@@ -192,6 +192,87 @@ class ReduceForward: public OperatorBase { | |||
}; | |||
using Reduce = ReduceForward; | |||
class CorrelationBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(CorrelationBase, OperatorBase); | |||
DEF_OPR_PARAM(Correlation); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, | |||
TensorLayout& dst); | |||
void check_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, | |||
const TensorLayout& dst); | |||
}; | |||
class CorrelationForward : public CorrelationBase { | |||
DEF_OPR_IMPL(CorrelationForward, CorrelationBase, 2, 1); | |||
public: | |||
/** | |||
* \param[in] data1 (n, c, ih, iw) | |||
* \param[in] data2 (n, c, ih, iw) | |||
* \param[out] dst (n, q, oh, ow), q is the number of neighborhood | |||
* */ | |||
virtual void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& data1, const TensorLayout& data2, | |||
TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& data1, const TensorLayout& data2, | |||
const TensorLayout& dst, size_t workspace_in_bytes); | |||
}; | |||
using Correlation = CorrelationForward; | |||
class CorrelationBackwardData1 : public CorrelationBase { | |||
DEF_OPR_IMPL(CorrelationBackwardData1, CorrelationBase, 3, 1); | |||
public: | |||
/** | |||
* \param[in] diff the backpropagated gradient wrt. dst | |||
* \param[in] data1 the `data1' parameter in CorrelationForward::exec | |||
* \param[in] data2 the `data2' parameter in CorrelationForward::exec | |||
* \param[out] grad1 the backpropagated gradient wrt. data1 | |||
*/ | |||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||
_megdnn_tensor_out grad1, _megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, | |||
const TensorLayout& data2, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
const TensorLayout& grad1) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, | |||
const TensorLayout& grad1, size_t workspace_in_bytes); | |||
}; | |||
class CorrelationBackwardData2 : public CorrelationBase { | |||
DEF_OPR_IMPL(CorrelationBackwardData2, CorrelationBase, 3, 1); | |||
public: | |||
/** | |||
* \param[in] diff the backpropagated gradient wrt. dst | |||
* \param[in] data1 the `data1' parameter in CorrelationForward::exec | |||
* \param[in] data2 the `data2' parameter in CorrelationForward::exec | |||
* \param[out] grad2 the backpropagated gradient wrt. data2 | |||
*/ | |||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||
_megdnn_tensor_out grad2, _megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, | |||
const TensorLayout& data2, TensorLayout& dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
const TensorLayout& grad2) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, | |||
const TensorLayout& grad2, size_t workspace_in_bytes); | |||
}; | |||
class CumsumForward: public OperatorBase { | |||
DEF_OPR_PARAM(Cumsum); | |||
DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1); | |||
@@ -1053,6 +1053,16 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||
'sample_width', '2') | |||
) | |||
(pdef('Correlation'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_fields('uint32', 'kernel_size', '1'). | |||
add_fields('uint32', 'max_displacement', '1'). | |||
add_fields('uint32', 'stride1', '1'). | |||
add_fields('uint32', 'stride2', '1'). | |||
add_fields('uint32', 'pad_size', '0'). | |||
add_fields('bool', 'is_multiply', 'true') | |||
) | |||
(pdef('DeformablePSROIPooling'). | |||
add_fields('bool', 'no_trans', 'true'). | |||
add_fields('float32', 'spatial_scale', 1, | |||
@@ -0,0 +1,132 @@ | |||
/** | |||
* \file dnn/src/common/correlation.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void CorrelationBase::deduce_layout_fwd(const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
TensorLayout& dst) { | |||
megdnn_assert_contiguous(data1); | |||
megdnn_assert_contiguous(data2); | |||
megdnn_assert_contiguous(dst); | |||
auto errmsg = [&]() { | |||
return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) + | |||
", " + megdnn_layout_msg(dst); | |||
}; | |||
MEGDNN_MARK_USED_VAR(errmsg); | |||
using Format = CorrelationBase::Param::Format; | |||
megdnn_assert(param().format == Format::NCHW); | |||
auto data1_dtype = data1.dtype, data2_dtype = data2.dtype; | |||
megdnn_assert(data1_dtype == data2_dtype && | |||
data1_dtype.category() == DTypeCategory::FLOAT); | |||
megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str()); | |||
megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str()); | |||
uint32_t pad_size = param().pad_size; | |||
uint32_t kernel_size = param().kernel_size; | |||
uint32_t stride1 = param().stride1; | |||
uint32_t stride2 = param().stride2; | |||
uint32_t max_displacement = param().max_displacement; | |||
int paddedbottomheight = data1[2] + 2 * pad_size; | |||
int paddedbottomwidth = data1[3] + 2 * pad_size; | |||
uint32_t kernel_radius = (kernel_size - 1) / 2; | |||
uint32_t border_size = max_displacement + kernel_radius; | |||
uint32_t top_width = | |||
ceil(static_cast<float>(paddedbottomwidth - border_size * 2) / | |||
static_cast<float>(stride1)); | |||
uint32_t top_height = | |||
ceil(static_cast<float>(paddedbottomheight - border_size * 2) / | |||
static_cast<float>(stride1)); | |||
uint32_t neighborhood_grid_radius = max_displacement / stride2; | |||
uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||
uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width; | |||
megdnn_assert(top_width >= 1 && top_height >= 1); | |||
dst = TensorLayout{{data1[0], top_channels, top_height, top_width}, | |||
data1.dtype}; | |||
} | |||
void CorrelationBase::check_layout_fwd(const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
const TensorLayout& dst) { | |||
TensorLayout dst_expected; | |||
megdnn_assert_eq_dtype(data1, dst); | |||
megdnn_assert_eq_shape(data1, data2); | |||
deduce_layout_fwd(data1, data2, dst_expected); | |||
megdnn_assert_eq_shape(dst_expected, dst); | |||
} | |||
void CorrelationForward::deduce_layout(const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
TensorLayout& dst) { | |||
deduce_layout_fwd(data1, data2, dst); | |||
} | |||
void CorrelationForward::check_exec(const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
const TensorLayout& dst, | |||
size_t workspace_in_bytes) { | |||
check_layout_fwd(data1, data2, dst); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(data1, data2, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void CorrelationBackwardData1::check_exec(const TensorLayout& diff, | |||
const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
const TensorLayout& grad1, | |||
size_t workspace_in_bytes) { | |||
check_layout_fwd(grad1, data2, diff); | |||
megdnn_assert_eq_shape(data1, data2); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(diff, data1, data2, grad1); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void CorrelationBackwardData2::check_exec(const TensorLayout& diff, | |||
const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
const TensorLayout& grad2, | |||
size_t workspace_in_bytes) { | |||
check_layout_fwd(data1, grad2, diff); | |||
megdnn_assert_eq_shape(data1, data2); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(diff, data1, data2, grad2); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void CorrelationBackwardData2::deduce_layout(const TensorLayout& diff, | |||
const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
TensorLayout& grad) { | |||
megdnn_assert_eq_shape(data1, data2); | |||
check_layout_fwd(data1, data2, diff); | |||
grad = data2; | |||
} | |||
void CorrelationBackwardData1::deduce_layout(const TensorLayout& diff, | |||
const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
TensorLayout& grad) { | |||
megdnn_assert_eq_shape(data1, data2); | |||
check_layout_fwd(data1, data2, diff); | |||
grad = data1; | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -194,6 +194,9 @@ private: | |||
cb(LocalShareBackwardFilter) \ | |||
cb(ROIAlignForward) \ | |||
cb(ROIAlignBackward) \ | |||
cb(CorrelationForward) \ | |||
cb(CorrelationBackwardData1) \ | |||
cb(CorrelationBackwardData2) \ | |||
cb(BatchConvBiasForward) \ | |||
cb(Remap) \ | |||
cb(RemapBackwardData) \ | |||
@@ -54,6 +54,9 @@ DEF(BNForward, 8, true, true); | |||
DEF(BNBackward, 8, true, false); | |||
DEF(ROIPoolingForward, 4, true, false); | |||
DEF(ROIPoolingBackward, 5, true, false); | |||
DEF(CorrelationForward, 3, true, true); | |||
DEF(CorrelationBackwardData1, 4, true, true); | |||
DEF(CorrelationBackwardData2, 4, true, true); | |||
DEF(WarpPerspectiveForward, 3, true, false); | |||
DEF(WarpPerspectiveBackwardData, 3, true, false); | |||
DEF(WarpPerspectiveBackwardMat, 4, true, false); | |||
@@ -0,0 +1,371 @@ | |||
/** | |||
* \file dnn/src/cuda/roi_align/roi_align.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/correlation/correlation_cuda.cuh" | |||
#include <cfloat> | |||
#include "megdnn/dtype.h" | |||
#include "src/cuda/query_blocksize.cuh" | |||
#include "src/cuda/utils.cuh" | |||
#define ROUND_OFF 50000 | |||
using namespace megdnn; | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace correlation { | |||
#define CUDA_KERNEL_LOOP(vtid, vthreads) \ | |||
for (int vtid = blockIdx.x * blockDim.x + threadIdx.x; vtid < vthreads; \ | |||
vtid += blockDim.x * gridDim.x) | |||
template <typename T> | |||
__global__ void forward_kernel(const int nthreads, const T* data1, | |||
const T* data2, T* dst, const int bchannels, | |||
const int bheight, const int bwidth, | |||
const int tchannels, const int theight, | |||
const int twidth, const int kernel_size, | |||
const int max_displacement, const int stride1, | |||
const int stride2, const int pad_size, | |||
const bool is_multiply) { | |||
CUDA_KERNEL_LOOP(idx, nthreads) { | |||
int kernel_radius = (kernel_size - 1) / 2; | |||
int neighborhood_grid_radius = max_displacement / stride2; | |||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||
int x = idx % twidth; | |||
int y = (idx / twidth) % theight; | |||
int c = (idx / twidth / theight) % tchannels; | |||
int n = idx / twidth / theight / tchannels; | |||
// get src center position in image1 | |||
int x1 = x * stride1 + kernel_radius + max_displacement - pad_size; | |||
int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; | |||
// get offset of center in image2 | |||
int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * | |||
stride2; | |||
int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * | |||
stride2; | |||
int x2 = x1 + s2o; | |||
int y2 = y1 + s2p; | |||
// compute kernel correlation | |||
T sum = T(0.f); | |||
for (int i = -kernel_radius; i <= kernel_radius; i++) { | |||
for (int j = -kernel_radius; j <= kernel_radius; j++) { | |||
int in_x1 = x1 + i; | |||
int in_y1 = y1 + j; | |||
int in_x2 = x2 + i; | |||
int in_y2 = y2 + j; | |||
for (int channel = 0; channel < bchannels; channel++) { | |||
T tmp1 = T(0.f); | |||
T tmp2 = T(0.f); | |||
if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && | |||
in_y1 < bheight) { | |||
int idx1 = | |||
((n * bchannels + channel) * bheight + in_y1) * | |||
bwidth + | |||
in_x1; | |||
tmp1 = data1[idx1]; | |||
} | |||
if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && | |||
in_y2 < bheight) { | |||
int idx2 = | |||
((n * bchannels + channel) * bheight + in_y2) * | |||
bwidth + | |||
in_x2; | |||
tmp2 = data2[idx2]; | |||
} | |||
if (is_multiply) { | |||
sum += tmp1 * tmp2; | |||
} else { | |||
sum += fabsf(tmp1 - tmp2); | |||
} | |||
} | |||
} | |||
} | |||
const int sumelems = | |||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||
dst[idx] = sum / sumelems; | |||
} | |||
} | |||
template <typename T> | |||
__global__ void backward_kernel_data1( | |||
const int nthreads, const T* diff, const T* data1, const T* data2, | |||
T* grad1, const int bchannels, const int bheight, const int bwidth, | |||
const int tchannels, const int theight, const int twidth, | |||
const int kernel_size, const int max_displacement, const int stride1, | |||
const int stride2, const int pad_size, const bool is_multiply) { | |||
CUDA_KERNEL_LOOP(idx, nthreads) { | |||
int kernel_radius = (kernel_size - 1) / 2; | |||
int neighborhood_grid_radius = max_displacement / stride2; | |||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||
int x = idx % bwidth; | |||
int y = (idx / bwidth) % bheight; | |||
int c = (idx / bwidth / bheight) % bchannels; | |||
int n = idx / bwidth / bheight / bchannels; | |||
T tmp1 = data1[idx]; | |||
// Get X,Y ranges and clamp | |||
// round_off is a trick to enable integer division with ceil, even for | |||
// negative numbers We use a large offset, for the inner part not to | |||
// become negative. | |||
const int round_off = ROUND_OFF; | |||
const int round_off_s1 = stride1 * round_off; | |||
// we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) | |||
// for diff_x_min, diff_y_min, x,y at the position of right-down | |||
// ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 | |||
int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + | |||
round_off_s1 - 1) / | |||
stride1 + | |||
1 - round_off; | |||
int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + | |||
round_off_s1 - 1) / | |||
stride1 + | |||
1 - round_off; | |||
// floor (l - max_displacement + pad_size) / stride1 | |||
int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - | |||
round_off; | |||
int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - | |||
round_off; | |||
T sum = T(0.f); | |||
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||
(ymin <= theight - 1)) { | |||
xmin = max(0, xmin); | |||
xmax = min(twidth - 1, xmax); | |||
ymin = max(0, ymin); | |||
ymax = min(theight - 1, ymax); | |||
for (int p = -neighborhood_grid_radius; | |||
p <= neighborhood_grid_radius; p++) { | |||
for (int o = -neighborhood_grid_radius; | |||
o <= neighborhood_grid_radius; o++) { | |||
// Get bottom1 data: | |||
int s2o = stride2 * o; | |||
int s2p = stride2 * p; | |||
int x2 = x + s2o, y2 = y + s2p; | |||
int idx2 = | |||
((n * bchannels + c) * bheight + y2) * bwidth + x2; | |||
T tmp2 = T(0.f); | |||
if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { | |||
tmp2 = data2[idx2]; | |||
} | |||
int op = (p + neighborhood_grid_radius) * | |||
neighborhood_grid_width + | |||
(o + neighborhood_grid_radius); | |||
int diff_channels_offset = (n * tchannels + op); | |||
for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||
for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||
int idxtopdiff = | |||
(diff_channels_offset * theight + diff_y) * | |||
twidth + | |||
diff_x; | |||
if (is_multiply) { | |||
sum += diff[idxtopdiff] * tmp2; | |||
} else { | |||
T sign = (tmp1 >= tmp2) ? T(1.f) : T(-1.f); | |||
sum += diff[idxtopdiff] * sign; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
const int sumelems = | |||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||
grad1[idx] = sum / sumelems; | |||
} | |||
} | |||
template <typename T> | |||
__global__ void backward_kernel_data2( | |||
const int nthreads, const T* diff, const T* data1, const T* data2, | |||
T* grad2, const int bchannels, const int bheight, const int bwidth, | |||
const int tchannels, const int theight, const int twidth, | |||
const int kernel_size, const int max_displacement, const int stride1, | |||
const int stride2, const int pad_size, const bool is_multiply) { | |||
CUDA_KERNEL_LOOP(idx, nthreads) { | |||
int kernel_radius = (kernel_size - 1) / 2; | |||
int neighborhood_grid_radius = max_displacement / stride2; | |||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||
int x = idx % bwidth; | |||
int y = (idx / bwidth) % bheight; | |||
int c = (idx / bwidth / bheight) % bchannels; | |||
int n = idx / bwidth / bheight / bchannels; | |||
T tmp2 = data2[idx]; | |||
T sum = T(0.f); | |||
for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; | |||
p++) { | |||
for (int o = -neighborhood_grid_radius; | |||
o <= neighborhood_grid_radius; o++) { | |||
int s2o = o * stride2; | |||
int s2p = p * stride2; | |||
int x1 = x - s2o; | |||
int y1 = y - s2p; | |||
const int round_off = ROUND_OFF; | |||
const int round_off_s1 = stride1 * round_off; | |||
int xmin = (x1 + pad_size - 2 * kernel_radius - | |||
max_displacement + round_off_s1 - 1) / | |||
stride1 + | |||
1 - round_off; | |||
int ymin = (y1 + pad_size - 2 * kernel_radius - | |||
max_displacement + round_off_s1 - 1) / | |||
stride1 + | |||
1 - round_off; | |||
int xmax = (x1 + pad_size - max_displacement + round_off_s1) / | |||
stride1 - | |||
round_off; | |||
int ymax = (y1 + pad_size - max_displacement + round_off_s1) / | |||
stride1 - | |||
round_off; | |||
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||
(ymin <= theight - 1)) { | |||
xmin = max(0, xmin); | |||
xmax = min(twidth - 1, xmax); | |||
ymin = max(0, ymin); | |||
ymax = min(theight - 1, ymax); | |||
int idx1 = | |||
((n * bchannels + c) * bheight + y1) * bwidth + x1; | |||
T tmp1 = T(0.f); | |||
if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { | |||
tmp1 = data1[idx1]; | |||
} | |||
int op = (p + neighborhood_grid_radius) * | |||
neighborhood_grid_width + | |||
(o + neighborhood_grid_radius); | |||
int diff_channels_offset = (n * tchannels + op); | |||
for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||
for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||
int idxtopdiff = | |||
(diff_channels_offset * theight + diff_y) * | |||
twidth + | |||
diff_x; | |||
if (is_multiply) { | |||
sum += diff[idxtopdiff] * tmp1; | |||
} else { | |||
T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f); | |||
sum += diff[idxtopdiff] * sign; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
const int sumelems = | |||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||
grad2[idx] = sum / sumelems; | |||
} | |||
} | |||
template <typename T> | |||
void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, | |||
const int bchannels, const int bheight, const int bwidth, | |||
const int tchannels, const int theight, const int twidth, | |||
const int kernel_size, const int max_displacement, | |||
const int stride1, const int stride2, const int pad_size, | |||
const bool is_multiply, cudaStream_t stream) { | |||
int threads_block = query_blocksize_for_kernel(forward_kernel<T>); | |||
forward_kernel<T> | |||
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||
nthreads, data1, data2, dst, bchannels, bheight, bwidth, | |||
tchannels, theight, twidth, kernel_size, max_displacement, | |||
stride1, stride2, pad_size, is_multiply); | |||
after_kernel_launch(); | |||
} | |||
template <typename T> | |||
void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, | |||
const T* data2, T* grad1, const int bchannels, | |||
const int bheight, const int bwidth, | |||
const int tchannels, const int theight, | |||
const int twidth, const int kernel_size, | |||
const int max_displacement, const int stride1, | |||
const int stride2, const int pad_size, | |||
const bool is_multiply, cudaStream_t stream) { | |||
int threads_block = query_blocksize_for_kernel(backward_kernel_data1<T>); | |||
backward_kernel_data1<T> | |||
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||
nthreads, diff, data1, data2, grad1, bchannels, bheight, | |||
bwidth, tchannels, theight, twidth, kernel_size, | |||
max_displacement, stride1, stride2, pad_size, is_multiply); | |||
after_kernel_launch(); | |||
} | |||
template <typename T> | |||
void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, | |||
const T* data2, T* grad2, const int bchannels, | |||
const int bheight, const int bwidth, | |||
const int tchannels, const int theight, | |||
const int twidth, const int kernel_size, | |||
const int max_displacement, const int stride1, | |||
const int stride2, const int pad_size, | |||
const bool is_multiply, cudaStream_t stream) { | |||
int threads_block = query_blocksize_for_kernel(backward_kernel_data2<T>); | |||
backward_kernel_data2<T> | |||
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||
nthreads, diff, data1, data2, grad2, bchannels, bheight, | |||
bwidth, tchannels, theight, twidth, kernel_size, | |||
max_displacement, stride1, stride2, pad_size, is_multiply); | |||
after_kernel_launch(); | |||
} | |||
#define INST(T) \ | |||
template void forward_proxy<T>( \ | |||
const int, const T*, const T*, T* dst, const int, const int, \ | |||
const int, const int, const int, const int, const int, const int, \ | |||
const int, const int, const int, const bool, cudaStream_t); \ | |||
template void backward_proxy_data1<T>( \ | |||
const int, const T*, const T*, const T*, T*, const int, const int, \ | |||
const int, const int, const int, const int, const int, const int, \ | |||
const int, const int, const int, const bool, cudaStream_t); \ | |||
template void backward_proxy_data2<T>( \ | |||
const int, const T*, const T*, const T*, T*, const int, const int, \ | |||
const int, const int, const int, const int, const int, const int, \ | |||
const int, const int, const int, const bool, cudaStream_t); | |||
INST(dt_float32) | |||
INST(dt_float16) | |||
INST(dt_bfloat16) | |||
#undef INST | |||
} // namespace roi_align | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,51 @@ | |||
/** | |||
* \file dnn/src/cuda/correlation/correlation.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <cuda_runtime_api.h> | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace correlation { | |||
template <typename T> | |||
void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, | |||
const int bchannels, const int bheight, const int bwidth, | |||
const int tchannels, const int theight, const int twidth, | |||
const int kernel_size, const int max_displacement, | |||
const int stride1, const int stride2, const int pad_size, | |||
const bool is_multiply, cudaStream_t stream); | |||
template <typename T> | |||
void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, | |||
const T* data2, T* grad1, const int bchannels, | |||
const int bheight, const int bwidth, | |||
const int tchannels, const int theight, | |||
const int twidth, const int kernel_size, | |||
const int max_displacement, const int stride1, | |||
const int stride2, const int pad_size, | |||
const bool is_multiply, cudaStream_t stream); | |||
template <typename T> | |||
void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, | |||
const T* data2, T* grad2, const int bchannels, | |||
const int bheight, const int bwidth, | |||
const int tchannels, const int theight, | |||
const int twidth, const int kernel_size, | |||
const int max_displacement, const int stride1, | |||
const int stride2, const int pad_size, | |||
const bool is_multiply, cudaStream_t stream); | |||
} // namespace correlation | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,129 @@ | |||
/** | |||
* \file dnn/src/naive/correlation/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/correlation/opr_impl.h" | |||
#include "src/cuda/correlation/correlation_cuda.cuh" | |||
#include "src/cuda/utils.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(data1.layout, data2.layout, dst.layout, workspace.size); | |||
auto p = param(); | |||
auto stream = cuda_stream(handle()); | |||
int nthreads = dst.layout.total_nr_elems(); | |||
int stride1 = p.stride1; | |||
int stride2 = p.stride2; | |||
int kernel_size = p.kernel_size; | |||
int max_displacement = p.max_displacement; | |||
int pad_size = p.pad_size; | |||
bool is_multiply = p.is_multiply; | |||
int tchannels = dst.layout[1]; | |||
int theight = dst.layout[2], twidth = dst.layout[3]; | |||
int bchannels = data1.layout[1]; | |||
int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||
using namespace ::megdnn::cuda::correlation; | |||
#define cb(DType) \ | |||
if (data1.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
forward_proxy<T>(nthreads, data1.ptr<T>(), data2.ptr<T>(), \ | |||
dst.ptr<T>(), bchannels, bheight, bwidth, tchannels, \ | |||
theight, twidth, kernel_size, max_displacement, \ | |||
stride1, stride2, pad_size, is_multiply, stream); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
} | |||
void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, | |||
_megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, | |||
_megdnn_tensor_out grad1, | |||
_megdnn_workspace workspace) { | |||
check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, | |||
workspace.size); | |||
auto stream = cuda_stream(handle()); | |||
int nthreads = grad1.layout.total_nr_elems(); | |||
int stride1 = param().stride1; | |||
int stride2 = param().stride2; | |||
int kernel_size = param().kernel_size; | |||
int max_displacement = param().max_displacement; | |||
int pad_size = param().pad_size; | |||
bool is_multiply = param().is_multiply; | |||
int tchannels = diff.layout[1]; | |||
int theight = diff.layout[2], twidth = diff.layout[3]; | |||
int bchannels = data1.layout[1]; | |||
int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||
using namespace ::megdnn::cuda::correlation; | |||
#define cb(DType) \ | |||
if (diff.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
backward_proxy_data1<T>(nthreads, diff.ptr<T>(), data1.ptr<T>(), \ | |||
data2.ptr<T>(), grad1.ptr<T>(), bchannels, \ | |||
bheight, bwidth, tchannels, theight, twidth, \ | |||
kernel_size, max_displacement, stride1, \ | |||
stride2, pad_size, is_multiply, stream); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
} | |||
void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, | |||
_megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, | |||
_megdnn_tensor_out grad2, | |||
_megdnn_workspace workspace) { | |||
check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, | |||
workspace.size); | |||
auto p = param(); | |||
auto stream = cuda_stream(handle()); | |||
int nthreads = grad2.layout.total_nr_elems(); | |||
int stride1 = p.stride1; | |||
int stride2 = p.stride2; | |||
int kernel_size = p.kernel_size; | |||
int max_displacement = p.max_displacement; | |||
int pad_size = p.pad_size; | |||
bool is_multiply = p.is_multiply; | |||
int tchannels = diff.layout[1]; | |||
int theight = diff.layout[2], twidth = diff.layout[3]; | |||
int bchannels = data1.layout[1]; | |||
int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||
using namespace ::megdnn::cuda::correlation; | |||
#define cb(DType) \ | |||
if (diff.layout.dtype == DType()) { \ | |||
using T = typename DTypeTrait<DType>::ctype; \ | |||
backward_proxy_data2<T>(nthreads, diff.ptr<T>(), data1.ptr<T>(), \ | |||
data2.ptr<T>(), grad2.ptr<T>(), bchannels, \ | |||
bheight, bwidth, tchannels, theight, twidth, \ | |||
kernel_size, max_displacement, stride1, \ | |||
stride2, pad_size, is_multiply, stream); \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
} | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,61 @@ | |||
/** | |||
* \file dnn/src/naive/correlation/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
#include "src/cuda/cudnn_wrapper.h" | |||
namespace megdnn { | |||
namespace cuda { | |||
class CorrelationForwardImpl final : public CorrelationForward { | |||
public: | |||
using CorrelationForward::CorrelationForward; | |||
void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout& data1, | |||
const TensorLayout& data2, | |||
const TensorLayout& dst) override { | |||
return 0; | |||
} | |||
}; | |||
class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { | |||
public: | |||
using CorrelationBackwardData1::CorrelationBackwardData1; | |||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { | |||
public: | |||
using CorrelationBackwardData2::CorrelationBackwardData2; | |||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -24,6 +24,7 @@ | |||
#include "src/cuda/convolution/opr_impl.h" | |||
#include "src/cuda/convolution3d/opr_impl.h" | |||
#include "src/cuda/convpooling/opr_impl.h" | |||
#include "src/cuda/correlation/opr_impl.h" | |||
#include "src/cuda/cumsum/opr_impl.h" | |||
#include "src/cuda/cvt_color/opr_impl.h" | |||
#include "src/cuda/dct/opr_impl.h" | |||
@@ -0,0 +1,384 @@ | |||
/** | |||
* \file dnn/src/naive/correlation/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/naive/correlation/opr_impl.h" | |||
#include <algorithm> | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
#define ROUND_OFF 50000 | |||
using namespace megdnn; | |||
using namespace naive; | |||
using namespace std; | |||
namespace { | |||
using Param = megdnn::Correlation::Param; | |||
template <typename T> | |||
void forward(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||
_megdnn_tensor_out dst, const Param& param) { | |||
// data1 treat as no-padding tensor | |||
int total_nr_elems = dst.layout.total_nr_elems(); | |||
int stride1 = param.stride1, stride2 = param.stride2; | |||
int kernel_size = param.kernel_size; | |||
int kernel_radius = (kernel_size - 1) / 2; | |||
int max_displacement = param.max_displacement; | |||
int pad_size = param.pad_size; | |||
int tchannels = dst.layout[1]; | |||
int theight = dst.layout[2], twidth = dst.layout[3]; | |||
int bchannels = data1.layout[1]; | |||
int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||
int neighborhood_grid_radius = max_displacement / stride2; | |||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||
for (int idx = 0; idx < total_nr_elems; ++idx) { | |||
int x = idx % twidth; | |||
int y = (idx / twidth) % theight; | |||
int c = (idx / twidth / theight) % tchannels; | |||
int n = idx / twidth / theight / tchannels; | |||
// get src center position in image1 | |||
int x1 = x * stride1 + kernel_radius + max_displacement - pad_size; | |||
int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; | |||
// get offset of center in image2 | |||
int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * | |||
stride2; | |||
int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * | |||
stride2; | |||
int x2 = x1 + s2o; | |||
int y2 = y1 + s2p; | |||
// compute kernel correlation | |||
float sum = 0.; | |||
for (int i = -kernel_radius; i <= kernel_radius; i++) { | |||
for (int j = -kernel_radius; j <= kernel_radius; j++) { | |||
int in_x1 = x1 + i; | |||
int in_y1 = y1 + j; | |||
int in_x2 = x2 + i; | |||
int in_y2 = y2 + j; | |||
for (int channel = 0; channel < bchannels; channel++) { | |||
float tmp1 = 0.; | |||
float tmp2 = 0.; | |||
if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && | |||
in_y1 < bheight) { | |||
int idx1 = | |||
((n * bchannels + channel) * bheight + in_y1) * | |||
bwidth + | |||
in_x1; | |||
tmp1 = data1.ptr<T>()[idx1]; | |||
} | |||
if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && | |||
in_y2 < bheight) { | |||
int idx2 = | |||
((n * bchannels + channel) * bheight + in_y2) * | |||
bwidth + | |||
in_x2; | |||
tmp2 = data2.ptr<T>()[idx2]; | |||
} | |||
if (param.is_multiply) { | |||
sum += tmp1 * tmp2; | |||
} else { | |||
sum += fabsf(tmp1 - tmp2); | |||
} | |||
} | |||
} | |||
} | |||
const int sumelems = | |||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||
dst.ptr<T>()[idx] = sum / sumelems; | |||
} | |||
} | |||
template <typename T> | |||
void backward_data1(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||
const Param& param) { | |||
// data1 treat as no-padding tensor | |||
// int total_nr_elems = diff.layout.total_nr_elems(); | |||
int total_nr_elems = grad1.layout.total_nr_elems(); | |||
int stride1 = param.stride1, stride2 = param.stride2; | |||
int kernel_size = param.kernel_size; | |||
int kernel_radius = (kernel_size - 1) / 2; | |||
int max_displacement = param.max_displacement; | |||
int pad_size = param.pad_size; | |||
int tchannels = diff.layout[1]; | |||
int theight = diff.layout[2], twidth = diff.layout[3]; | |||
int bchannels = grad1.layout[1]; | |||
int bheight = grad1.layout[2], bwidth = grad1.layout[3]; | |||
int neighborhood_grid_radius = max_displacement / stride2; | |||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||
for (int idx = 0; idx < total_nr_elems; ++idx) { | |||
// idx for grad1 | |||
int x = idx % bwidth; | |||
int y = (idx / bwidth) % bheight; | |||
int c = (idx / bwidth / bheight) % bchannels; | |||
int n = idx / bwidth / bheight / bchannels; | |||
float tmp1 = data1.ptr<T>()[idx]; | |||
// Get X,Y ranges and clamp | |||
// round_off is a trick to enable integer division with ceil, even for | |||
// negative numbers We use a large offset, for the inner part not to | |||
// become negative. | |||
const int round_off = ROUND_OFF; | |||
const int round_off_s1 = stride1 * round_off; | |||
// we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) | |||
// for diff_x_min, diff_y_min, x,y at the position of right-down | |||
// ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 | |||
int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + | |||
round_off_s1 - 1) / | |||
stride1 + | |||
1 - round_off; | |||
int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + | |||
round_off_s1 - 1) / | |||
stride1 + | |||
1 - round_off; | |||
// floor (l - max_displacement + pad_size) / stride1 | |||
int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - | |||
round_off; | |||
int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - | |||
round_off; | |||
float sum = 0.; | |||
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||
(ymin <= theight - 1)) { | |||
xmin = max(0, xmin); | |||
xmax = min(twidth - 1, xmax); | |||
ymin = max(0, ymin); | |||
ymax = min(theight - 1, ymax); | |||
for (int p = -neighborhood_grid_radius; | |||
p <= neighborhood_grid_radius; p++) { | |||
for (int o = -neighborhood_grid_radius; | |||
o <= neighborhood_grid_radius; o++) { | |||
// Get bottom1 data: | |||
int s2o = stride2 * o; | |||
int s2p = stride2 * p; | |||
int x2 = x + s2p, y2 = y + s2o; | |||
int idx2 = | |||
((n * bchannels + c) * bheight + y2) * bwidth + x2; | |||
float tmp2 = 0.; | |||
if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { | |||
tmp2 = data2.ptr<T>()[idx2]; | |||
} | |||
int op = (p + neighborhood_grid_radius) * | |||
neighborhood_grid_width + | |||
(o + neighborhood_grid_radius); | |||
int diff_channels_offset = (n * tchannels + op); | |||
for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||
for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||
int idxtopdiff = | |||
(diff_channels_offset * theight + diff_y) * | |||
twidth + | |||
diff_x; | |||
if (param.is_multiply) { | |||
sum += diff.ptr<T>()[idxtopdiff] * tmp2; | |||
} else { | |||
T sign = (tmp1 > tmp2) ? T(1.) : T(-1.); | |||
sum += diff.ptr<T>()[idxtopdiff] * sign; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
const int sumelems = | |||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||
grad1.ptr<T>()[idx] = sum / sumelems; | |||
} | |||
} | |||
template <typename T> | |||
void backward_data2(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||
const Param& param) { | |||
// data1 treat as no-padding tensor | |||
int total_nr_elems = grad2.layout.total_nr_elems(); | |||
int stride1 = param.stride1, stride2 = param.stride2; | |||
int kernel_size = param.kernel_size; | |||
int kernel_radius = (kernel_size - 1) / 2; | |||
int max_displacement = param.max_displacement; | |||
int pad_size = param.pad_size; | |||
int tchannels = diff.layout[1]; | |||
int theight = diff.layout[2], twidth = diff.layout[3]; | |||
int bchannels = grad2.layout[1]; | |||
int bheight = grad2.layout[2], bwidth = grad2.layout[3]; | |||
int neighborhood_grid_radius = max_displacement / stride2; | |||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||
for (int idx = 0; idx < total_nr_elems; ++idx) { | |||
int x = idx % bwidth; | |||
int y = (idx / bwidth) % bheight; | |||
int c = (idx / bwidth / bheight) % bchannels; | |||
int n = idx / bwidth / bheight / bchannels; | |||
T tmp2 = data2.ptr<T>()[idx]; | |||
T sum = T(0.f); | |||
for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; | |||
p++) { | |||
for (int o = -neighborhood_grid_radius; | |||
o <= neighborhood_grid_radius; o++) { | |||
int s2o = o * stride2; | |||
int s2p = p * stride2; | |||
int x1 = x - s2o; | |||
int y1 = y - s2p; | |||
const int round_off = ROUND_OFF; | |||
const int round_off_s1 = stride1 * round_off; | |||
int xmin = (x1 + pad_size - 2 * kernel_radius - | |||
max_displacement + round_off_s1 - 1) / | |||
stride1 + | |||
1 - round_off; | |||
int ymin = (y1 + pad_size - 2 * kernel_radius - | |||
max_displacement + round_off_s1 - 1) / | |||
stride1 + | |||
1 - round_off; | |||
int xmax = (x1 + pad_size - max_displacement + round_off_s1) / | |||
stride1 - | |||
round_off; | |||
int ymax = (y1 + pad_size - max_displacement + round_off_s1) / | |||
stride1 - | |||
round_off; | |||
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||
(ymin <= theight - 1)) { | |||
xmin = max(0, xmin); | |||
xmax = min(twidth - 1, xmax); | |||
ymin = max(0, ymin); | |||
ymax = min(theight - 1, ymax); | |||
int idx1 = | |||
((n * bchannels + c) * bheight + y1) * bwidth + x1; | |||
T tmp1 = T(0.f); | |||
if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { | |||
tmp1 = data1.ptr<T>()[idx1]; | |||
} | |||
int op = (p + neighborhood_grid_radius) * | |||
neighborhood_grid_width + | |||
(o + neighborhood_grid_radius); | |||
int diff_channels_offset = (n * tchannels + op); | |||
for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||
for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||
int idxtopdiff = | |||
(diff_channels_offset * theight + diff_y) * | |||
twidth + | |||
diff_x; | |||
if (param.is_multiply) { | |||
sum += diff.ptr<T>()[idxtopdiff] * tmp1; | |||
} else { | |||
T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f); | |||
sum += diff.ptr<T>()[idxtopdiff] * sign; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
const int sumelems = | |||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||
grad2.ptr<T>()[idx] = sum / sumelems; | |||
} | |||
} | |||
} // namespace | |||
namespace megdnn { | |||
namespace naive { | |||
void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(data1.layout, data2.layout, dst.layout, workspace.size); | |||
#define cb(DType) \ | |||
if (data1.layout.dtype == DType()) { \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
forward<typename DTypeTrait<DType>::ctype>(data1, data2, dst, \ | |||
param())); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, | |||
_megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, | |||
_megdnn_tensor_out grad1, | |||
_megdnn_workspace workspace) { | |||
check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, | |||
workspace.size); | |||
#define cb(DType) \ | |||
if (diff.layout.dtype == DType()) { \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
backward_data1<typename DTypeTrait<DType>::ctype>( \ | |||
diff, data1, data2, grad1, param())); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, | |||
_megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, | |||
_megdnn_tensor_out grad2, | |||
_megdnn_workspace workspace) { | |||
check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, | |||
workspace.size); | |||
#define cb(DType) \ | |||
if (diff.layout.dtype == DType()) { \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
backward_data2<typename DTypeTrait<DType>::ctype>( \ | |||
diff, data1, data2, grad2, param())); \ | |||
return; \ | |||
} | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||
#undef cb | |||
megdnn_throw("bad dtype"); | |||
} | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,58 @@ | |||
/** | |||
* \file dnn/src/naive/correlation/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace naive { | |||
class CorrelationForwardImpl final : public CorrelationForward { | |||
public: | |||
using CorrelationForward::CorrelationForward; | |||
void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { | |||
public: | |||
using CorrelationBackwardData1::CorrelationBackwardData1; | |||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { | |||
public: | |||
using CorrelationBackwardData2::CorrelationBackwardData2; | |||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||
_megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace naive | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -30,6 +30,7 @@ | |||
#include "src/naive/convpooling/opr_impl.h" | |||
#include "src/naive/cumsum/opr_impl.h" | |||
#include "src/naive/cvt_color/opr_impl.h" | |||
#include "src/naive/correlation/opr_impl.h" | |||
#include "src/naive/dct/opr_impl.h" | |||
#include "src/naive/deformable_conv/opr_impl.h" | |||
#include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | |||
@@ -0,0 +1,73 @@ | |||
/** | |||
* \file dnn/test/common/correlation.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/opr_param_defs.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace correlation { | |||
struct TestArg { | |||
param::Correlation param; | |||
TensorShape data1, data2; | |||
TestArg(param::Correlation param, TensorShape data1, TensorShape data2) | |||
: param(param), data1(data1), data2(data2) {} | |||
}; | |||
inline static std::vector<TestArg> get_args() { | |||
std::vector<TestArg> args; | |||
param::Correlation cur_param; | |||
for (size_t batch_size : {2}) { | |||
for (size_t channel : {2}) { | |||
for (size_t height : {160}) { | |||
for (size_t width : {160}) { | |||
cur_param.is_multiply = true; | |||
cur_param.kernel_size = 3; | |||
cur_param.max_displacement = 3; | |||
cur_param.pad_size = 0; | |||
cur_param.stride1 = 1; | |||
cur_param.stride2 = 1; | |||
cur_param.format = megdnn::param::Correlation::Format::NCHW; | |||
args.emplace_back( | |||
cur_param, | |||
TensorShape{batch_size, channel, height, width}, | |||
TensorShape{batch_size, channel, height, width}); | |||
// cur_param.is_multiply = false; | |||
// cur_param.kernel_size = 1; | |||
// cur_param.max_displacement = 2; | |||
// cur_param.pad_size = 1; | |||
// cur_param.stride1 = 1; | |||
// cur_param.stride2 = 1; | |||
// cur_param.format = | |||
// megdnn::param::Correlation::Format::NCHW; | |||
// args.emplace_back( | |||
// cur_param, | |||
// TensorShape{batch_size, channel, height, width}, | |||
// TensorShape{batch_size, channel, height, width}); | |||
} | |||
} | |||
} | |||
} | |||
return args; | |||
} | |||
} // namespace correlation | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,160 @@ | |||
/** | |||
* \file dnn/test/cuda/correlation.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/cuda/fixture.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/correlation.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(CUDA, CORRELATION_FORWARD) { | |||
using namespace correlation; | |||
std::vector<TestArg> args = get_args(); | |||
Checker<Correlation> checker(handle_cuda()); | |||
for (auto&& arg : args) { | |||
checker.set_param(arg.param) | |||
.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.execs({arg.data1, arg.data2, {}}); | |||
} | |||
} | |||
TEST_F(CUDA, CORRELATION_BACKWARDDATA1) { | |||
ConstValue const_0{0}; | |||
using Param = CorrelationBackwardData1::Param; | |||
Param param; | |||
param.is_multiply = true; | |||
param.format = Param::Format::NCHW; | |||
param.stride1 = 2; | |||
param.stride2 = 2; | |||
param.kernel_size = 3; | |||
param.pad_size = 4; | |||
Checker<CorrelationBackwardData1> checker(handle_cuda()); | |||
checker.set_epsilon(1e-2); | |||
uint32_t pad_size = param.pad_size; | |||
uint32_t kernel_size = param.kernel_size; | |||
uint32_t stride1 = param.stride1; | |||
uint32_t stride2 = param.stride2; | |||
uint32_t max_displacement = param.max_displacement; | |||
auto run = [&](DType dtype) { | |||
for (size_t N : {1, 3}) | |||
for (size_t C : {1, 3}) | |||
for (size_t OH : {10, 100}) | |||
for (size_t OW : {10, 100}) { | |||
int paddedbottomheight = OH + 2 * pad_size; | |||
int paddedbottomwidth = OW + 2 * pad_size; | |||
uint32_t kernel_radius = (kernel_size - 1) / 2; | |||
uint32_t border_size = max_displacement + kernel_radius; | |||
uint32_t top_width = | |||
ceil(static_cast<float>(paddedbottomwidth - | |||
border_size * 2) / | |||
static_cast<float>(stride1)); | |||
uint32_t top_height = | |||
ceil(static_cast<float>(paddedbottomheight - | |||
border_size * 2) / | |||
static_cast<float>(stride1)); | |||
uint32_t neighborhood_grid_radius = | |||
max_displacement / stride2; | |||
uint32_t neighborhood_grid_width = | |||
neighborhood_grid_radius * 2 + 1; | |||
uint32_t top_channels = neighborhood_grid_width * | |||
neighborhood_grid_width; | |||
checker.set_param(param) | |||
.set_dtype(0, dtype) | |||
.set_dtype(1, dtype) | |||
.set_dtype(2, dtype) | |||
.set_dtype(3, dtype) | |||
.execs({{N, top_channels, top_height, | |||
top_width}, | |||
{N, C, OH, OW}, | |||
{N, C, OH, OW}, | |||
{N, C, OH, OW}}); | |||
} | |||
}; | |||
run(dtype::Float32()); | |||
run(dtype::Float16()); | |||
checker.set_epsilon(5e-2); | |||
run(dtype::BFloat16()); | |||
} | |||
TEST_F(CUDA, CORRELATION_BACKWARDDATA2) { | |||
ConstValue const_0{0}; | |||
using Param = CorrelationBackwardData2::Param; | |||
Param param; | |||
param.is_multiply = true; | |||
param.format = Param::Format::NCHW; | |||
param.stride1 = 2; | |||
param.stride2 = 2; | |||
param.kernel_size = 3; | |||
param.pad_size = 4; | |||
Checker<CorrelationBackwardData2> checker(handle_cuda()); | |||
checker.set_epsilon(1e-2); | |||
uint32_t pad_size = param.pad_size; | |||
uint32_t kernel_size = param.kernel_size; | |||
uint32_t stride1 = param.stride1; | |||
uint32_t stride2 = param.stride2; | |||
uint32_t max_displacement = param.max_displacement; | |||
auto run = [&](DType dtype) { | |||
for (size_t N : {1, 3}) | |||
for (size_t C : {1, 3}) | |||
for (size_t OH : {10, 100}) | |||
for (size_t OW : {10, 100}) { | |||
int paddedbottomheight = OH + 2 * pad_size; | |||
int paddedbottomwidth = OW + 2 * pad_size; | |||
uint32_t kernel_radius = (kernel_size - 1) / 2; | |||
uint32_t border_size = max_displacement + kernel_radius; | |||
uint32_t top_width = | |||
ceil(static_cast<float>(paddedbottomwidth - | |||
border_size * 2) / | |||
static_cast<float>(stride1)); | |||
uint32_t top_height = | |||
ceil(static_cast<float>(paddedbottomheight - | |||
border_size * 2) / | |||
static_cast<float>(stride1)); | |||
uint32_t neighborhood_grid_radius = | |||
max_displacement / stride2; | |||
uint32_t neighborhood_grid_width = | |||
neighborhood_grid_radius * 2 + 1; | |||
uint32_t top_channels = neighborhood_grid_width * | |||
neighborhood_grid_width; | |||
checker.set_param(param) | |||
.set_dtype(0, dtype) | |||
.set_dtype(1, dtype) | |||
.set_dtype(2, dtype) | |||
.set_dtype(3, dtype) | |||
.execs({{N, top_channels, top_height, | |||
top_width}, | |||
{N, C, OH, OW}, | |||
{N, C, OH, OW}, | |||
{N, C, OH, OW}}); | |||
} | |||
}; | |||
run(dtype::Float32()); | |||
run(dtype::Float16()); | |||
checker.set_epsilon(5e-2); | |||
run(dtype::BFloat16()); | |||
} | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |