@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/cuda/roi_align/roi_align.cu | |||||
* \file dnn/src/cuda/correlation/correlation_cuda.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
@@ -28,9 +28,9 @@ | |||||
#include "src/naive/convolution/opr_impl.h" | #include "src/naive/convolution/opr_impl.h" | ||||
#include "src/naive/convolution3d/opr_impl.h" | #include "src/naive/convolution3d/opr_impl.h" | ||||
#include "src/naive/convpooling/opr_impl.h" | #include "src/naive/convpooling/opr_impl.h" | ||||
#include "src/naive/correlation/opr_impl.h" | |||||
#include "src/naive/cumsum/opr_impl.h" | #include "src/naive/cumsum/opr_impl.h" | ||||
#include "src/naive/cvt_color/opr_impl.h" | #include "src/naive/cvt_color/opr_impl.h" | ||||
#include "src/naive/correlation/opr_impl.h" | |||||
#include "src/naive/dct/opr_impl.h" | #include "src/naive/dct/opr_impl.h" | ||||
#include "src/naive/deformable_conv/opr_impl.h" | #include "src/naive/deformable_conv/opr_impl.h" | ||||
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | ||||
@@ -38,6 +38,7 @@ | |||||
#include "src/naive/elemwise/opr_impl.h" | #include "src/naive/elemwise/opr_impl.h" | ||||
#include "src/naive/elemwise_multi_type/opr_impl.h" | #include "src/naive/elemwise_multi_type/opr_impl.h" | ||||
#include "src/naive/eye/opr_impl.h" | #include "src/naive/eye/opr_impl.h" | ||||
#include "src/naive/fake_quant/opr_impl.h" | |||||
#include "src/naive/flip/opr_impl.h" | #include "src/naive/flip/opr_impl.h" | ||||
#include "src/naive/gaussian_blur/opr_impl.h" | #include "src/naive/gaussian_blur/opr_impl.h" | ||||
#include "src/naive/group_local/opr_impl.h" | #include "src/naive/group_local/opr_impl.h" | ||||
@@ -75,13 +76,11 @@ | |||||
#include "src/naive/tensor_remap/opr_impl.h" | #include "src/naive/tensor_remap/opr_impl.h" | ||||
#include "src/naive/tile/opr_impl.h" | #include "src/naive/tile/opr_impl.h" | ||||
#include "src/naive/topk/opr_impl.h" | #include "src/naive/topk/opr_impl.h" | ||||
#include "src/naive/tqt/opr_impl.h" | |||||
#include "src/naive/transpose/opr_impl.h" | #include "src/naive/transpose/opr_impl.h" | ||||
#include "src/naive/type_cvt/opr_impl.h" | #include "src/naive/type_cvt/opr_impl.h" | ||||
#include "src/naive/warp_affine/opr_impl.h" | #include "src/naive/warp_affine/opr_impl.h" | ||||
#include "src/naive/warp_perspective/opr_impl.h" | #include "src/naive/warp_perspective/opr_impl.h" | ||||
#include "src/naive/remap/opr_impl.h" | |||||
#include "src/naive/fake_quant/opr_impl.h" | |||||
#include "src/naive/tqt/opr_impl.h" | |||||
static size_t g_image2d_pitch_alignment = 1; | static size_t g_image2d_pitch_alignment = 1; | ||||
@@ -45,19 +45,6 @@ inline static std::vector<TestArg> get_args() { | |||||
TensorShape{batch_size, channel, height, width}, | TensorShape{batch_size, channel, height, width}, | ||||
TensorShape{batch_size, channel, height, width}); | TensorShape{batch_size, channel, height, width}); | ||||
// cur_param.is_multiply = false; | |||||
// cur_param.kernel_size = 1; | |||||
// cur_param.max_displacement = 2; | |||||
// cur_param.pad_size = 1; | |||||
// cur_param.stride1 = 1; | |||||
// cur_param.stride2 = 1; | |||||
// cur_param.format = | |||||
// megdnn::param::Correlation::Format::NCHW; | |||||
// args.emplace_back( | |||||
// cur_param, | |||||
// TensorShape{batch_size, channel, height, width}, | |||||
// TensorShape{batch_size, channel, height, width}); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -106,6 +106,43 @@ def roi_pooling( | |||||
return result | return result | ||||
def correlation( | |||||
data1: Tensor, | |||||
data2: Tensor, | |||||
kernel_size: int = 1, | |||||
max_displacement: int = 1, | |||||
stride1: int = 1, | |||||
stride2: int = 1, | |||||
pad_size: int = 0, | |||||
is_multiply: bool = True, | |||||
) -> Tensor: | |||||
""" Applies correlation to inputs. | |||||
:param data1: Input data1 to the correlation. format must be nchw | |||||
:param data2: Input data2 to the correlation. format must be nchw | |||||
:param kernel_size: (int (non-negative), optional, default=1) – kernel size for Correlation must be an odd number | |||||
:param max_displacement: (int (non-negative), optional, default=1) – Max displacement of Correlation | |||||
:param stride1: (int (non-negative), optional, default=1) – stride1 quantize data1 globally | |||||
:param stride2: (int (non-negative), optional, default=1) – stride2 quantize data2 within the neighborhood centered around data1 | |||||
:param pad_size: (int (non-negative), optional, default=0) – pad for Correlation | |||||
:param is_multiply: (boolean, optional, default=True) – operation type is either multiplication or absolute difference | |||||
""" | |||||
op = builtin.Correlation( | |||||
format="NCHW", | |||||
kernel_size=kernel_size, | |||||
max_displacement=max_displacement, | |||||
stride1=stride1, | |||||
stride2=stride2, | |||||
pad_size=pad_size, | |||||
is_multiply=is_multiply, | |||||
) | |||||
result, *_ = apply(op, data1, data2) | |||||
return result | |||||
def roi_align( | def roi_align( | ||||
inp: Tensor, | inp: Tensor, | ||||
rois: Tensor, | rois: Tensor, | ||||
@@ -228,6 +228,106 @@ def test_roi_align(): | |||||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | ||||
def _gen_correlation(random=True, constant=1, image_shape=(2, 1, 160, 160)): | |||||
if random: | |||||
inp_feat1 = np.random.randn( | |||||
image_shape[0], image_shape[1], image_shape[2], image_shape[3] | |||||
) | |||||
inp_feat2 = np.random.randn( | |||||
image_shape[0], image_shape[1], image_shape[2], image_shape[3] | |||||
) | |||||
else: | |||||
inp_feat1 = np.ones(image_shape) * constant | |||||
inp_feat2 = np.ones(image_shape) * constant | |||||
return tensor(inp_feat1), tensor(inp_feat2) | |||||
def test_correlation(): | |||||
##test case 0 check the grad shape | |||||
data1, data2 = _gen_correlation() | |||||
grad = Grad().wrt(data1, callback=_save_to(data1)) | |||||
out_feat = F.vision.correlation( | |||||
data1, | |||||
data2, | |||||
kernel_size=5, | |||||
max_displacement=4, | |||||
stride1=2, | |||||
stride2=2, | |||||
pad_size=2, | |||||
is_multiply=True, | |||||
) | |||||
grad(out_feat, tensor(F.ones_like(out_feat))) | |||||
assert make_shape_tuple(data1.grad.shape) == make_shape_tuple(data1.shape) | |||||
##test case 1 from https://github.com/NVIDIA/flownet2-pytorch/issues/194 | |||||
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) | |||||
out_feat = F.vision.correlation( | |||||
data1, | |||||
data2, | |||||
kernel_size=3, | |||||
max_displacement=0, | |||||
stride1=1, | |||||
stride2=1, | |||||
pad_size=0, | |||||
is_multiply=True, | |||||
) | |||||
assert abs(out_feat.sum() - 1) < 1e-9 | |||||
##test case 2 check same image subduction | |||||
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) | |||||
out_feat = F.vision.correlation( | |||||
data1, | |||||
data2, | |||||
kernel_size=3, | |||||
max_displacement=0, | |||||
stride1=1, | |||||
stride2=1, | |||||
pad_size=0, | |||||
is_multiply=False, | |||||
) | |||||
assert out_feat.sum() < 1e-9 | |||||
##test case 3 check same image subduction | |||||
data1, data2 = _gen_correlation(random=False, image_shape=(1, 1, 3, 3)) | |||||
out_feat = F.vision.correlation( | |||||
data1, | |||||
data2, | |||||
kernel_size=3, | |||||
max_displacement=0, | |||||
stride1=1, | |||||
stride2=1, | |||||
pad_size=0, | |||||
is_multiply=False, | |||||
) | |||||
assert out_feat.sum() < 1e-9 | |||||
##test case 4 check correlation | |||||
data1, _ = _gen_correlation( | |||||
random=False, image_shape=(1, 1, 220, 220), constant=2.0 | |||||
) | |||||
_, data2 = _gen_correlation( | |||||
random=False, image_shape=(1, 1, 220, 220), constant=1.0 | |||||
) | |||||
out_feat = F.vision.correlation( | |||||
data1, | |||||
data2, | |||||
kernel_size=3, | |||||
max_displacement=2, | |||||
stride1=1, | |||||
stride2=2, | |||||
pad_size=0, | |||||
is_multiply=False, | |||||
) | |||||
assert abs(out_feat.mean() - 1) < 1e-9 | |||||
def test_roi_pooling(): | def test_roi_pooling(): | ||||
inp_feat, rois = _gen_roi_inp() | inp_feat, rois = _gen_roi_inp() | ||||
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) | ||||
@@ -19,6 +19,7 @@ | |||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
#include "megbrain/opr/dnn/roi_align.h" | #include "megbrain/opr/dnn/roi_align.h" | ||||
#include "megbrain/opr/dnn/correlation.h" | |||||
#include "megbrain/opr/dnn/roi_pooling.h" | #include "megbrain/opr/dnn/roi_pooling.h" | ||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
#include "megbrain/opr/blas.h" | #include "megbrain/opr/blas.h" | ||||
@@ -445,6 +446,21 @@ OP_TRAIT_REG(ROIAlign, ROIAlign) | |||||
.fallback(); | .fallback(); | ||||
}} // roi_align | }} // roi_align | ||||
namespace { namespace correlation { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const Correlation&>(def); | |||||
mgb_assert(inputs.size() == 2); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Correlation::make( | |||||
inputs[0], inputs[1], op.param(), config); | |||||
} | |||||
OP_TRAIT_REG(Correlation, Correlation) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // correlation | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
namespace { namespace nvof { | namespace { namespace nvof { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
@@ -82,6 +82,7 @@ def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, Executio | |||||
def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; | ||||
def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; | ||||
def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>; | |||||
def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; | def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; | ||||
@@ -0,0 +1,109 @@ | |||||
/** | |||||
* \file src/opr/impl/dnn/correlation.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/correlation.h" | |||||
#include "megbrain/graph/grad_impl.h" | |||||
#include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
/* ==================== CorrelationForward ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationForward); | |||||
CorrelationForward::CorrelationForward(VarNode* data1, VarNode* data2, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config) | |||||
: Super{data1->owner_graph(), config, "correlation", {data1, data2}} { | |||||
init_megdnn_opr(*this, param); | |||||
mgb_assert(data1->dtype() == data2->dtype()); | |||||
mgb_assert(data1->dtype().category() == DTypeCategory::FLOAT); | |||||
add_input({data1, data2}); | |||||
output(0)->dtype(data1->dtype()); | |||||
} | |||||
SymbolVar CorrelationForward::make(SymbolVar data1, SymbolVar data2, | |||||
const Param& param, | |||||
const OperatorNodeConfig& config) { | |||||
return data1.insert_single_output_opr<CorrelationForward>( | |||||
data1.node(), data2.node(), param, config); | |||||
} | |||||
#if MGB_ENABLE_GRAD | |||||
MGB_IMPL_OPR_GRAD(CorrelationForward) { | |||||
if (wrt_idx == 0) { | |||||
// wrt src | |||||
SymbolVar grad = CorrelationBackwardData1::make( | |||||
out_grad[0], opr.input(0), opr.input(1), opr.param(), | |||||
opr.config()); | |||||
return grad.node(); | |||||
} else { | |||||
mgb_assert(wrt_idx == 1); | |||||
SymbolVar grad = CorrelationBackwardData2::make( | |||||
out_grad[0], opr.input(0), opr.input(1), opr.param(), | |||||
opr.config()); | |||||
return grad.node(); | |||||
} | |||||
} | |||||
#endif | |||||
/* ==================== CorrelationBackwardData1 ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData1); | |||||
MEGDNN_OPR_INIT3(CorrelationBackwardData1, "correlation_backward_data1", 1, | |||||
true); | |||||
void CorrelationBackwardData1::scn_do_execute() { | |||||
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||||
input(1)->dev_tensor().as_megdnn(), | |||||
input(2)->dev_tensor().as_megdnn(), | |||||
output(0)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output(1))); | |||||
} | |||||
size_t CorrelationBackwardData1::get_workspace_size_bytes( | |||||
const TensorShapeArray& inp_shapes, | |||||
const TensorShapeArray& out_shapes) const { | |||||
TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
data1{inp_shapes[1], input(1)->dtype(), input(1)->format()}, | |||||
data2{inp_shapes[2], input(2)->dtype(), input(2)->format()}, | |||||
grad1{out_shapes[0], output(0)->dtype(), output(0)->format()}; | |||||
return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad1); | |||||
} | |||||
/* ==================== CorrelationBackwardData2 ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CorrelationBackwardData2); | |||||
MEGDNN_OPR_INIT3(CorrelationBackwardData2, "correlation_backward_data2", 1, | |||||
true); | |||||
void CorrelationBackwardData2::scn_do_execute() { | |||||
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||||
input(1)->dev_tensor().as_megdnn(), | |||||
input(2)->dev_tensor().as_megdnn(), | |||||
output(0)->dev_tensor().as_megdnn(), | |||||
intl::get_megdnn_workspace_from_var(output(1))); | |||||
} | |||||
size_t CorrelationBackwardData2::get_workspace_size_bytes( | |||||
const TensorShapeArray& inp_shapes, | |||||
const TensorShapeArray& out_shapes) const { | |||||
TensorLayout diff{inp_shapes[0], input(0)->dtype(), input(0)->format()}, | |||||
data1{inp_shapes[1], input(1)->dtype(), input(1)->format()}, | |||||
data2{inp_shapes[2], input(2)->dtype(), input(2)->format()}, | |||||
grad2{out_shapes[0], output(0)->dtype(), output(0)->format()}; | |||||
return megdnn_opr()->get_workspace_in_bytes(diff, data1, data2, grad2); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -11,6 +11,7 @@ | |||||
#include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/dnn/correlation.h" | |||||
#include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/dnn/adaptive_pooling.h" | #include "megbrain/opr/dnn/adaptive_pooling.h" | ||||
@@ -573,6 +574,10 @@ MGB_SEREG_OPR(DeformableConvForwardV1, 0); | |||||
MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0); | MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0); | ||||
MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0); | MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0); | ||||
MGB_SEREG_OPR(CorrelationForward, 2); | |||||
MGB_SEREG_OPR(CorrelationBackwardData1, 3); | |||||
MGB_SEREG_OPR(CorrelationBackwardData2, 3); | |||||
MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); | MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); | ||||
MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | ||||
@@ -0,0 +1,69 @@ | |||||
/** | |||||
* \file src/opr/include/megbrain/opr/dnn/correlation.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megdnn/oprs.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
MGB_DEFINE_OPR_CLASS(CorrelationForward, | |||||
intl::MegDNNOprWrapperFwd<megdnn::CorrelationForward>) // { | |||||
public: | |||||
CorrelationForward(VarNode* data1, VarNode* data2, const Param& param, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVar make(SymbolVar data1, SymbolVar data2, | |||||
const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
}; | |||||
using Correlation = CorrelationForward; | |||||
MGB_DEFINE_OPR_CLASS( | |||||
CorrelationBackwardData1, intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData1>) // { | |||||
public: | |||||
CorrelationBackwardData1(VarNode* diff, VarNode* data1, VarNode* data2, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2, | |||||
const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
void scn_do_execute() override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
}; | |||||
MGB_DEFINE_OPR_CLASS( | |||||
CorrelationBackwardData2, intl::MegDNNOprWrapperBwd<megdnn::CorrelationBackwardData2>) // { | |||||
public: | |||||
CorrelationBackwardData2(VarNode* diff, VarNode* data1, VarNode* data2, | |||||
const Param& param, const OperatorNodeConfig& config); | |||||
static SymbolVar make(SymbolVar diff, SymbolVar data1, SymbolVar data2, | |||||
const Param& param = {}, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
void scn_do_execute() override; | |||||
size_t get_workspace_size_bytes( | |||||
const TensorShapeArray& input_shapes, | |||||
const TensorShapeArray& output_shapes) const override; | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,134 @@ | |||||
/** | |||||
* \file src/opr/test/dnn/correlation.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/dnn/correlation.h" | |||||
#include "megbrain/test/autocheck.h" | |||||
#include "megbrain/test/helper.h" | |||||
#include "megbrain/test/megdnn_helper.h" | |||||
#include "megdnn/oprs.h" | |||||
#include <cmath> | |||||
#include <iomanip> | |||||
#include <random> | |||||
#include <sstream> | |||||
using namespace mgb; | |||||
namespace { | |||||
using Param = opr::CorrelationForward::Param; | |||||
void run_forward(bool is_multiply) { | |||||
RNGxorshf rng{next_rand_seed()}; | |||||
using Checker = AutoOprChecker<2, 1>; | |||||
Param param; | |||||
param.format = Param::Format::NCHW; | |||||
param.is_multiply = is_multiply; | |||||
param.kernel_size = 3; | |||||
param.max_displacement = 2; | |||||
param.pad_size = 1; | |||||
param.stride1 = 2; | |||||
param.stride2 = 2; | |||||
auto make_graph = | |||||
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||||
auto o0 = opr::CorrelationForward::make(inputs[0], inputs[1], param); | |||||
return {o0}; | |||||
}; | |||||
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||||
auto opr = megdnn_naive_handle() | |||||
->create_operator<megdnn::CorrelationForward>(); | |||||
opr->param() = param; | |||||
auto inp_shape = inp[0]->shape(); | |||||
auto num = inp_shape[0]; | |||||
auto height = inp_shape[2]; | |||||
auto width = inp_shape[3]; | |||||
uint32_t pad_size = param.pad_size; | |||||
uint32_t kernel_size = param.kernel_size; | |||||
uint32_t stride1 = param.stride1; | |||||
uint32_t stride2 = param.stride2; | |||||
uint32_t max_displacement = param.max_displacement; | |||||
int paddedbottomheight = height + 2 * pad_size; | |||||
int paddedbottomwidth = width + 2 * pad_size; | |||||
uint32_t kernel_radius = (kernel_size - 1) / 2; | |||||
uint32_t border_size = max_displacement + kernel_radius; | |||||
uint32_t top_width = | |||||
ceil(static_cast<float>(paddedbottomwidth - border_size * 2) / | |||||
static_cast<float>(stride1)); | |||||
uint32_t top_height = | |||||
ceil(static_cast<float>(paddedbottomheight - border_size * 2) / | |||||
static_cast<float>(stride1)); | |||||
uint32_t neighborhood_grid_radius = max_displacement / stride2; | |||||
uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
uint32_t top_channels = | |||||
neighborhood_grid_width * neighborhood_grid_width; | |||||
megdnn::TensorShape target_shape{num, top_channels, top_height, | |||||
top_width}; | |||||
dest[0].dtype(dtype::Float32()) | |||||
.comp_node(inp[0]->comp_node()) | |||||
.resize(target_shape); | |||||
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), | |||||
{}); | |||||
}; | |||||
auto rand_real = [&](float lo, float hi) { | |||||
std::uniform_real_distribution<float> dist(lo, hi); | |||||
return dist(rng); | |||||
}; | |||||
auto gen_inp1 = [&](HostTensorND &inp) { | |||||
auto ptr = inp.ptr<float>(); | |||||
for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) { | |||||
ptr[i] = rand_real(0.06f, 0.1f); | |||||
}; | |||||
}; | |||||
auto gen_inp2 = [&](HostTensorND &inp) { | |||||
auto ptr = inp.ptr<float>(); | |||||
for (size_t i = 0; i < inp.shape().total_nr_elems(); ++i) { | |||||
ptr[i] = rand_real(0.01f, 0.04f); | |||||
}; | |||||
}; | |||||
Checker::RunOptions option; | |||||
option.numdiff_eps = 1e-3; | |||||
option.numdiff_max_err = 1e-2; | |||||
Checker checker{make_graph, fwd}; | |||||
checker.set_input_generator(0, gen_inp1); | |||||
checker.set_input_generator(1, gen_inp2); | |||||
checker.run({TensorShape{1, 1, 10, 10}, TensorShape{1, 1, 10, 10}}, option) | |||||
.run({TensorShape{1, 3, 50, 50}, TensorShape{1, 3, 50, 50}}, option) | |||||
.run({TensorShape{1, 1, 100, 100}, TensorShape{1, 1, 100, 100}}, | |||||
option); | |||||
} | |||||
TEST(TestOprDNN, CorrelationForwardMultiply) { | |||||
// TODO: fix me, add correct backward of cpu | |||||
REQUIRE_GPU(1); | |||||
run_forward(true); | |||||
} | |||||
TEST(TestOprDNN, CorrelationForwardSubstract) { | |||||
// TODO: fix me, add correct backward of cpu | |||||
REQUIRE_GPU(1); | |||||
run_forward(false); | |||||
} | |||||
} // anonymous namespace | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -106,6 +106,7 @@ union OperatorParam { | |||||
param.DctChannelSelect = 72, | param.DctChannelSelect = 72, | ||||
param.FakeQuant = 73, | param.FakeQuant = 73, | ||||
param.TQT = 74, | param.TQT = 74, | ||||
param.Correlation = 75, | |||||
} | } | ||||
table Operator { | table Operator { | ||||