From dedb7a3f141ba1b68969693e62e31dc653897c12 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 23 Jun 2020 21:33:20 +0800 Subject: [PATCH] feat(dnn/cuda): add cuda remap GitOrigin-RevId: 40a2a2ce24bf59a9d74c53f6f0c769313de94c51 --- dnn/src/cuda/remap/common.h | 33 ++++++ dnn/src/cuda/remap/forward.cpp | 93 +++++++++++++++ dnn/src/cuda/remap/forward.cu | 238 +++++++++++++++++++++++++++++++++++++++ dnn/src/cuda/remap/opr_impl.cpp | 26 ----- dnn/src/cuda/remap/opr_impl.h | 5 +- dnn/src/naive/remap/opr_impl.cpp | 55 ++++----- dnn/test/cuda/remap.cpp | 203 +++++++++++++++++++++++++++++++++ 7 files changed, 593 insertions(+), 60 deletions(-) create mode 100644 dnn/src/cuda/remap/common.h create mode 100644 dnn/src/cuda/remap/forward.cpp create mode 100644 dnn/src/cuda/remap/forward.cu delete mode 100644 dnn/src/cuda/remap/opr_impl.cpp create mode 100644 dnn/test/cuda/remap.cpp diff --git a/dnn/src/cuda/remap/common.h b/dnn/src/cuda/remap/common.h new file mode 100644 index 00000000..82593c9e --- /dev/null +++ b/dnn/src/cuda/remap/common.h @@ -0,0 +1,33 @@ +/** + * \file dnn/src/cuda/remap/common.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include +#include "megcore_cdefs.h" +#include "src/common/cv/enums.h" +#include "src/common/opr_param_defs_enumv.cuh" + +namespace megdnn { +namespace cuda { +namespace remap { + +// all these kernels use LINEAR interpolation + +template +void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, + int C, int IH, int IW, int OH, int OW, float scalar, + int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream); + +} // namespace remap +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/remap/forward.cpp b/dnn/src/cuda/remap/forward.cpp new file mode 100644 index 00000000..7ccf1ef5 --- /dev/null +++ b/dnn/src/cuda/remap/forward.cpp @@ -0,0 +1,93 @@ +/** + * \file dnn/src/cuda/remap/forward.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/config/config.h" +#include "src/common/opr_param_defs_enumv.cuh" +#include "src/cuda/remap/common.h" +#include "src/cuda/remap/opr_impl.h" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; + +void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, + _megdnn_tensor_in dst, _megdnn_workspace workspace) { + check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); + auto stream = cuda_stream(this->handle()); + int N, C, IH, IW, OH, OW; + ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0; + OH = map_xy.layout.shape[1]; + OW = map_xy.layout.shape[2]; + + megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR, + "only support LINEAR interpolationMode"); + + if (param().format == param::Remap::Format::NCHW) { + N = src.layout.shape[0]; + C = src.layout.shape[1]; + IH = src.layout.shape[2]; + IW = src.layout.shape[3]; + S_IN = src.layout.stride[0]; + S_IC = src.layout.stride[1]; + S_IH = src.layout.stride[2]; + S_IW = src.layout.stride[3]; + } else if (param().format == param::Remap::Format::NHWC) { + N = src.layout.shape[0]; + C = src.layout.shape[3]; + IH = src.layout.shape[1]; + IW = src.layout.shape[2]; + } else { + megdnn_throw("unsupported format, cuda remap"); + } + +#define cb(dt, _format, bmode) \ + if (param().format == param::Remap::Format::_format && \ + param().border_type == param::Remap::BorderMode::bmode) { \ + using ctype = DTypeTrait
::ctype; \ + remap::forward_proxy( \ + src.compatible_ptr(), \ + map_xy.compatible_ptr(), \ + dst.compatible_ptr(), N, C, IH, IW, OH, OW, \ + param().scalar, S_IN, S_IC, S_IH, S_IW, stream); \ + break; \ + } + +#define support_dtype(dt) \ + case DTypeTrait
::enumv: { \ + cb(dt, NCHW, CONSTANT); \ + cb(dt, NCHW, REPLICATE); \ + cb(dt, NCHW, REFLECT); \ + cb(dt, NCHW, REFLECT_101); \ + cb(dt, NCHW, WRAP); \ + cb(dt, NHWC, CONSTANT); \ + cb(dt, NHWC, REPLICATE); \ + cb(dt, NHWC, REFLECT); \ + cb(dt, NHWC, REFLECT_101); \ + cb(dt, NHWC, WRAP); \ + megdnn_throw("unsupported border type in remap cuda"); \ + } + + switch (src.layout.dtype.enumv()) { + support_dtype(dtype::Float32) + MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)) + support_dtype(dtype::Int8) + support_dtype(dtype::Uint8) + default: + megdnn_throw("unsupported dtype in remap cuda"); + } + +#undef supported_dtype +#undef cb +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/remap/forward.cu b/dnn/src/cuda/remap/forward.cu new file mode 100644 index 00000000..fcd6f19f --- /dev/null +++ b/dnn/src/cuda/remap/forward.cu @@ -0,0 +1,238 @@ +/** + * \file dnn/src/cuda/remap/forward.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include +#include +#include "src/common/rounding_converter.cuh" +#include "src/cuda/cv/kernel_common.cuh" +#include "src/cuda/remap/common.h" +#include "src/cuda/utils.cuh" + +using namespace megdnn; +using namespace cuda; +using namespace remap; +using namespace rounding; + +namespace { + +template +struct DirectSrcVisitor { + const ctype* ptr; + + __device__ __forceinline__ const ctype* get(int batch, int im_size) { + return ptr + batch * im_size; + } + + void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; } +}; + +template +__device__ inline int get_offset(int height, int width, int channel, int h, + int w, int c); + +template <> +__device__ inline int get_offset( + int height, int width, int channel, int h, int w, int c) { + return channel * h * w + height * w + width; +} + +template <> +__device__ inline int get_offset( + int height, int width, int channel, int h, int w, int c) { + return height * w * c + width * c + channel; +} + +template +struct GetSrcData { + __device__ static inline ctype get(const ctype* src, int height, int width, + int channel, int h, int w, int c, + float) { + height = megcv::border_interpolate(height, h); + width = megcv::border_interpolate(width, w); + return src[get_offset(height, width, channel, h, w, c)]; + } +}; + +template +struct GetSrcData { + __device__ static inline ctype get(const ctype* src, int height, int width, + int channel, int h, int w, int c, + float scalar) { + RoundingConverter round_converter; + return (height >= 0 && height < h && width >= 0 && width < w) + ? src[get_offset(height, width, channel, h, w, + c)] + : round_converter(scalar); + } +}; + +template +__global__ void kern_general(SrcVisitor src, const float* map_xy, + ctype* __restrict dst, int C, int IH, int IW, + int OH, int OW, int S_IN, int S_IC, int S_IH, + int S_IW, float scalar) { + int ow = blockIdx.x * blockDim.x + threadIdx.x; + int oh = blockIdx.y * blockDim.y + threadIdx.y; + const ctype* __restrict sptr = src.get(blockIdx.z, S_IN); + dst += blockIdx.z * C * OH * OW; + map_xy += blockIdx.z * 2 * OH * OW; + RoundingConverter round_converter; + + if (ow < OW && oh < OH) { + float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; + float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; + int col = (int)floor(index_col); + int row = (int)floor(index_row); + float v = index_col - col; + float u = index_row - row; + for (int c = 0; c < C; ++c) { + ctype a00 = GetSrcData::get(sptr, row + 0, col + 0, c, IH, + IW, C, scalar); + ctype a01 = GetSrcData::get(sptr, row + 0, col + 1, c, IH, + IW, C, scalar); + ctype a10 = GetSrcData::get(sptr, row + 1, col + 0, c, IH, + IW, C, scalar); + ctype a11 = GetSrcData::get(sptr, row + 1, col + 1, c, IH, + IW, C, scalar); + dst[get_offset(oh, ow, c, OH, OW, + C)] = + round_converter(a00 * (1.f - u) * (1.f - v) + + a01 * (1.f - u) * v + a10 * (1.f - v) * u + + a11 * u * v); + } + } +} + +template +__global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, + ctype* __restrict dst, int C, int IH, int IW, + int OH, int OW, float scalar) { + int ow = blockIdx.x * blockDim.x + threadIdx.x; + int oh = blockIdx.y * blockDim.y + threadIdx.y; + const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); + dst += blockIdx.z * C * OH * OW; + map_xy += blockIdx.z * 2 * OH * OW; + RoundingConverter round_converter; + + if (ow < OW && oh < OH) { + float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; + float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; + int col = (int)floor(index_col); + int row = (int)floor(index_row); + float v = index_col - col; + float u = index_row - row; + for (int c = 0; c < C; ++c) { + ctype a00 = GetSrcData::get(sptr, row + 0, col + 0, c, IH, + IW, C, scalar); + ctype a01 = GetSrcData::get(sptr, row + 0, col + 1, c, IH, + IW, C, scalar); + ctype a10 = GetSrcData::get(sptr, row + 1, col + 0, c, IH, + IW, C, scalar); + ctype a11 = GetSrcData::get(sptr, row + 1, col + 1, c, IH, + IW, C, scalar); + dst[get_offset(oh, ow, c, OH, OW, + C)] = + round_converter(a00 * (1.f - u) * (1.f - v) + + a01 * (1.f - u) * v + a10 * (1.f - v) * u + + a11 * u * v); + } + } +} + +template +void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst, + int N, int C, int IH, int IW, int OH, int OW, + float scalar, int S_IN, int S_IC, int S_IH, int S_IW, + cudaStream_t stream) { + const int BX = 32, BY = 16; + + const int max_batch_size = 65535; + while (N) { + size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; + dim3 threads(BX, BY); + dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); + + if (format == param_enumv::Remap::Format::NCHW) { + kern_general + <<>>(src, map_xy, dst, C, IH, + IW, OH, OW, S_IN, S_IC, + S_IH, S_IW, scalar); + } else if (format == param_enumv::Remap::Format::NHWC) { + kern_general_nhwc + <<>>(src, map_xy, dst, C, IH, + IW, OH, OW, scalar); + } + + N -= curr_batch_size; + src.move_batch(curr_batch_size, C * IH * IW); + dst += curr_batch_size * C * OH * OW; + } +} + +} // anonymous namespace + +namespace megdnn { +namespace cuda { +namespace remap { + +template +void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, + int C, int IH, int IW, int OH, int OW, float scalar, + int S_IN, int S_IC, int S_IH, int S_IW, + cudaStream_t stream) { + DirectSrcVisitor visitor; + visitor.ptr = src; + using SrcVisitor = DirectSrcVisitor; + dispatch_with_visitor( + visitor, map_xy, dst, N, C, IH, IW, OH, OW, scalar, S_IN, S_IC, + S_IH, S_IW, stream); + after_kernel_launch(); +} + +#define INST(ctype, format, bmode) \ + template void forward_proxy( \ + const ctype* src, const float*, ctype*, int, int, int, int, int, \ + int, float, int, int, int, int, cudaStream_t); + +#define FOR_FORMAT_BMODE(ctype) \ + INST(ctype, NCHW, BORDER_CONSTANT) \ + INST(ctype, NCHW, BORDER_REPLICATE) \ + INST(ctype, NCHW, BORDER_REFLECT) \ + INST(ctype, NCHW, BORDER_REFLECT_101) \ + INST(ctype, NCHW, BORDER_WRAP) \ + INST(ctype, NHWC, BORDER_CONSTANT) \ + INST(ctype, NHWC, BORDER_REPLICATE) \ + INST(ctype, NHWC, BORDER_REFLECT) \ + INST(ctype, NHWC, BORDER_REFLECT_101) \ + INST(ctype, NHWC, BORDER_WRAP) + +FOR_FORMAT_BMODE(float) +MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) +FOR_FORMAT_BMODE(int8_t) +FOR_FORMAT_BMODE(uint8_t) + +#undef FOR_BMODE +#undef INST +} // namespace remap +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/remap/opr_impl.cpp b/dnn/src/cuda/remap/opr_impl.cpp deleted file mode 100644 index 191f4cda..00000000 --- a/dnn/src/cuda/remap/opr_impl.cpp +++ /dev/null @@ -1,26 +0,0 @@ -/** - * \file dnn/src/opencl/cuda/opr_impl.cpp - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ - -#include "src/cuda/remap/opr_impl.h" -#include "megdnn/config/config.h" -#include "src/common/utils.h" - -using namespace megdnn; -using namespace cuda; - -void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, - _megdnn_tensor_in dst, _megdnn_workspace workspace) { - check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); - megdnn_throw("megdnn currently do not support remap in cuda"); -} - -// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/remap/opr_impl.h b/dnn/src/cuda/remap/opr_impl.h index f4fd4f31..a812e217 100644 --- a/dnn/src/cuda/remap/opr_impl.h +++ b/dnn/src/cuda/remap/opr_impl.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/opencl/cuda/opr_impl.h + * \file dnn/src/cuda/remap/opr_impl.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -16,13 +16,16 @@ namespace megdnn { namespace cuda { class RemapImpl final : public Remap { using Remap::Remap; + void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out, _megdnn_workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, const TensorLayout&) override { return 0; } }; + } // namespace cuda } // namespace megdnn diff --git a/dnn/src/naive/remap/opr_impl.cpp b/dnn/src/naive/remap/opr_impl.cpp index 0fe67a41..42e72aa3 100644 --- a/dnn/src/naive/remap/opr_impl.cpp +++ b/dnn/src/naive/remap/opr_impl.cpp @@ -37,42 +37,29 @@ inline int get_offset(int height, int width, } template + param::Remap::BorderMode bordertype> struct GetSrcData { - static inline DataType get(const DataType*, int, int, int, int, int, int, - int); + static inline DataType get(const DataType* src, int height, int width, + int channel, int h, int w, int c, float, + std::function) { + height = megcv::border_interpolate(height, h); + width = megcv::border_interpolate(width, w); + return src[get_offset(height, width, channel, h, w, c)]; + } }; template struct GetSrcData { static inline DataType get(const DataType* src, int height, int width, - int channel, int h, int w, int c, float scalar) { + int channel, int h, int w, int c, float scalar, + std::function round) { return (height >= 0 && height < h && width >= 0 && width < w) ? src[get_offset(height, width, channel, h, w, c)] - : static_cast(std::round(scalar)); + : static_cast(round(scalar)); } }; -#define cb(bmode) \ - template \ - struct GetSrcData { \ - static inline DataType get(const DataType* src, int height, int width, \ - int channel, int h, int w, int c, float) { \ - height = megcv::border_interpolate< \ - param::Remap::BorderMode::bmode>(height, h); \ - width = megcv::border_interpolate< \ - param::Remap::BorderMode::bmode>(width, w); \ - return src[get_offset(height, width, channel, h, w, c)]; \ - } \ - }; - -cb(REPLICATE); -cb(REFLECT); -cb(REFLECT_101); -cb(WRAP); -#undef cb - template void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, @@ -92,20 +79,20 @@ void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, for (int c = 0; c < C; ++c) { DataType a00 = GetSrcData::get( - src, row + 0, col + 0, c, IH, IW, C, - scalar); + src, row + 0, col + 0, c, IH, IW, C, scalar, + round); DataType a01 = GetSrcData::get( - src, row + 0, col + 1, c, IH, IW, C, - scalar); + src, row + 0, col + 1, c, IH, IW, C, scalar, + round); DataType a10 = GetSrcData::get( - src, row + 1, col + 0, c, IH, IW, C, - scalar); + src, row + 1, col + 0, c, IH, IW, C, scalar, + round); DataType a11 = GetSrcData::get( - src, row + 1, col + 1, c, IH, IW, C, - scalar); + src, row + 1, col + 1, c, IH, IW, C, scalar, + round); dst[get_offset(h, w, c, OH, OW, C)] = static_cast( @@ -139,11 +126,13 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, C = src.layout.shape[1]; IH = src.layout.shape[2]; IW = src.layout.shape[3]; - } else { + } else if (param().format == param::Remap::Format::NHWC) { N = src.layout.shape[0]; C = src.layout.shape[3]; IH = src.layout.shape[1]; IW = src.layout.shape[2]; + } else { + megdnn_throw("unsupported format"); } OH = map_xy.layout.shape[1]; OW = map_xy.layout.shape[2]; diff --git a/dnn/test/cuda/remap.cpp b/dnn/test/cuda/remap.cpp new file mode 100644 index 00000000..dcbc1104 --- /dev/null +++ b/dnn/test/cuda/remap.cpp @@ -0,0 +1,203 @@ +/** + * \file dnn/test/cuda/remap.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/common/remap.h" +#include "test/common/benchmarker.h" +#include "test/common/checker.h" +#include "test/common/rng.h" +#include "test/cuda/benchmark.h" +#include "test/cuda/fixture.h" + +namespace megdnn { +namespace test { +namespace remap { + +TEST_F(CUDA, REMAP_NCHW_FLOAT) { + Checker checker(handle_cuda()); + std::vector args = get_nchw_args(); + UniformFloatRNG float_rng(0, 255); +#define cb(data_type, data_rng) \ + for (auto arg : args) { \ + UniformFloatRNG map_rng( \ + -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ + checker.set_dtype(0, data_type) \ + .set_dtype(1, dtype::Float32()) \ + .set_dtype(2, data_type) \ + .set_rng(0, &data_rng) \ + .set_rng(1, &map_rng) \ + .set_rng(2, &data_rng) \ + .set_param(arg.param) \ + .execs({arg.src, arg.map_xy, arg.dst}); \ + } + cb(dtype::Float32(), float_rng); + cb(dtype::Float16(), float_rng); +#undef cb +} + +TEST_F(CUDA, REMAP_NCHW_INT) { + Checker checker(handle_cuda()); + std::vector args = get_nchw_args(); + UniformIntRNG uint8_rng(0, 255); + UniformIntRNG int8_rng(-128, 127); + +#define cb(data_type, data_rng) \ + for (auto arg : args) { \ + UniformFloatRNG map_rng( \ + -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ + checker.set_dtype(0, data_type) \ + .set_dtype(1, dtype::Float32()) \ + .set_dtype(2, data_type) \ + .set_rng(0, &data_rng) \ + .set_rng(1, &map_rng) \ + .set_rng(2, &data_rng) \ + .set_epsilon(1) \ + .set_param(arg.param) \ + .execs({arg.src, arg.map_xy, arg.dst}); \ + } + cb(dtype::Int8(), int8_rng); + cb(dtype::Uint8(), uint8_rng); +#undef cb +} + +TEST_F(CUDA, REMAP_NHWC_FLOAT) { + Checker checker(handle_cuda()); + std::vector args = get_nhwc_args(); + UniformFloatRNG float_rng(0, 255); +#define cb(data_type, data_rng) \ + for (auto arg : args) { \ + UniformFloatRNG map_rng( \ + -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ + checker.set_dtype(0, data_type) \ + .set_dtype(1, dtype::Float32()) \ + .set_dtype(2, data_type) \ + .set_rng(0, &data_rng) \ + .set_rng(1, &map_rng) \ + .set_rng(2, &data_rng) \ + .set_param(arg.param) \ + .execs({arg.src, arg.map_xy, arg.dst}); \ + } + cb(dtype::Float32(), float_rng); + cb(dtype::Float16(), float_rng); +#undef cb +} + +TEST_F(CUDA, REMAP_NHWC_INT) { + Checker checker(handle_cuda()); + std::vector args = get_nhwc_args(); + UniformIntRNG uint8_rng(0, 255); + UniformIntRNG int8_rng(-128, 127); + +#define cb(data_type, data_rng) \ + for (auto arg : args) { \ + UniformFloatRNG map_rng( \ + -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ + checker.set_dtype(0, data_type) \ + .set_dtype(1, dtype::Float32()) \ + .set_dtype(2, data_type) \ + .set_rng(0, &data_rng) \ + .set_rng(1, &map_rng) \ + .set_rng(2, &data_rng) \ + .set_epsilon(1) \ + .set_param(arg.param) \ + .execs({arg.src, arg.map_xy, arg.dst}); \ + } + cb(dtype::Int8(), int8_rng); + cb(dtype::Uint8(), uint8_rng); +#undef cb +} + +#if MEGDNN_WITH_BENCHMARK + +TEST_F(CUDA, BENCHMARK_REMAP) { + using Param = param::Remap; + auto run = [&](const TensorShapeArray& shapes, Param param, DType dtype) { + auto handle_cpu = create_cpu_handle(2); + Benchmarker benchmarker_naive(handle_cpu.get()); + CUBenchmarker benchmarker_cuda(handle_cuda()); + UniformIntRNG rng(0, 0xff); + UniformFloatRNG map_rng( + -2, std::max(shapes[1].shape[1], shapes[1].shape[2]) + 2); + benchmarker_naive.set_rng(0, &rng); + benchmarker_cuda.set_rng(0, &rng); + benchmarker_naive.set_rng(1, &map_rng); + benchmarker_cuda.set_rng(1, &map_rng); + benchmarker_naive.set_rng(2, &rng); + benchmarker_cuda.set_rng(2, &rng); + + benchmarker_naive.set_dtype(1, dtype::Float32()); + benchmarker_cuda.set_dtype(1, dtype::Float32()); + benchmarker_naive.set_dtype(0, dtype).set_dtype(2, dtype); + benchmarker_cuda.set_dtype(0, dtype).set_dtype(2, dtype); + + size_t RUN = 10; + auto t1 = benchmarker_naive.set_display(false) + .set_times(RUN) + .set_param(param) + .execs(shapes); + auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs( + shapes); + const TensorShape dst_layout = shapes[2]; + + float calc_amount = dst_layout.total_nr_elems(); + printf("naive={%.3fms, %.3fMflops}, " + "cuda={%.3fms, %.3fMflops}\n", + t1 / RUN, calc_amount / (t1 / RUN * 1000), t2, + calc_amount / (t2 * 1000)); + }; + Param param; + param.imode = param::Remap::InterpolationMode::LINEAR; + param.format = param::Remap::Format::NHWC; + param.border_type = param::Remap::BorderMode::CONSTANT; + run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, + dtype::Float32{}); + run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, + dtype::Float16{}); + run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, + dtype::Uint8{}); + run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, + dtype::Int8{}); + param.border_type = param::Remap::BorderMode::REPLICATE; + run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, + dtype::Float32{}); + run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, + dtype::Float16{}); + run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, + dtype::Uint8{}); + run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, + dtype::Int8{}); + param.format = param::Remap::Format::NCHW; + param.border_type = param::Remap::BorderMode::CONSTANT; + run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, + dtype::Float32{}); + run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, + dtype::Float16{}); + run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, + dtype::Uint8{}); + run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, + dtype::Int8{}); + param.border_type = param::Remap::BorderMode::REPLICATE; + run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, + dtype::Float32{}); + run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, + dtype::Float16{}); + run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, + dtype::Uint8{}); + run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, + dtype::Int8{}); +} + +#endif + +} // namespace remap +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen