Browse Source

feat(megbrain): add correlation opr

GitOrigin-RevId: 6d44598891
release-1.4
Megvii Engine Team 4 years ago
parent
commit
4fa9096d20
12 changed files with 476 additions and 18 deletions
  1. +1
    -1
      dnn/src/cuda/correlation/correlation_cuda.cu
  2. +3
    -4
      dnn/src/naive/handle.cpp
  3. +0
    -13
      dnn/test/common/correlation.h
  4. +37
    -0
      imperative/python/megengine/functional/vision.py
  5. +100
    -0
      imperative/python/test/unit/functional/test_functional.py
  6. +16
    -0
      imperative/src/impl/ops/specializations.cpp
  7. +1
    -0
      src/core/include/megbrain/ir/ops.td
  8. +109
    -0
      src/opr/impl/dnn/correlation.cpp
  9. +5
    -0
      src/opr/impl/dnn/dnn.sereg.h
  10. +69
    -0
      src/opr/include/megbrain/opr/dnn/correlation.h
  11. +134
    -0
      src/opr/test/dnn/correlation.cpp
  12. +1
    -0
      src/serialization/impl/schema.fbs

+ 1
- 1
dnn/src/cuda/correlation/correlation_cuda.cu View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/cuda/roi_align/roi_align.cu
* \file dnn/src/cuda/correlation/correlation_cuda.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.


+ 3
- 4
dnn/src/naive/handle.cpp View File

@@ -28,9 +28,9 @@
#include "src/naive/convolution/opr_impl.h"
#include "src/naive/convolution3d/opr_impl.h"
#include "src/naive/convpooling/opr_impl.h"
#include "src/naive/correlation/opr_impl.h"
#include "src/naive/cumsum/opr_impl.h"
#include "src/naive/cvt_color/opr_impl.h"
#include "src/naive/correlation/opr_impl.h"
#include "src/naive/dct/opr_impl.h"
#include "src/naive/deformable_conv/opr_impl.h"
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h"
@@ -38,6 +38,7 @@
#include "src/naive/elemwise/opr_impl.h"
#include "src/naive/elemwise_multi_type/opr_impl.h"
#include "src/naive/eye/opr_impl.h"
#include "src/naive/fake_quant/opr_impl.h"
#include "src/naive/flip/opr_impl.h"
#include "src/naive/gaussian_blur/opr_impl.h"
#include "src/naive/group_local/opr_impl.h"
@@ -75,13 +76,11 @@
#include "src/naive/tensor_remap/opr_impl.h"
#include "src/naive/tile/opr_impl.h"
#include "src/naive/topk/opr_impl.h"
#include "src/naive/tqt/opr_impl.h"
#include "src/naive/transpose/opr_impl.h"
#include "src/naive/type_cvt/opr_impl.h"
#include "src/naive/warp_affine/opr_impl.h"
#include "src/naive/warp_perspective/opr_impl.h"
#include "src/naive/remap/opr_impl.h"
#include "src/naive/fake_quant/opr_impl.h"
#include "src/naive/tqt/opr_impl.h"

static size_t g_image2d_pitch_alignment = 1;



+ 0
- 13
dnn/test/common/correlation.h View File

@@ -45,19 +45,6 @@ inline static std::vector<TestArg> get_args() {
TensorShape{batch_size, channel, height, width},
TensorShape{batch_size, channel, height, width});

// cur_param.is_multiply = false;
// cur_param.kernel_size = 1;
// cur_param.max_displacement = 2;
// cur_param.pad_size = 1;
// cur_param.stride1 = 1;
// cur_param.stride2 = 1;
// cur_param.format =
// megdnn::param::Correlation::Format::NCHW;

// args.emplace_back(
// cur_param,
// TensorShape{batch_size, channel, height, width},
// TensorShape{batch_size, channel, height, width});
}
}
}


+ 37
- 0
imperative/python/megengine/functional/vision.py View File

