diff --git a/dnn/include/megdnn/oprs/cv.h b/dnn/include/megdnn/oprs/cv.h
index c20f8945..6d75f541 100644
--- a/dnn/include/megdnn/oprs/cv.h
+++ b/dnn/include/megdnn/oprs/cv.h
@@ -270,6 +270,41 @@ protected:
};
using Remap = RemapForward;
+class RemapBackwardData : public RemapBase {
+ DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);
+
+public:
+ virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
+ _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
+
+ virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy,
+ const TensorLayout& diff,
+ const TensorLayout& grad) = 0;
+
+protected:
+ void check_exec(const TensorLayout& map_xy, const TensorLayout& diff,
+ const TensorLayout& grad, size_t workspace_in_bytes);
+};
+
+class RemapBackwardMat : public RemapBase {
+ DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);
+
+public:
+ virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
+ _megdnn_tensor_in diff, _megdnn_tensor_out grad,
+ _megdnn_workspace workspace) = 0;
+
+ virtual size_t get_workspace_in_bytes(const TensorLayout& src,
+ const TensorLayout& map_xy,
+ const TensorLayout& diff,
+ const TensorLayout& grad) = 0;
+
+protected:
+ void check_exec(const TensorLayout& src, const TensorLayout& map_xy,
+ const TensorLayout& diff, const TensorLayout& grad,
+ size_t workspace_in_bytes);
+};
+
class SeparableFilterBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase);
DEF_OPR_PARAM(SeparableFilter);
diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h
index f94cc0c3..354bbd0a 100644
--- a/dnn/src/common/handle_impl.h
+++ b/dnn/src/common/handle_impl.h
@@ -197,6 +197,8 @@ private:
cb(ROIAlignBackward) \
cb(BatchConvBiasForward) \
cb(Remap) \
+ cb(RemapBackwardData) \
+ cb(RemapBackwardMat) \
/*!
* \brief specialize HandleImpl::create_operator for a single opr type;
diff --git a/dnn/src/common/remap.cpp b/dnn/src/common/remap.cpp
index ff76f866..6cf3e57c 100644
--- a/dnn/src/common/remap.cpp
+++ b/dnn/src/common/remap.cpp
@@ -50,6 +50,7 @@ void RemapBase::check_layout_fwd(const TensorLayout& src,
megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str());
megdnn_assert(map_xy.shape[3] == 2);
megdnn_assert(map_xy.shape[0] == src.shape[0]);
+ megdnn_assert_contiguous(src);
// map_xy only support floa32 type
// map_xy always in NHWC format
@@ -85,6 +86,34 @@ void Remap::check_exec(const TensorLayout& src, const TensorLayout& map_xy,
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
+void RemapBackwardData::check_exec(const TensorLayout& map_xy,
+ const TensorLayout& diff,
+ const TensorLayout& grad,
+ size_t workspace_in_bytes) {
+ check_layout_fwd(grad, map_xy, diff);
+ megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
+ || grad.dtype == dtype::BFloat16()),
+ "Backward Remap only supports Float32/BFloat16.");
+ auto required_workspace_in_bytes =
+ get_workspace_in_bytes(map_xy, diff, grad);
+ megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
+}
+
+void RemapBackwardMat::check_exec(const TensorLayout& src,
+ const TensorLayout& map_xy,
+ const TensorLayout& diff,
+ const TensorLayout& grad,
+ size_t workspace_in_bytes) {
+ check_layout_fwd(src, map_xy, diff);
+ megdnn_assert_eq_layout(map_xy, grad);
+ megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
+ || grad.dtype == dtype::BFloat16()),
+ "Backward Remap only supports Float32/BFloat16.");
+ auto required_workspace_in_bytes =
+ get_workspace_in_bytes(src, map_xy, diff, grad);
+ megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
+}
+
} // namespace megdnn
// vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/remap/backward_data.cpp b/dnn/src/cuda/remap/backward_data.cpp
new file mode 100644
index 00000000..d0ef6c75
--- /dev/null
+++ b/dnn/src/cuda/remap/backward_data.cpp
@@ -0,0 +1,71 @@
+/**
+ * \file dnn/src/cuda/remap/backward_data.cpp
+ * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+ *
+ * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied.
+ */
+
+#include "src/cuda/remap/common.h"
+#include "src/cuda/remap/opr_impl.h"
+#include "src/cuda/utils.h"
+
+using namespace megdnn;
+using namespace cuda;
+
+void RemapBackwardDataImpl::exec(_megdnn_tensor_in map_xy,
+ _megdnn_tensor_in diff,
+ _megdnn_tensor_out grad,
+ _megdnn_workspace workspace) {
+ check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size);
+ megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR,
+ "only support LINEAR interpolationMode");
+ megdnn_assert(param().format == param::Remap::Format::NCHW,
+ "only support NCHW format for remap backward");
+ auto stream = cuda_stream(this->handle());
+ int N, C, IH, IW, OH, OW;
+ N = grad.layout.shape[0];
+ C = grad.layout.shape[1];
+ IH = grad.layout.shape[2];
+ IW = grad.layout.shape[3];
+ OH = map_xy.layout.shape[1];
+ OW = map_xy.layout.shape[2];
+
+#define cb(dt, _format, bmode) \
+ if (param().format == param::Remap::Format::_format && \
+ param().border_type == param::Remap::BorderMode::bmode) { \
+ using ctype = DTypeTrait
::ctype; \
+ remap::backwarddata_proxy( \
+ grad.compatible_ptr(), \
+ map_xy.compatible_ptr(), \
+ diff.compatible_ptr(), N, C, IH, IW, OH, OW, stream); \
+ break; \
+ }
+
+#define support_dtype(dt) \
+ case DTypeTrait::enumv: { \
+ cb(dt, NCHW, CONSTANT); \
+ cb(dt, NCHW, REPLICATE); \
+ cb(dt, NCHW, REFLECT); \
+ cb(dt, NCHW, REFLECT_101); \
+ cb(dt, NCHW, WRAP); \
+ megdnn_throw("unsupported border type in remap cuda"); \
+ }
+
+ switch (grad.layout.dtype.enumv()) {
+ support_dtype(dtype::Float32);
+ support_dtype(dtype::BFloat16);
+ default:
+ megdnn_throw("unsupported dtype in remap backward cuda\n");
+ }
+
+#undef support_dtype
+#undef cb
+}
+
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/remap/backward_data.cu b/dnn/src/cuda/remap/backward_data.cu
new file mode 100644
index 00000000..b6fe472b
--- /dev/null
+++ b/dnn/src/cuda/remap/backward_data.cu
@@ -0,0 +1,169 @@
+/**
+ * \file dnn/src/cuda/remap/backward_data.cu
+ * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+ *
+ * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied.
+ */
+#include
+#include "src/common/rounding_converter.cuh"
+#include "src/cuda/cv/kernel_common.cuh"
+#include "src/cuda/remap/common.h"
+#include "src/cuda/utils.cuh"
+
+using namespace megdnn;
+using namespace cuda;
+using namespace remap;
+using namespace rounding;
+
+namespace {
+
+template
+__device__ inline int get_offset(int height, int width, int channel, int h,
+ int w, int c);
+
+template <>
+__device__ inline int get_offset(
+ int height, int width, int channel, int h, int w, int c) {
+ return channel * h * w + height * w + width;
+}
+
+template
+struct GetSrcData {
+ __device__ static inline int get_index(int height, int width, int channel,
+ int h, int w, int c) {
+ height = megcv::border_interpolate(height, h);
+ width = megcv::border_interpolate(width, w);
+ return get_offset(height, width, channel, h, w, c);
+ }
+};
+
+template
+struct GetSrcData {
+ __device__ static inline int get_index(int height, int width, int channel,
+ int h, int w, int c) {
+ return (height >= 0 && height < h && width >= 0 && width < w)
+ ? get_offset(height, width, channel, h, w, c)
+ : -1;
+ }
+};
+
+template
+__global__ void kern_general(ctype* __restrict grad, const float* map_xy,
+ const ctype* diff, int C, int IH, int IW, int OH,
+ int OW) {
+ int ow = blockIdx.x * blockDim.x + threadIdx.x;
+ int oh = blockIdx.y * blockDim.y + threadIdx.y;
+ grad += blockIdx.z * C * IH * IW;
+ diff += blockIdx.z * C * OH * OW;
+ map_xy += blockIdx.z * 2 * OH * OW;
+ RoundingConverter round_converter;
+
+ if (ow < OW && oh < OH) {
+ float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
+ float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
+ int col = static_cast(floor(index_col));
+ int row = static_cast(floor(index_row));
+ float v = index_col - col; // alphah
+ float u = index_row - row; // alphaw
+ const float one = 1.f;
+ for (int c = 0; c < C; ++c) {
+ float hidden = static_cast(
+ diff[get_offset(oh, ow, c, OH, OW, C)]);
+
+ int a00 = GetSrcData::get_index(
+ row + 0, col + 0, c, IH, IW, C);
+ if (a00 != -1) {
+ atomic_add(grad + a00,
+ round_converter((one - u) * (one - v) * hidden));
+ }
+
+ int a01 = GetSrcData::get_index(
+ row + 0, col + 1, c, IH, IW, C);
+ if (a01 != -1) {
+ atomic_add(grad + a01, round_converter((one - u) * v * hidden));
+ }
+
+ int a10 = GetSrcData::get_index(
+ row + 1, col + 0, c, IH, IW, C);
+ if (a10 != -1) {
+ atomic_add(grad + a10, round_converter(u * (one - v) * hidden));
+ }
+
+ int a11 = GetSrcData::get_index(row + 1, col + 1, c, IH, IW,
+ C);
+ if (a11 != -1) {
+ atomic_add(grad + a11, round_converter(u * v * hidden));
+ }
+ }
+ }
+}
+
+template
+void dispatch_backwarddata(ctype* grad, const float* map_xy, const ctype* diff,
+ int N, int C, int IH, int IW, int OH, int OW,
+ cudaStream_t stream) {
+ const int BX = 32, BY = 16;
+ const int max_batch_size = 65535;
+ while (N) {
+ size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
+ dim3 threads(BX, BY);
+ dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
+
+ cuda_check(cudaMemsetAsync(
+ grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW,
+ stream));
+ kern_general<<>>(
+ grad, map_xy, diff, C, IH, IW, OH, OW);
+
+ N -= curr_batch_size;
+ grad += curr_batch_size * C * IH * IW;
+ diff += curr_batch_size * C * OH * OW;
+ map_xy += curr_batch_size * 2 * OH * OW;
+ }
+}
+
+} // anonymous namespace
+
+namespace megdnn {
+namespace cuda {
+namespace remap {
+
+template
+void backwarddata_proxy(ctype* grad, const float* map_xy, const ctype* diff,
+ int N, int C, int IH, int IW, int OH, int OW,
+ cudaStream_t stream) {
+ dispatch_backwarddata(grad, map_xy, diff, N, C, IH,
+ IW, OH, OW, stream);
+ after_kernel_launch();
+}
+
+#define INST(ctype, format, bmode) \
+ template void backwarddata_proxy< \
+ ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \
+ ctype*, const float*, const ctype*, int, int, int, int, int, int, \
+ cudaStream_t);
+
+#define FOR_FORMAT_BMODE(ctype) \
+ INST(ctype, NCHW, BORDER_CONSTANT) \
+ INST(ctype, NCHW, BORDER_REPLICATE) \
+ INST(ctype, NCHW, BORDER_REFLECT) \
+ INST(ctype, NCHW, BORDER_REFLECT_101) \
+ INST(ctype, NCHW, BORDER_WRAP)
+
+FOR_FORMAT_BMODE(float)
+MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
+
+#undef FOR_FORMAT_BMODE
+#undef INST
+
+} // namespace remap
+} // namespace cuda
+} // namespace megdnn
+
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/remap/backward_mat.cpp b/dnn/src/cuda/remap/backward_mat.cpp
new file mode 100644
index 00000000..aeccf19e
--- /dev/null
+++ b/dnn/src/cuda/remap/backward_mat.cpp
@@ -0,0 +1,73 @@
+/**
+ * \file dnn/src/cuda/remap/backward_mat.cpp
+ * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+ *
+ * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied.
+ */
+
+#include "src/cuda/remap/common.h"
+#include "src/cuda/remap/opr_impl.h"
+#include "src/cuda/utils.h"
+
+using namespace megdnn;
+using namespace cuda;
+
+void RemapBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
+ _megdnn_tensor_in diff, _megdnn_tensor_out grad,
+ _megdnn_workspace workspace) {
+ check_exec(src.layout, map_xy.layout, diff.layout, grad.layout,
+ workspace.size);
+ megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR,
+ "only support LINEAR interpolationMode");
+ megdnn_assert(param().format == param::Remap::Format::NCHW,
+ "only support NCHW format for remap backward");
+ auto stream = cuda_stream(this->handle());
+ int N, C, IH, IW, OH, OW;
+ N = src.layout.shape[0];
+ C = src.layout.shape[1];
+ IH = src.layout.shape[2];
+ IW = src.layout.shape[3];
+ OH = map_xy.layout.shape[1];
+ OW = map_xy.layout.shape[2];
+
+#define cb(dt, _format, bmode) \
+ if (param().format == param::Remap::Format::_format && \
+ param().border_type == param::Remap::BorderMode::bmode) { \
+ using ctype = DTypeTrait::ctype; \
+ remap::backwardmat_proxy( \
+ src.compatible_ptr(), \
+ map_xy.compatible_ptr(), \
+ diff.compatible_ptr(), \
+ grad.compatible_ptr(), N, C, IH, IW, OH, OW, \
+ param().scalar, stream); \
+ break; \
+ }
+
+#define support_dtype(dt) \
+ case DTypeTrait::enumv: { \
+ cb(dt, NCHW, CONSTANT); \
+ cb(dt, NCHW, REPLICATE); \
+ cb(dt, NCHW, REFLECT); \
+ cb(dt, NCHW, REFLECT_101); \
+ cb(dt, NCHW, WRAP); \
+ megdnn_throw("unsupported border type in remap cuda"); \
+ }
+
+ switch (src.layout.dtype.enumv()) {
+ support_dtype(dtype::Float32);
+ support_dtype(dtype::BFloat16);
+ default:
+ megdnn_throw("unsupported dtype in remap backward cuda\n");
+ }
+
+#undef support_dtype
+#undef cb
+}
+
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/remap/backward_mat.cu b/dnn/src/cuda/remap/backward_mat.cu
new file mode 100644
index 00000000..eb2a345c
--- /dev/null
+++ b/dnn/src/cuda/remap/backward_mat.cu
@@ -0,0 +1,170 @@
+/**
+ * \file dnn/src/cuda/remap/backward_mat.cu
+ * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
+ *
+ * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ * implied.
+ */
+#include
+#include "src/common/rounding_converter.cuh"
+#include "src/cuda/cv/kernel_common.cuh"
+#include "src/cuda/remap/common.h"
+#include "src/cuda/utils.cuh"
+
+using namespace megdnn;
+using namespace cuda;
+using namespace remap;
+using namespace rounding;
+
+namespace {
+
+template
+__device__ inline int get_offset(int height, int width, int channel, int h,
+ int w, int c);
+
+template <>
+__device__ inline int get_offset(
+ int height, int width, int channel, int h, int w, int c) {
+ return channel * h * w + height * w + width;
+}
+
+template
+struct GetSrcData {
+ __device__ static inline int get_index(int height, int width, int channel,
+ int h, int w, int c) {
+ height = megcv::border_interpolate(height, h);
+ width = megcv::border_interpolate(width, w);
+ return get_offset(height, width, channel, h, w, c);
+ }
+};
+
+template
+struct GetSrcData {
+ __device__ static inline int get_index(int height, int width, int channel,
+ int h, int w, int c) {
+ return (height >= 0 && height < h && width >= 0 && width < w)
+ ? get_offset(height, width, channel, h, w, c)
+ : -1;
+ }
+};
+
+template
+__global__ void kern_general(const ctype* src, const float* map_xy,
+ const ctype* diff, float* __restrict grad, int C,
+ int IH, int IW, int OH, int OW, float scalar) {
+ int ow = blockIdx.x * blockDim.x + threadIdx.x;
+ int oh = blockIdx.y * blockDim.y + threadIdx.y;
+ src += blockIdx.z * C * IH * IW;
+ diff += blockIdx.z * C * OH * OW;
+ map_xy += blockIdx.z * 2 * OH * OW;
+ grad += blockIdx.z * 2 * OH * OW;
+ RoundingConverter round_converter;
+
+ if (ow < OW && oh < OH) {
+ float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
+ float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
+ int col = static_cast(floor(index_col));
+ int row = static_cast(floor(index_row));
+ float v = index_col - col; // alphaw
+ float u = index_row - row; // alphah
+ const float one = 1.f;
+ for (int c = 0; c < C; ++c) {
+ float hidden = static_cast(
+ diff[get_offset(
+ oh, ow, c, OH, OW, C)]);
+ float du = 0.f, dv = 0.f;
+
+ int a00 = GetSrcData::get_index(
+ row + 0, col + 0, c, IH, IW, C);
+ int a01 = GetSrcData::get_index(
+ row + 0, col + 1, c, IH, IW, C);
+ int a10 = GetSrcData::get_index(
+ row + 1, col + 0, c, IH, IW, C);
+ int a11 = GetSrcData::get_index(
+ row + 1, col + 1, c, IH, IW, C);
+
+ dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
+ dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
+ dv -= ((a10 != -1) ? src[a10] : scalar) * u;
+ dv += ((a11 != -1) ? src[a11] : scalar) * u;
+
+ du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
+ du -= ((a01 != -1) ? src[a01] : scalar) * v;
+ du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
+ du += ((a11 != -1) ? src[a11] : scalar) * v;
+
+ grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv);
+ grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du);
+ }
+ }
+}
+
+template
+void dispatch_backwardmat(const ctype* src, const float* map_xy,
+ const ctype* diff, float* grad, int N, int C, int IH,
+ int IW, int OH, int OW, float scalar,
+ cudaStream_t stream) {
+ const int BX = 32, BY = 16;
+ const int max_batch_size = 65535;
+ while (N) {
+ size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
+ dim3 threads(BX, BY);
+ dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
+
+ cuda_check(cudaMemsetAsync(
+ grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2,
+ stream));
+ kern_general<<>>(
+ src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar);
+
+ N -= curr_batch_size;
+ src += curr_batch_size * C * IH * IW;
+ diff += curr_batch_size * C * OH * OW;
+ map_xy += curr_batch_size * 2 * OH * OW;
+ grad += curr_batch_size * 2 * OH * OW;
+ }
+}
+
+} // anonymous namespace
+
+namespace megdnn {
+namespace cuda {
+namespace remap {
+
+template
+void backwardmat_proxy(const ctype* src, const float* map_xy, const ctype* diff,
+ float* grad, int N, int C, int IH, int IW, int OH,
+ int OW, float scalar, cudaStream_t stream) {
+ dispatch_backwardmat(src, map_xy, diff, grad, N, C,
+ IH, IW, OH, OW, scalar, stream);
+ after_kernel_launch();
+}
+
+#define INST(ctype, format, bmode) \
+ template void backwardmat_proxy( \
+ const ctype*, const float*, const ctype*, float*, int, int, int, \
+ int, int, int, float, cudaStream_t);
+
+#define FOR_FORMAT_BMODE(ctype) \
+ INST(ctype, NCHW, BORDER_CONSTANT) \
+ INST(ctype, NCHW, BORDER_REPLICATE) \
+ INST(ctype, NCHW, BORDER_REFLECT) \
+ INST(ctype, NCHW, BORDER_REFLECT_101) \
+ INST(ctype, NCHW, BORDER_WRAP)
+
+FOR_FORMAT_BMODE(float)
+MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
+
+#undef FOR_FORMAT_BMODE
+#undef INST
+
+} // namespace remap
+} // namespace cuda
+} // namespace megdnn
+
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/cuda/remap/common.h b/dnn/src/cuda/remap/common.h
index 82593c9e..602fb8b7 100644
--- a/dnn/src/cuda/remap/common.h
+++ b/dnn/src/cuda/remap/common.h
@@ -24,7 +24,17 @@ namespace remap {
template
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, float scalar,
- int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream);
+ cudaStream_t stream);
+
+template
+void backwarddata_proxy(ctype* grad, const float* map_xy, const ctype* diff,
+ int N, int C, int IH, int IW, int OH, int OW,
+ cudaStream_t stream);
+
+template
+void backwardmat_proxy(const ctype* src, const float* map_xy, const ctype* diff,
+ float* grad, int N, int C, int IH, int IW, int OH,
+ int OW, float scalar, cudaStream_t stream);
} // namespace remap
} // namespace cuda
diff --git a/dnn/src/cuda/remap/forward.cpp b/dnn/src/cuda/remap/forward.cpp
index 7ccf1ef5..312f7935 100644
--- a/dnn/src/cuda/remap/forward.cpp
+++ b/dnn/src/cuda/remap/forward.cpp
@@ -22,9 +22,10 @@ using namespace cuda;
void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
_megdnn_tensor_in dst, _megdnn_workspace workspace) {
check_exec(src.layout, map_xy.layout, dst.layout, workspace.size);
+ megdnn_assert(map_xy.layout.dtype.enumv() ==
+ DTypeTrait::enumv);
auto stream = cuda_stream(this->handle());
int N, C, IH, IW, OH, OW;
- ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0;
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];
@@ -36,10 +37,6 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
- S_IN = src.layout.stride[0];
- S_IC = src.layout.stride[1];
- S_IH = src.layout.stride[2];
- S_IW = src.layout.stride[3];
} else if (param().format == param::Remap::Format::NHWC) {
N = src.layout.shape[0];
C = src.layout.shape[3];
@@ -58,7 +55,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
src.compatible_ptr(), \
map_xy.compatible_ptr(), \
dst.compatible_ptr(), N, C, IH, IW, OH, OW, \
- param().scalar, S_IN, S_IC, S_IH, S_IW, stream); \
+ param().scalar, stream); \
break; \
}
@@ -78,15 +75,16 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
}
switch (src.layout.dtype.enumv()) {
- support_dtype(dtype::Float32)
- MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16))
- support_dtype(dtype::Int8)
- support_dtype(dtype::Uint8)
+ support_dtype(dtype::Float32);
+ MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16));
+ MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
+ support_dtype(dtype::Int8);
+ support_dtype(dtype::Uint8);
default:
megdnn_throw("unsupported dtype in remap cuda");
}
-#undef supported_dtype
+#undef support_dtype
#undef cb
}
diff --git a/dnn/src/cuda/remap/forward.cu b/dnn/src/cuda/remap/forward.cu
index fcd6f19f..b417306a 100644
--- a/dnn/src/cuda/remap/forward.cu
+++ b/dnn/src/cuda/remap/forward.cu
@@ -23,17 +23,6 @@ using namespace rounding;
namespace {
-template
-struct DirectSrcVisitor {
- const ctype* ptr;
-
- __device__ __forceinline__ const ctype* get(int batch, int im_size) {
- return ptr + batch * im_size;
- }
-
- void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; }
-};
-
template
__device__ inline int get_offset(int height, int width, int channel, int h,
int w, int c);
@@ -74,14 +63,13 @@ struct GetSrcData {
}
};
-template
-__global__ void kern_general(SrcVisitor src, const float* map_xy,
+template
+__global__ void kern_general(const ctype* __restrict sptr, const float* map_xy,
ctype* __restrict dst, int C, int IH, int IW,
- int OH, int OW, int S_IN, int S_IC, int S_IH,
- int S_IW, float scalar) {
+ int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
- const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
+ sptr += blockIdx.z * C * IH * IW;
dst += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
RoundingConverter round_converter;
@@ -89,8 +77,8 @@ __global__ void kern_general(SrcVisitor src, const float* map_xy,
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
- int col = (int)floor(index_col);
- int row = (int)floor(index_row);
+ int col = static_cast(floor(index_col));
+ int row = static_cast(floor(index_row));
float v = index_col - col;
float u = index_row - row;
for (int c = 0; c < C; ++c) {
@@ -106,22 +94,25 @@ __global__ void kern_general(SrcVisitor src, const float* map_xy,
ctype a11 = GetSrcData::get(sptr, row + 1, col + 1, c, IH,
IW, C, scalar);
- dst[get_offset(oh, ow, c, OH, OW,
- C)] =
- round_converter(a00 * (1.f - u) * (1.f - v) +
- a01 * (1.f - u) * v + a10 * (1.f - v) * u +
- a11 * u * v);
+ /* in remap, we use float as the type of intermediate result */
+ float result = static_cast(a00) * (1.f - u) * (1.f - v) +
+ static_cast(a01) * (1.f - u) * v +
+ static_cast(a10) * (1.f - v) * u +
+ static_cast(a11) * u * v;
+ dst[get_offset(
+ oh, ow, c, OH, OW, C)] = round_converter(result);
}
}
}
-template
-__global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy,
- ctype* __restrict dst, int C, int IH, int IW,
- int OH, int OW, float scalar) {
+template
+__global__ void kern_general_nhwc(const ctype* __restrict sptr,
+ const float* map_xy, ctype* __restrict dst,
+ int C, int IH, int IW, int OH, int OW,
+ float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
- const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW);
+ sptr += blockIdx.z * C * IH * IW;
dst += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
RoundingConverter round_converter;
@@ -129,8 +120,8 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy,
if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
- int col = (int)floor(index_col);
- int row = (int)floor(index_row);
+ int col = static_cast(floor(index_col));
+ int row = static_cast(floor(index_row));
float v = index_col - col;
float u = index_row - row;
for (int c = 0; c < C; ++c) {
@@ -146,21 +137,21 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy,
ctype a11 = GetSrcData::get(sptr, row + 1, col + 1, c, IH,
IW, C, scalar);
- dst[get_offset(oh, ow, c, OH, OW,
- C)] =
- round_converter(a00 * (1.f - u) * (1.f - v) +
- a01 * (1.f - u) * v + a10 * (1.f - v) * u +
- a11 * u * v);
+ /* in remap, we use float as the type of intermediate result */
+ float result = static_cast(a00) * (1.f - u) * (1.f - v) +
+ static_cast(a01) * (1.f - u) * v +
+ static_cast(a10) * (1.f - v) * u +
+ static_cast(a11) * u * v;
+ dst[get_offset(
+ oh, ow, c, OH, OW, C)] = round_converter(result);
}
}
}
-template
-void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst,
- int N, int C, int IH, int IW, int OH, int OW,
- float scalar, int S_IN, int S_IC, int S_IH, int S_IW,
- cudaStream_t stream) {
+template
+void dispatch_forward(const ctype* src, const float* map_xy, ctype* dst, int N,
+ int C, int IH, int IW, int OH, int OW, float scalar,
+ cudaStream_t stream) {
const int BX = 32, BY = 16;
const int max_batch_size = 65535;
@@ -170,19 +161,17 @@ void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst,
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);
if (format == param_enumv::Remap::Format::NCHW) {
- kern_general
- <<>>(src, map_xy, dst, C, IH,
- IW, OH, OW, S_IN, S_IC,
- S_IH, S_IW, scalar);
+ kern_general<<>>(
+ src, map_xy, dst, C, IH, IW, OH, OW, scalar);
} else if (format == param_enumv::Remap::Format::NHWC) {
- kern_general_nhwc
- <<>>(src, map_xy, dst, C, IH,
- IW, OH, OW, scalar);
+ kern_general_nhwc<<>>(
+ src, map_xy, dst, C, IH, IW, OH, OW, scalar);
}
N -= curr_batch_size;
- src.move_batch(curr_batch_size, C * IH * IW);
+ src += curr_batch_size * C * IH * IW;
dst += curr_batch_size * C * OH * OW;
+ map_xy += curr_batch_size * OH * OW * 2;
}
}
@@ -195,22 +184,17 @@ namespace remap {
template
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, float scalar,
- int S_IN, int S_IC, int S_IH, int S_IW,
cudaStream_t stream) {
- DirectSrcVisitor visitor;
- visitor.ptr = src;
- using SrcVisitor = DirectSrcVisitor;
- dispatch_with_visitor(
- visitor, map_xy, dst, N, C, IH, IW, OH, OW, scalar, S_IN, S_IC,
- S_IH, S_IW, stream);
+ dispatch_forward(src, map_xy, dst, N, C, IH, IW, OH,
+ OW, scalar, stream);
after_kernel_launch();
}
-#define INST(ctype, format, bmode) \
- template void forward_proxy( \
- const ctype* src, const float*, ctype*, int, int, int, int, int, \
- int, float, int, int, int, int, cudaStream_t);
+#define INST(ctype, format, bmode) \
+ template void forward_proxy( \
+ const ctype*, const float*, ctype*, int, int, int, int, int, int, \
+ float, cudaStream_t);
#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT) \
@@ -226,11 +210,13 @@ void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N,
FOR_FORMAT_BMODE(float)
MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
+MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16))
FOR_FORMAT_BMODE(int8_t)
FOR_FORMAT_BMODE(uint8_t)
-#undef FOR_BMODE
+#undef FOR_FORMAT_BMODE
#undef INST
+
} // namespace remap
} // namespace cuda
} // namespace megdnn
diff --git a/dnn/src/cuda/remap/opr_impl.h b/dnn/src/cuda/remap/opr_impl.h
index a812e217..bdc5f66a 100644
--- a/dnn/src/cuda/remap/opr_impl.h
+++ b/dnn/src/cuda/remap/opr_impl.h
@@ -15,13 +15,41 @@
namespace megdnn {
namespace cuda {
class RemapImpl final : public Remap {
+public:
using Remap::Remap;
- void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out,
- _megdnn_workspace) override;
+ void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
+ _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
- size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
- const TensorLayout&) override {
+ size_t get_workspace_in_bytes(const TensorLayout& src,
+ const TensorLayout& map_xy,
+ const TensorLayout& dst) override {
+ return 0;
+ }
+};
+
+class RemapBackwardDataImpl final : public RemapBackwardData {
+public:
+ using RemapBackwardData::RemapBackwardData;
+ void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
+ _megdnn_tensor_out grad, _megdnn_workspace workspace) override;
+ size_t get_workspace_in_bytes(const TensorLayout& map_xy,
+ const TensorLayout& diff,
+ const TensorLayout& grad) override {
+ return 0;
+ }
+};
+
+class RemapBackwardMatImpl final : public RemapBackwardMat {
+public:
+ using RemapBackwardMat::RemapBackwardMat;
+ void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
+ _megdnn_tensor_in diff, _megdnn_tensor_out grad,
+ _megdnn_workspace workspace) override;
+ size_t get_workspace_in_bytes(const TensorLayout& src,
+ const TensorLayout& map_xy,
+ const TensorLayout& diff,
+ const TensorLayout& grad) override {
return 0;
}
};
diff --git a/dnn/src/naive/remap/opr_impl.cpp b/dnn/src/naive/remap/opr_impl.cpp
index 42e72aa3..26758156 100644
--- a/dnn/src/naive/remap/opr_impl.cpp
+++ b/dnn/src/naive/remap/opr_impl.cpp
@@ -12,11 +12,13 @@
#include "src/naive/remap/opr_impl.h"
#include "src/common/cv/helper.h"
+#include "src/common/rounding_converter.cuh"
#include "src/common/utils.h"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace naive;
+using namespace rounding;
namespace {
template
@@ -36,35 +38,46 @@ inline int get_offset(int height, int width,
return height * w * c + width * c + channel;
}
-template
struct GetSrcData {
- static inline DataType get(const DataType* src, int height, int width,
- int channel, int h, int w, int c, float,
- std::function) {
+ static inline ctype get(const ctype* src, int height, int width,
+ int channel, int h, int w, int c, float) {
height = megcv::border_interpolate(height, h);
width = megcv::border_interpolate(width, w);
return src[get_offset(height, width, channel, h, w, c)];
}
+ static inline int get_index(int height, int width, int channel, int h,
+ int w, int c) {
+ height = megcv::border_interpolate(height, h);
+ width = megcv::border_interpolate(width, w);
+ return get_offset(height, width, channel, h, w, c);
+ }
};
-template
-struct GetSrcData {
- static inline DataType get(const DataType* src, int height, int width,
- int channel, int h, int w, int c, float scalar,
- std::function round) {
+template
+struct GetSrcData {
+ static inline ctype get(const ctype* src, int height, int width,
+ int channel, int h, int w, int c, float scalar) {
+ RoundingConverter round;
return (height >= 0 && height < h && width >= 0 && width < w)
? src[get_offset(height, width, channel, h, w,
c)]
- : static_cast(round(scalar));
+ : round(scalar);
+ }
+ static inline int get_index(int height, int width, int channel, int h,
+ int w, int c) {
+ return (height >= 0 && height < h && width >= 0 && width < w)
+ ? get_offset(height, width, channel, h, w, c)
+ : -1;
}
};
-template
-void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst,
- int N, int C, int IH, int IW, int OH, int OW, float scalar,
- std::function round) {
+void remap_LINEAR(const ctype* src, const float* map_xy, ctype* dst, int N,
+ int C, int IH, int IW, int OH, int OW, float scalar) {
+ RoundingConverter round_converter;
for (int n = 0; n < N;
++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) {
for (int h = 0; h < OH; ++h) {
@@ -73,47 +86,131 @@ void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst,
float index_row = map_xy[h * OW * 2 + w * 2 + 1];
int col = static_cast(floor(index_col));
int row = static_cast(floor(index_row));
- float v = index_col - col;
- float u = index_row - row;
- float one = 1.f;
+ float v = index_col - col; // alphaw
+ float u = index_row - row; // alphah
+ const float one = 1.f;
for (int c = 0; c < C; ++c) {
- DataType a00 =
- GetSrcData::get(
- src, row + 0, col + 0, c, IH, IW, C, scalar,
- round);
- DataType a01 =
- GetSrcData::get(
- src, row + 0, col + 1, c, IH, IW, C, scalar,
- round);
- DataType a10 =
- GetSrcData::get(
- src, row + 1, col + 0, c, IH, IW, C, scalar,
- round);
- DataType a11 =
- GetSrcData::get(
- src, row + 1, col + 1, c, IH, IW, C, scalar,
- round);
+ ctype a00 = GetSrcData::get(
+ src, row + 0, col + 0, c, IH, IW, C, scalar);
+ ctype a01 = GetSrcData::get(
+ src, row + 0, col + 1, c, IH, IW, C, scalar);
+ ctype a10 = GetSrcData::get(
+ src, row + 1, col + 0, c, IH, IW, C, scalar);
+ ctype a11 = GetSrcData::get(
+ src, row + 1, col + 1, c, IH, IW, C, scalar);
dst[get_offset(h, w, c, OH, OW, C)] =
- static_cast(
- round(a00 * (one - u) * (one - v) +
- a01 * (one - u) * v +
- a10 * (one - v) * u + a11 * u * v));
+ round_converter(a00 * (one - v) * (one - u) +
+ a01 * (one - u) * v +
+ a10 * (one - v) * u + a11 * u * v);
}
}
}
}
}
-template
-struct Round {
- static inline DataType round(float x) { return std::round(x); }
-};
+template
+void remap_LINEAR_backwarddata(ctype* grad, const float* map_xy,
+ const ctype* diff, int N, int C, int IH, int IW,
+ int OH, int OW) {
+ RoundingConverter round_converter;
+ std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW);
+ for (int n = 0; n < N;
+ ++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) {
+ for (int h = 0; h < OH; ++h) {
+ for (int w = 0; w < OW; ++w) {
+ float index_col = map_xy[h * OW * 2 + w * 2 + 0];
+ float index_row = map_xy[h * OW * 2 + w * 2 + 1];
+ int col = static_cast(floor(index_col));
+ int row = static_cast(floor(index_row));
+ float v = index_col - col; // alphaw
+ float u = index_row - row; // alphah
+ const float one = 1.f;
+ for (int c = 0; c < C; ++c) {
+ ctype hidden = diff[get_offset(h, w, c, OH, OW, C)];
-template
-struct Round {
- static inline DataType round(float x) { return static_cast(x); }
-};
+ int a00 = GetSrcData::get_index(
+ row + 0, col + 0, c, IH, IW, C);
+ if (a00 != -1) {
+ grad[a00] +=
+ round_converter((one - v) * (one - u) * hidden);
+ }
+
+ int a01 = GetSrcData::get_index(
+ row + 0, col + 1, c, IH, IW, C);
+ if (a01 != -1) {
+ grad[a01] += round_converter((one - u) * v * hidden);
+ }
+
+ int a10 = GetSrcData::get_index(
+ row + 1, col + 0, c, IH, IW, C);
+ if (a10 != -1) {
+ grad[a10] += round_converter(u * (one - v) * hidden);
+ }
+
+ int a11 = GetSrcData::get_index(
+ row + 1, col + 1, c, IH, IW, C);
+ if (a11 != -1) {
+ grad[a11] += round_converter(v * u * hidden);
+ }
+ }
+ }
+ }
+ }
+}
+
+template
+void remap_LINEAR_backwardmat(const ctype* src, const float* map_xy,
+ const ctype* diff, float* grad, int N, int C,
+ int IH, int IW, int OH, int OW, float scalar) {
+ RoundingConverter round_converter;
+ std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW);
+ for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW,
+ map_xy += OH * OW * 2, grad += OH * OW * 2) {
+ for (int h = 0; h < OH; ++h) {
+ for (int w = 0; w < OW; ++w) {
+ float index_col = map_xy[h * OW * 2 + w * 2 + 0];
+ float index_row = map_xy[h * OW * 2 + w * 2 + 1];
+ int col = static_cast(floor(index_col));
+ int row = static_cast(floor(index_row));
+ float v = index_col - col; // alphaw
+ float u = index_row - row; // alphah
+ const float one = 1.f;
+ for (int c = 0; c < C; ++c) {
+ float hidden = static_cast(
+ diff[get_offset(h, w, c, OH, OW, C)]);
+ float du = 0.f, dv = 0.f;
+
+ int a00 = GetSrcData::get_index(
+ row + 0, col + 0, c, IH, IW, C);
+ int a01 = GetSrcData::get_index(
+ row + 0, col + 1, c, IH, IW, C);
+ int a10 = GetSrcData::get_index(
+ row + 1, col + 0, c, IH, IW, C);
+ int a11 = GetSrcData::get_index(
+ row + 1, col + 1, c, IH, IW, C);
+
+ dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u);
+ dv += ((a01 != -1) ? src[a01] : scalar) * (one - u);
+ dv -= ((a10 != -1) ? src[a10] : scalar) * u;
+ dv += ((a11 != -1) ? src[a11] : scalar) * u;
+
+ du -= ((a00 != -1) ? src[a00] : scalar) * (one - v);
+ du -= ((a01 != -1) ? src[a01] : scalar) * v;
+ du += ((a10 != -1) ? src[a10] : scalar) * (one - v);
+ du += ((a11 != -1) ? src[a11] : scalar) * v;
+
+ grad[h * OW * 2 + w * 2 + 0] +=
+ round_converter(hidden * dv);
+ grad[h * OW * 2 + w * 2 + 1] +=
+ round_converter(hidden * du);
+ }
+ }
+ }
+ }
+}
} // namespace
@@ -148,8 +245,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
src.compatible_ptr(), \
map_xy.compatible_ptr(), \
dst.compatible_ptr(), N, C, IH, IW, OH, OW, \
- param().scalar, \
- Round::category>::round))); \
+ param().scalar))); \
break; \
}
@@ -172,6 +268,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
support_dtype(dtype::Float32);
MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16));
+ MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
support_dtype(dtype::Int8);
support_dtype(dtype::Uint8);
#undef cb
@@ -181,3 +278,109 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
megdnn_throw("unsupported dtype in remap naive\n");
}
}
+
+void RemapBackwardDataImpl::exec(_megdnn_tensor_in map_xy,
+ _megdnn_tensor_in diff,
+ _megdnn_tensor_out grad,
+ _megdnn_workspace workspace) {
+ check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size);
+ megdnn_assert(param().format == param::Remap::Format::NCHW,
+ "only support NCHW format for remap backward");
+ int N, C, IH, IW, OH, OW;
+ N = grad.layout.shape[0];
+ C = grad.layout.shape[1];
+ IH = grad.layout.shape[2];
+ IW = grad.layout.shape[3];
+ OH = map_xy.layout.shape[1];
+ OW = map_xy.layout.shape[2];
+ switch (diff.layout.dtype.enumv()) {
+#define cb(dt, fmt, border, interpolation) \
+ if (param().format == param::Remap::Format::fmt && \
+ param().border_type == param::Remap::BorderMode::border && \
+ param().imode == param::Remap::InterpolationMode::interpolation) { \
+ using ctype = DTypeTrait::ctype; \
+ MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwarddata< \
+ ctype, param::Remap::Format::fmt, \
+ param::Remap::BorderMode::border>( \
+ grad.compatible_ptr(), \
+ map_xy.compatible_ptr(), \
+ diff.compatible_ptr(), N, C, IH, IW, OH, OW))); \
+ break; \
+ }
+
+#define support_dtype(dt) \
+ case DTypeTrait::enumv: { \
+ cb(dt, NCHW, CONSTANT, LINEAR); \
+ cb(dt, NCHW, REPLICATE, LINEAR); \
+ cb(dt, NCHW, REFLECT, LINEAR); \
+ cb(dt, NCHW, REFLECT_101, LINEAR); \
+ cb(dt, NCHW, WRAP, LINEAR); \
+ megdnn_throw( \
+ "format, border type or imode is incorrect in remap navie " \
+ "with dtype = " #dt); \
+ }
+
+ support_dtype(dtype::Float32);
+ MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
+#undef cb
+#undef support_dtype
+
+ default:
+ megdnn_throw("unsupported dtype in remap backward naive\n");
+ }
+}
+
+void RemapBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
+ _megdnn_tensor_in diff, _megdnn_tensor_out grad,
+ _megdnn_workspace workspace) {
+ check_exec(src.layout, map_xy.layout, diff.layout, grad.layout,
+ workspace.size);
+ megdnn_assert(param().format == param::Remap::Format::NCHW,
+ "only support NCHW format for remap backward");
+ int N, C, IH, IW, OH, OW;
+ N = src.layout.shape[0];
+ C = src.layout.shape[1];
+ IH = src.layout.shape[2];
+ IW = src.layout.shape[3];
+ OH = map_xy.layout.shape[1];
+ OW = map_xy.layout.shape[2];
+ switch (src.layout.dtype.enumv()) {
+#define cb(dt, fmt, border, interpolation) \
+ if (param().format == param::Remap::Format::fmt && \
+ param().border_type == param::Remap::BorderMode::border && \
+ param().imode == param::Remap::InterpolationMode::interpolation) { \
+ using ctype = DTypeTrait::ctype; \
+ MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwardmat< \
+ ctype, param::Remap::Format::fmt, \
+ param::Remap::BorderMode::border>( \
+ src.compatible_ptr(), \
+ map_xy.compatible_ptr(), \
+ diff.compatible_ptr(), \
+ grad.compatible_ptr(), N, C, IH, IW, OH, OW, \
+ param().scalar))); \
+ break; \
+ }
+
+#define support_dtype(dt) \
+ case DTypeTrait::enumv: { \
+ cb(dt, NCHW, CONSTANT, LINEAR); \
+ cb(dt, NCHW, REPLICATE, LINEAR); \
+ cb(dt, NCHW, REFLECT, LINEAR); \
+ cb(dt, NCHW, REFLECT_101, LINEAR); \
+ cb(dt, NCHW, WRAP, LINEAR); \
+ megdnn_throw( \
+ "format, border type or imode is incorrect in remap navie " \
+ "with dtype = " #dt); \
+ }
+
+ support_dtype(dtype::Float32);
+ MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
+#undef cb
+#undef support_dtype
+
+ default:
+ megdnn_throw("unsupported dtype in remap backward naive\n");
+ }
+}
+
+// vim: syntax=cpp.doxygen
diff --git a/dnn/src/naive/remap/opr_impl.h b/dnn/src/naive/remap/opr_impl.h
index 5423d3cf..a5f20d4c 100644
--- a/dnn/src/naive/remap/opr_impl.h
+++ b/dnn/src/naive/remap/opr_impl.h
@@ -23,6 +23,33 @@ class RemapImpl final : public Remap {
return 0;
}
};
+
+class RemapBackwardDataImpl final : public RemapBackwardData {
+public:
+ using RemapBackwardData::RemapBackwardData;
+ void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
+ _megdnn_tensor_out grad, _megdnn_workspace workspace) override;
+ size_t get_workspace_in_bytes(const TensorLayout&,
+ const TensorLayout&,
+ const TensorLayout&) override {
+ return 0;
+ }
+};
+
+class RemapBackwardMatImpl final : public RemapBackwardMat {
+public:
+ using RemapBackwardMat::RemapBackwardMat;
+ void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
+ _megdnn_tensor_in diff, _megdnn_tensor_out grad,
+ _megdnn_workspace workspace) override;
+ size_t get_workspace_in_bytes(const TensorLayout&,
+ const TensorLayout&,
+ const TensorLayout&,
+ const TensorLayout&) override {
+ return 0;
+ }
+};
+
} // namespace naive
} // namespace megdnn
diff --git a/dnn/test/common/opr_trait.h b/dnn/test/common/opr_trait.h
index 7ff1c6c9..22f0f349 100644
--- a/dnn/test/common/opr_trait.h
+++ b/dnn/test/common/opr_trait.h
@@ -106,6 +106,8 @@ DEF(DeformablePSROIPoolingForward, 5, true, true);
DEF(DeformablePSROIPoolingBackward, 7, true, false);
DEF(BatchConvBiasForward, 5, true, true);
DEF(Remap, 3, true, true);
+DEF(RemapBackwardData, 3, true, false);
+DEF(RemapBackwardMat, 4, true, false);
} // namespace test
} // namespace megdnn
diff --git a/dnn/test/common/remap.h b/dnn/test/common/remap.h
index 9267364f..f2c22fcd 100644
--- a/dnn/test/common/remap.h
+++ b/dnn/test/common/remap.h
@@ -46,6 +46,9 @@ static inline std::vector get_nchw_args() {
for (auto border_type : border_mode_vec) {
param.format = fmt;
param.border_type = border_type;
+ args.emplace_back(param, TensorShape{70000, 1, 2, 2},
+ TensorShape{70000, 2, 2, 2}, TensorShape{70000, 1, 2, 2});
+
args.emplace_back(param, TensorShape{1, 1, 2, 2},
TensorShape{1, 2, 2, 2}, TensorShape{1, 1, 2, 2});
@@ -90,6 +93,9 @@ static inline std::vector get_nhwc_args() {
param.format = fmt;
param.border_type = border_type;
param.scalar = 12.f;
+ args.emplace_back(param, TensorShape{70000, 2, 2, 1},
+ TensorShape{70000, 2, 2, 2}, TensorShape{70000, 2, 2, 1});
+
args.emplace_back(param, TensorShape{1, 2, 2, 1},
TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 1});
diff --git a/dnn/test/cuda/remap.cpp b/dnn/test/cuda/remap.cpp
index dcbc1104..1bb8f2f7 100644
--- a/dnn/test/cuda/remap.cpp
+++ b/dnn/test/cuda/remap.cpp
@@ -40,6 +40,22 @@ TEST_F(CUDA, REMAP_NCHW_FLOAT) {
cb(dtype::Float32(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
+#define cb(data_type, data_rng) \
+ for (auto arg : args) { \
+ UniformFloatRNG map_rng( \
+ -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
+ checker.set_dtype(0, data_type) \
+ .set_dtype(1, dtype::Float32()) \
+ .set_dtype(2, data_type) \
+ .set_rng(0, &data_rng) \
+ .set_rng(1, &map_rng) \
+ .set_rng(2, &data_rng) \
+ .set_param(arg.param) \
+ .set_epsilon(1e-2) \
+ .execs({arg.src, arg.map_xy, arg.dst}); \
+ }
+ cb(dtype::BFloat16(), float_rng);
+#undef cb
}
TEST_F(CUDA, REMAP_NCHW_INT) {
@@ -87,6 +103,22 @@ TEST_F(CUDA, REMAP_NHWC_FLOAT) {
cb(dtype::Float32(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
+#define cb(data_type, data_rng) \
+ for (auto arg : args) { \
+ UniformFloatRNG map_rng( \
+ -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
+ checker.set_dtype(0, data_type) \
+ .set_dtype(1, dtype::Float32()) \
+ .set_dtype(2, data_type) \
+ .set_rng(0, &data_rng) \
+ .set_rng(1, &map_rng) \
+ .set_rng(2, &data_rng) \
+ .set_param(arg.param) \
+ .set_epsilon(1e-2) \
+ .execs({arg.src, arg.map_xy, arg.dst}); \
+ }
+ cb(dtype::BFloat16(), float_rng);
+#undef cb
}
TEST_F(CUDA, REMAP_NHWC_INT) {
@@ -114,6 +146,85 @@ TEST_F(CUDA, REMAP_NHWC_INT) {
#undef cb
}
+TEST_F(CUDA, REMAP_BACKWARD_DATA) {
+ Checker checker(handle_cuda());
+ std::vector args = get_nchw_args();
+ UniformFloatRNG float_rng(0, 255);
+#define cb(data_type, data_rng) \
+ for (auto arg : args) { \
+ UniformFloatRNG map_rng( \
+ -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
+ checker.set_dtype(1, data_type) \
+ .set_dtype(0, dtype::Float32()) \
+ .set_dtype(2, data_type) \
+ .set_rng(1, &data_rng) \
+ .set_rng(0, &map_rng) \
+ .set_rng(2, &data_rng) \
+ .set_param(arg.param) \
+ .execs({arg.map_xy, arg.dst, arg.src}); \
+ }
+ cb(dtype::Float32(), float_rng);
+#undef cb
+#define cb(data_type, data_rng) \
+ for (auto arg : args) { \
+ UniformFloatRNG map_rng( \
+ -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
+ checker.set_dtype(1, data_type) \
+ .set_dtype(0, dtype::Float32()) \
+ .set_dtype(2, data_type) \
+ .set_rng(1, &data_rng) \
+ .set_rng(0, &map_rng) \
+ .set_rng(2, &data_rng) \
+ .set_param(arg.param) \
+ .set_epsilon(1e-1) \
+ .execs({arg.map_xy, arg.dst, arg.src}); \
+ }
+ cb(dtype::BFloat16(), float_rng);
+#undef cb
+}
+
+TEST_F(CUDA, REMAP_BACKWARD_MAT) {
+ Checker checker(handle_cuda());
+ std::vector args = get_nchw_args();
+ UniformFloatRNG float_rng(0, 255);
+#define cb(data_type, data_rng) \
+ for (auto arg : args) { \
+ UniformFloatRNG map_rng( \
+ -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
+ checker.set_dtype(0, data_type) \
+ .set_dtype(1, dtype::Float32()) \
+ .set_dtype(2, data_type) \
+ .set_dtype(3, dtype::Float32()) \
+ .set_rng(0, &data_rng) \
+ .set_rng(1, &map_rng) \
+ .set_rng(2, &data_rng) \
+ .set_rng(3, &map_rng) \
+ .set_param(arg.param) \
+ .set_epsilon(2e-2) \
+ .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \
+ }
+ cb(dtype::Float32(), float_rng);
+#undef cb
+#define cb(data_type, data_rng) \
+ for (auto arg : args) { \
+ UniformFloatRNG map_rng( \
+ -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
+ checker.set_dtype(0, data_type) \
+ .set_dtype(1, dtype::Float32()) \
+ .set_dtype(2, data_type) \
+ .set_dtype(3, dtype::Float32()) \
+ .set_rng(0, &data_rng) \
+ .set_rng(1, &map_rng) \
+ .set_rng(2, &data_rng) \
+ .set_rng(3, &map_rng) \
+ .set_param(arg.param) \
+ .set_epsilon(1e-1) \
+ .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \
+ }
+ cb(dtype::BFloat16(), float_rng);
+#undef cb
+}
+
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, BENCHMARK_REMAP) {
@@ -144,13 +255,31 @@ TEST_F(CUDA, BENCHMARK_REMAP) {
.execs(shapes);
auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs(
shapes);
+
+ int size = 0;
+ if (dtype == dtype::Float32{}) {
+ size = sizeof(float);
+ printf("float32: ");
+ } else if (dtype == dtype::Float16{}) {
+ size = sizeof(dt_float16);
+ printf("float16: ");
+ } else if (dtype == dtype::Int8{}) {
+ size = sizeof(dt_int8);
+ printf("int8: ");
+ } else if (dtype == dtype::Uint8{}) {
+ size = sizeof(dt_uint8);
+ printf("uint8: ");
+ }
+ const TensorShape map_xy = shapes[1];
const TensorShape dst_layout = shapes[2];
- float calc_amount = dst_layout.total_nr_elems();
- printf("naive={%.3fms, %.3fMflops}, "
- "cuda={%.3fms, %.3fMflops}\n",
- t1 / RUN, calc_amount / (t1 / RUN * 1000), t2,
- calc_amount / (t2 * 1000));
+ float calc_amount = (dst_layout.total_nr_elems() * (4.f + 1.f) * size +
+ map_xy.total_nr_elems() * sizeof(float)) /
+ (1024 * 1024 * 1024);
+ printf("naive={%.3fms, %.3fGBPS}, "
+ "cuda={%.3fms, %.3fGBPS}\n",
+ t1 / RUN, calc_amount / (t1 / RUN) * 1e3, t2,
+ calc_amount / t2 * 1e3);
};
Param param;
param.imode = param::Remap::InterpolationMode::LINEAR;
diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py
index a2d25ac4..6220c599 100644
--- a/python_module/megengine/functional/__init__.py
+++ b/python_module/megengine/functional/__init__.py
@@ -84,6 +84,7 @@ from .nn import (
max_pool2d,
one_hot,
prelu,
+ remap,
roi_align,
roi_pooling,
softmax,
diff --git a/python_module/megengine/functional/nn.py b/python_module/megengine/functional/nn.py
index e579bea2..44438d83 100644
--- a/python_module/megengine/functional/nn.py
+++ b/python_module/megengine/functional/nn.py
@@ -706,6 +706,61 @@ def warp_perspective(
@wrap_io_tensor
+def remap(
+ inp: Tensor,
+ map_xy: Tensor,
+ border_mode: str = "REPLICATE",
+ scalar: float = 0.0,
+ interp_mode: str = "LINEAR",
+) -> Tensor:
+ r"""
+ Applies remap transformation to batched 2D images.
+
+ The input images are transformed to the output images by the tensor map_xy.
+ The output's H and W are same as map_xy's H and W.
+
+ :param inp: input image
+ :param map_xy: (batch, oh, ow, 2) transformation matrix
+ :param border_mode: pixel extrapolation method. Default: ``"REPLICATE"``
+ :param scalar: value used in case of a constant border. Default: ``0``
+ :param interp_mode: interpolation methods. Default: ``"LINEAR"``
+
+ Examples:
+
+ .. testcode::
+
+ import numpy as np
+ from megengine import tensor
+ import megengine.functional as F
+ inp_shape = (1, 1, 4, 4)
+ inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape))
+ map_xy_shape = (1, 2, 2, 2)
+ map_xy = tensor(np.array([[[1., 0.],[0., 1.]],
+ [[0., 1.],[0., 1.]]],
+ dtype=np.float32).reshape(map_xy_shape))
+ out = F.remap(inp, map_xy)
+ print(out.numpy())
+
+ Outputs:
+
+ .. testoutput::
+
+ [[[[1. 4.]
+ [4. 4.]]]]
+
+ """
+
+ return mgb.opr.remap(
+ inp,
+ map_xy,
+ border_type=border_mode,
+ scalar=scalar,
+ imode=interp_mode,
+ format="NCHW",
+ )
+
+
+@wrap_io_tensor
def eye(
n: int,
m: Optional[int] = None,
diff --git a/src/opr/impl/imgproc.cpp b/src/opr/impl/imgproc.cpp
index 33fb1e23..ed6c3dd5 100644
--- a/src/opr/impl/imgproc.cpp
+++ b/src/opr/impl/imgproc.cpp
@@ -443,4 +443,29 @@ void RemapForward::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
}
+#ifdef MGB_ENABLE_GRAD
+MGB_IMPL_OPR_GRAD(RemapForward) {
+ mgb_assert(opr.input().size() == 2);
+ if (wrt_idx == 0) {
+ SymbolVar grad =
+ RemapBackwardData::make(opr.input(1), out_grad[0],
+ opr.input(0), opr.param());
+ return grad.node();
+ } else if (wrt_idx == 1) {
+ SymbolVar grad =
+ RemapBackwardMat::make(opr.input(0), opr.input(1),
+ out_grad[0], opr.param());
+ return grad.node();
+ } else
+ return InvalidGrad::make(opr, wrt_idx);
+}
+#endif
+
+/* ====================== RemapBackward ====================== */
+
+MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardData);
+MEGDNN_OPR_INIT3(RemapBackwardData, "remap_bwd_data", 2, false);
+MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardMat);
+MEGDNN_OPR_INIT3(RemapBackwardMat, "remap_bwd_mat", 1, true);
+
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
diff --git a/src/opr/impl/imgproc.sereg.h b/src/opr/impl/imgproc.sereg.h
index 949a9740..a8bfd2ca 100644
--- a/src/opr/impl/imgproc.sereg.h
+++ b/src/opr/impl/imgproc.sereg.h
@@ -97,6 +97,8 @@ namespace opr {
MGB_SEREG_OPR(ResizeBackward, 2);
MGB_SEREG_OPR(Remap, 2);
+ MGB_SEREG_OPR(RemapBackwardData, 3);
+ MGB_SEREG_OPR(RemapBackwardMat, 3);
//! current warp affine version
using WarpAffineV1 = opr::WarpAffine;
diff --git a/src/opr/include/megbrain/opr/imgproc.h b/src/opr/include/megbrain/opr/imgproc.h
index 611d7bc1..36e85372 100644
--- a/src/opr/include/megbrain/opr/imgproc.h
+++ b/src/opr/include/megbrain/opr/imgproc.h
@@ -74,7 +74,7 @@ size_t get_workspace_size_bytes(
const TensorShapeArray& output_shapes) const override;
void record_execute_deps(ExecDependencyArray& deps) override;
-}; // namespace opr
+};
using WarpPerspective = WarpPerspectiveForward;
MGB_DEFINE_OPR_CLASS(
@@ -98,7 +98,7 @@ static SymbolVar make(SymbolVar mat, SymbolVar mat_idx, SymbolVar out_diff,
const OperatorNodeConfig& config = {});
void scn_do_execute() override;
-}; // namespace mgb
+};
MGB_DEFINE_OPR_CLASS(
WarpPerspectiveBackwardMat,
@@ -119,8 +119,7 @@ static SymbolVar make(SymbolVar src, SymbolVar mat, SymbolVar mat_idx,
const OperatorNodeConfig& config = {});
void scn_do_execute() override;
-}
-;
+};
/* ============================= shape infer ============================== */
//! param: src, dst
@@ -164,8 +163,7 @@ size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void record_execute_deps(ExecDependencyArray& deps) override;
-}
-;
+};
using Resize = ResizeForward;
MGB_DEFINE_OPR_CLASS(ResizeBackward,
@@ -177,8 +175,7 @@ ResizeBackward(VarNode* out_diff, VarNode* in_for_shape, const Param& param,
static SymbolVar make(SymbolVar out_diff, SymbolVar in_for_shape,
const Param& param = {},
const OperatorNodeConfig& config = {});
-}
-;
+};
MGB_DEFINE_OPR_CLASS(RemapForward,
intl::MegDNNOprWrapperFwd) // {
@@ -192,10 +189,31 @@ static SymbolVar make(SymbolVar in_tensor, SymbolVar map,
private:
void init_output_dtype() override;
-}
-;
+};
using Remap = RemapForward;
+MGB_DEFINE_OPR_CLASS(RemapBackwardData,
+ intl::MegDNNOprWrapperBwd) // {
+public:
+RemapBackwardData(VarNode *map, VarNode *out_diff,
+ VarNode *in_for_shape, const Param ¶m,
+ const OperatorNodeConfig &config);
+
+static SymbolVar make(SymbolVar map, SymbolVar out_diff,
+ SymbolVar in_for_shape, const Param ¶m = {},
+ const OperatorNodeConfig &config = {});
+};
+
+MGB_DEFINE_OPR_CLASS(RemapBackwardMat,
+ intl::MegDNNOprWrapperBwd) // {
+public:
+RemapBackwardMat(VarNode *src, VarNode *map, VarNode *out_diff,
+ const Param ¶m, const OperatorNodeConfig &config);
+
+static SymbolVar make(SymbolVar src, SymbolVar map, SymbolVar out_diff,
+ const Param ¶m = {}, const OperatorNodeConfig &config = {});
+};
+
/*!
* \brief apply affine transformation to batched 2D images
*
@@ -238,8 +256,7 @@ size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void record_execute_deps(ExecDependencyArray& deps) override;
-}
-;
+};
using WarpAffine = WarpAffineForward;
} // opr
diff --git a/src/opr/test/imgproc.cpp b/src/opr/test/imgproc.cpp
index 68d9e835..ba4824e5 100644
--- a/src/opr/test/imgproc.cpp
+++ b/src/opr/test/imgproc.cpp
@@ -640,11 +640,11 @@ TEST(TestOprImgproc, WarpAffineForward) {
}
TEST(TestOprImgproc, Remap_NCHW) {
- constexpr size_t N = 2, C = 8;
+ constexpr size_t N = 2, C = 8, OH = 10, OW = 10;
opr::Remap::Param param;
using Checker = AutoOprChecker<2, 1>;
- TensorShape out_shp{N, C, 10, 10};
+ TensorShape out_shp{N, C, OH, OW};
param.format = opr::Remap::Param::Format::NCHW;
auto make_graph = [&](const Checker::SymInpArray &inputs) ->
Checker::SymOutArray {
@@ -657,12 +657,34 @@ TEST(TestOprImgproc, Remap_NCHW) {
opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), {});
};
+ std::mt19937 rng(next_rand_seed());
+ auto rand_real = [&](double lo, double hi) {
+ auto real = rng() / (std::mt19937::max() + 1.0) * (hi - lo) + lo;
+ if(std::abs(std::round(real) - real) <= 1e-2)
+ return real + 1e-1;
+ return real;
+ };
+ auto rand_real2 = [&](double range) {
+ return rand_real(-range, range);
+ };
+ auto gen_mat = [&](HostTensorND& mat) {
+ auto ptr = mat.ptr();
+ for (size_t i = 0; i < N; ++ i) {
+ for(size_t j = 0; j < OH * OW * 2; j++) {
+ //! undifferentiable when map is an integer
+ ptr[j] = static_cast(rand_real2(20));
+ }
+ ptr += OH * OW * 2;
+ }
+ mgb_assert(ptr == mat.ptr() + mat.shape().total_nr_elems());
+ };
+
Checker::RunOptions opt;
Checker(make_graph, fwd, CompNode::load("cpu1"))
- .disable_grad_check()
- .run({TensorShape{N, C, 3, 20}, TensorShape{N, 10, 10, 2}}, opt)
- .run({TensorShape{N, C, 6, 5}, TensorShape{N, 10, 10, 2}}, opt)
- .run({TensorShape{N, C, 20, 20}, TensorShape{N, 10, 10, 2}}, opt);
+ .set_input_generator(1, gen_mat)
+ .run({TensorShape{N, C, 3, 20}, TensorShape{N, OH, OW, 2}}, opt)
+ .run({TensorShape{N, C, 6, 5}, TensorShape{N, OH, OW, 2}}, opt)
+ .run({TensorShape{N, C, 20, 20}, TensorShape{N, OH, OW, 2}}, opt);
}
TEST(TestOprImgproc, Remap_NHWC) {
@@ -690,4 +712,5 @@ TEST(TestOprImgproc, Remap_NHWC) {
.run({TensorShape{N, 6, 5, C}, TensorShape{N, 10, 10, 2}}, opt)
.run({TensorShape{N, 20, 20, C}, TensorShape{N, 10, 10, 2}}, opt);
}
+
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}