@@ -106,6 +106,43 @@ def roi_pooling(
return result


def correlation(
data1: Tensor,
data2: Tensor,
kernel_size: int = 1,
max_displacement: int = 1,
stride1: int = 1,
stride2: int = 1,
pad_size: int = 0,
is_multiply: bool = True,
) -> Tensor:
""" Applies correlation to inputs.

:param data1: Input data1 to the correlation. format must be nchw
:param data2: Input data2 to the correlation. format must be nchw
:param kernel_size: (int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number
:param max_displacement: (int (non-negative), optional, default=1) – Max displacement of Correlation
:param stride1: (int (non-negative), optional, default=1) – stride1 quantize data1 globally
:param stride2: (int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1
:param pad_size: (int (non-negative), optional, default=0) – pad for Correlation
:param is_multiply: (boolean, optional, default=True) – operation type is either multiplication or absolute difference

"""

op = builtin.Correlation(
format="NCHW",
kernel_size=kernel_size,
max_displacement=max_displacement,
stride1=stride1,
stride2=stride2,
pad_size=pad_size,
is_multiply=is_multiply,
)

result, *_ = apply(op, data1, data2)
return result


def roi_align(
inp: Tensor,
rois: Tensor,


+ 100
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -228,6 +228,106 @@ def test_roi_align():
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape)


def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)):
if random:
inp_feat1 = np.random.randn(
image_shape[0], image_shape[1], image_shape[2], image_shape[3]
)
inp_feat2 = np.random.randn(
image_shape[0], image_shape[1], image_shape[2], image_shape[3]
)
else:
inp_feat1 = np.ones(image_shape) * constant
inp_feat2 = np.ones(image_shape) * constant

return tensor(inp_feat1), tensor(inp_feat2)


def test_correlation():
##test case 0 check the grad shape
data1, data2 = _gen_correlation()
grad = Grad().wrt(data1, callback=_save_to(data1))

out_feat = F.vision.correlation(
data1,
data2,
kernel_size=5,
max_displacement=4,
stride1=2,
stride2=2,
pad_size=2,
is_multiply=True,
)

grad(out_feat, tensor(F.ones_like(out_feat)))
assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape)

##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))

out_feat = F.vision.correlation(
data1,
data2,
kernel_size=3,
max_displacement=0,
stride1=1,
stride2=1,
pad_size=0,
is_multiply=True,
)
assert abs(out_feat.sum() - 1) < 1e-9

##test case 2 check same image subduction
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))

out_feat = F.vision.correlation(
data1,
data2,
kernel_size=3,
max_displacement=0,
stride1=1,
stride2=1,
pad_size=0,
is_multiply=False,
)
assert out_feat.sum() < 1e-9

##test case 3 check same image subduction
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3))

out_feat = F.vision.correlation(
data1,
data2,
kernel_size=3,
max_displacement=0,
stride1=1,
stride2=1,
pad_size=0,
is_multiply=False,
)
assert out_feat.sum() < 1e-9

##test case 4 check correlation
data1, _ = _gen_correlation(
random=False, image_shape=(1, 1, 220, 220), constant=2.0
)
_, data2 = _gen_correlation(
random=False, image_shape=(1, 1, 220, 220), constant=1.0
)

out_feat = F.vision.correlation(
data1,
data2,
kernel_size=3,
max_displacement=2,
stride1=1,
stride2=2,
pad_size=0,
is_multiply=False,
)
assert abs(out_feat.mean() - 1) < 1e-9


def test_roi_pooling():
inp_feat, rois = _gen_roi_inp()
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))


+ 16
- 0
imperative/src/impl/ops/specializations.cpp View File

@@ -19,6 +19,7 @@
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
@@ -445,6 +446,21 @@ OP_TRAIT_REG(ROIAlign, ROIAlign)
.fallback();
}} // roi_align

namespace { namespace correlation {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Correlation&>(def);
mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{op.make_name()};
return opr::Correlation::make(
inputs[0], inputs[1], op.param(), config);
}
OP_TRAIT_REG(Correlation, Correlation)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // correlation

#if MGB_CUDA
namespace { namespace nvof {
auto apply_on_var_node(


+ 1
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -82,6 +82,7 @@ def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, Executio
def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;

def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;
def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>;

def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;



+ 109
- 0
src/opr/impl/dnn/correlation.cpp View File

@@ -0,0 +1,109 @@
/**
* \file src/opr/impl/dnn/correlation.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/opr/dnn/correlation.h"

#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/utility.h"

#include "../internal/megdnn_opr_wrapper.inl"

using namespace mgb;
using namespace opr;

/* ==================== CorrelationForward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationForward);
CorrelationForward::CorrelationForward(VarNode* data1, VarNode* data2,
const Param& param,
const OperatorNodeConfig& config)
: Super{data1->owner_graph(), config, "correlation", {data1, data2}} {
init_megdnn_opr(*this, param);
mgb_assert(data1->dtype() == data2->dtype());
mgb_assert(data1->dtype().category() == DTypeCategory::FLOAT);

add_input({data1, data2});
output(0)->dtype(data1->dtype());
}

SymbolVar CorrelationForward::make(SymbolVar data1, SymbolVar data2,
const Param& param,
const OperatorNodeConfig& config) {
return data1.insert_single_output_opr<CorrelationForward>(
data1.node(), data2.node(), param, config);
}

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(CorrelationForward) {
if (wrt_idx == 0) {
// wrt src
SymbolVar grad = CorrelationBackwardData1::make(
out_grad[0], opr.input(0), opr.input(1), opr.param(),
opr.config());
return grad.node();
} else {
mgb_assert(wrt_idx == 1);
SymbolVar grad = CorrelationBackwardData2::make(
out_grad[0], opr.input(0), opr.input(1), opr.param(),
opr.config());
return grad.node();
}
}
#endif

/* ==================== CorrelationBackwardData1 ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData1);
MEGDNN_OPR_INIT3(CorrelationBackwardData1, "correlation_backward_data1", 1,
true);

void CorrelationBackwardData1::scn_do_execute() {
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}

size_t CorrelationBackwardData1::get_workspace_size_bytes(
const TensorShapeArray& inp_shapes,
const TensorShapeArray& out_shapes) const {
TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()},
data1{inp_shapes[1], input(1)->dtype(), input(1)->format()},
data2{inp_shapes[2], input(2)->dtype(), input(2)->format()},
grad1{out_shapes[0], output(0)->dtype(), output(0)->format()};
return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad1);
}

/* ==================== CorrelationBackwardData2 ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData2);
MEGDNN_OPR_INIT3(CorrelationBackwardData2, "correlation_backward_data2", 1,
true);

void CorrelationBackwardData2::scn_do_execute() {
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}

size_t CorrelationBackwardData2::get_workspace_size_bytes(
const TensorShapeArray& inp_shapes,
const TensorShapeArray& out_shapes) const {
TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()},
data1{inp_shapes[1], input(1)->dtype(), input(1)->format()},
data2{inp_shapes[2], input(2)->dtype(), input(2)->format()},
grad2{out_shapes[0], output(0)->dtype(), output(0)->format()};
return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad2);
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 5
- 0
src/opr/impl/dnn/dnn.sereg.h View File

@@ -11,6 +11,7 @@

#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
@@ -573,6 +574,10 @@ MGB_SEREG_OPR(DeformableConvForwardV1, 0);
MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0);
MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0);

MGB_SEREG_OPR(CorrelationForward, 2);
MGB_SEREG_OPR(CorrelationBackwardData1, 3);
MGB_SEREG_OPR(CorrelationBackwardData2, 3);

MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3);
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5);



+ 69
- 0
src/opr/include/megbrain/opr/dnn/correlation.h View File

@@ -0,0 +1,69 @@
/**
* \file src/opr/include/megbrain/opr/dnn/correlation.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#pragma once

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"

namespace mgb {
namespace opr {

MGB_DEFINE_OPR_CLASS(CorrelationForward,
intl::MegDNNOprWrapperFwd<megdnn::CorrelationForward>) // {
public:
CorrelationForward(VarNode* data1, VarNode* data2, const Param& param,
const OperatorNodeConfig& config);

static SymbolVar make(SymbolVar data1, SymbolVar data2,
const Param& param = {},
const OperatorNodeConfig& config = {});
};
using Correlation = CorrelationForward;

MGB_DEFINE_OPR_CLASS(
CorrelationBackwardData1, intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData1>) // {
public:
CorrelationBackwardData1(VarNode* diff, VarNode* data1, VarNode* data2,
const Param& param, const OperatorNodeConfig& config);

static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2,
const Param& param = {},
const OperatorNodeConfig& config = {});

private:
void scn_do_execute() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
};

MGB_DEFINE_OPR_CLASS(
CorrelationBackwardData2, intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData2>) // {
public:
CorrelationBackwardData2(VarNode* diff, VarNode* data1, VarNode* data2,
const Param& param, const OperatorNodeConfig& config);

static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2,
const Param& param = {},
const OperatorNodeConfig& config = {});

private:
void scn_do_execute() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
};

} // namespace opr
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 134
- 0
src/opr/test/dnn/correlation.cpp View File

@@ -0,0 +1,134 @@
/**
* \file src/opr/test/dnn/correlation.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"

#include "megdnn/oprs.h"

#include <cmath>
#include <iomanip>
#include <random>
#include <sstream>

using namespace mgb;

namespace {
using Param = opr::CorrelationForward::Param;

void run_forward(bool is_multiply) {
RNGxorshf rng{next_rand_seed()};
using Checker = AutoOprChecker<2, 1>;

Param param;
param.format = Param::Format::NCHW;
param.is_multiply = is_multiply;
param.kernel_size = 3;
param.max_displacement = 2;
param.pad_size = 1;
param.stride1 = 2;
param.stride2 = 2;

auto make_graph =
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
auto o0 = opr::CorrelationForward::make(inputs[0], inputs[1], param);
return {o0};
};

auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr = megdnn_naive_handle()
->create_operator<megdnn::CorrelationForward>();
opr->param() = param;
auto inp_shape = inp[0]->shape();
auto num = inp_shape[0];
auto height = inp_shape[2];
auto width = inp_shape[3];

uint32_t pad_size = param.pad_size;
uint32_t kernel_size = param.kernel_size;
uint32_t stride1 = param.stride1;
uint32_t stride2 = param.stride2;
uint32_t max_displacement = param.max_displacement;

int paddedbottomheight = height + 2 * pad_size;
int paddedbottomwidth = width + 2 * pad_size;
uint32_t kernel_radius = (kernel_size - 1) / 2;
uint32_t border_size = max_displacement + kernel_radius;
uint32_t top_width =
ceil(static_cast<float>(paddedbottomwidth - border_size * 2) /
static_cast<float>(stride1));
uint32_t top_height =
ceil(static_cast<float>(paddedbottomheight - border_size * 2) /
static_cast<float>(stride1));
uint32_t neighborhood_grid_radius = max_displacement / stride2;
uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1;
uint32_t top_channels =
neighborhood_grid_width * neighborhood_grid_width;
megdnn::TensorShape target_shape{num, top_channels, top_height,
top_width};

dest[0].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize(target_shape);
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(),
{});
};

auto rand_real = [&](float lo, float hi) {
std::uniform_real_distribution<float> dist(lo, hi);
return dist(rng);
};
auto gen_inp1 = [&](HostTensorND &inp) {
auto ptr = inp.ptr<float>();
for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) {
ptr[i] = rand_real(0.06f, 0.1f);
};
};

auto gen_inp2 = [&](HostTensorND &inp) {
auto ptr = inp.ptr<float>();
for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) {
ptr[i] = rand_real(0.01f, 0.04f);
};
};

Checker::RunOptions option;
option.numdiff_eps = 1e-3;
option.numdiff_max_err = 1e-2;
Checker checker{make_graph, fwd};

checker.set_input_generator(0, gen_inp1);
checker.set_input_generator(1, gen_inp2);

checker.run({TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 10, 10}}, option)
.run({TensorShape{1, 3, 50, 50}, TensorShape{1, 3, 50, 50}}, option)
.run({TensorShape{1, 1, 100, 100}, TensorShape{1, 1, 100, 100}},
option);
}

TEST(TestOprDNN, CorrelationForwardMultiply) {
// TODO: fix me, add correct backward of cpu
REQUIRE_GPU(1);
run_forward(true);
}

TEST(TestOprDNN, CorrelationForwardSubstract) {
// TODO: fix me, add correct backward of cpu
REQUIRE_GPU(1);
run_forward(false);
}

} // anonymous namespace

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 0
src/serialization/impl/schema.fbs View File

@@ -106,6 +106,7 @@ union OperatorParam {
param.DctChannelSelect = 72,
param.FakeQuant = 73,
param.TQT = 74,
param.Correlation = 75,
}

table Operator {


Loading…
Cancel
Save