Browse Source

feat(dnn/cuda): add cuda remap

GitOrigin-RevId: 40a2a2ce24
release-0.6
Megvii Engine Team 5 years ago
parent
commit
dedb7a3f14
7 changed files with 593 additions and 60 deletions
  1. +33
    -0
      dnn/src/cuda/remap/common.h
  2. +93
    -0
      dnn/src/cuda/remap/forward.cpp
  3. +238
    -0
      dnn/src/cuda/remap/forward.cu
  4. +0
    -26
      dnn/src/cuda/remap/opr_impl.cpp
  5. +4
    -1
      dnn/src/cuda/remap/opr_impl.h
  6. +22
    -33
      dnn/src/naive/remap/opr_impl.cpp
  7. +203
    -0
      dnn/test/cuda/remap.cpp

+ 33
- 0
dnn/src/cuda/remap/common.h View File

@@ -0,0 +1,33 @@
/**
* \file dnn/src/cuda/remap/common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <cuda_runtime_api.h>
#include "megcore_cdefs.h"
#include "src/common/cv/enums.h"
#include "src/common/opr_param_defs_enumv.cuh"

namespace megdnn {
namespace cuda {
namespace remap {

// all these kernels use LINEAR interpolation

template <typename ctype, const uint32_t format, ::BorderMode bmode>
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, float scalar,
int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream);

} // namespace remap
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 93
- 0
dnn/src/cuda/remap/forward.cpp View File

@@ -0,0 +1,93 @@
/**
* \file dnn/src/cuda/remap/forward.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megdnn/config/config.h"
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/remap/common.h"
#include "src/cuda/remap/opr_impl.h"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
_megdnn_tensor_in dst, _megdnn_workspace workspace) {
check_exec(src.layout, map_xy.layout, dst.layout, workspace.size);
auto stream = cuda_stream(this->handle());
int N, C, IH, IW, OH, OW;
ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0;
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];

megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR,
"only support LINEAR interpolationMode");

if (param().format == param::Remap::Format::NCHW) {
N = src.layout.shape[0];
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
S_IN = src.layout.stride[0];
S_IC = src.layout.stride[1];
S_IH = src.layout.stride[2];
S_IW = src.layout.stride[3];
} else if (param().format == param::Remap::Format::NHWC) {
N = src.layout.shape[0];
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
} else {
megdnn_throw("unsupported format, cuda remap");
}

#define cb(dt, _format, bmode) \
if (param().format == param::Remap::Format::_format && \
param().border_type == param::Remap::BorderMode::bmode) { \
using ctype = DTypeTrait<dt>::ctype; \
remap::forward_proxy<ctype, param_enumv::Remap::Format::_format, \
::BorderMode::BORDER_##bmode>( \
src.compatible_ptr<ctype>(), \
map_xy.compatible_ptr<dt_float32>(), \
dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \
param().scalar, S_IN, S_IC, S_IH, S_IW, stream); \
break; \
}

#define support_dtype(dt) \
case DTypeTrait<dt>::enumv: { \
cb(dt, NCHW, CONSTANT); \
cb(dt, NCHW, REPLICATE); \
cb(dt, NCHW, REFLECT); \
cb(dt, NCHW, REFLECT_101); \
cb(dt, NCHW, WRAP); \
cb(dt, NHWC, CONSTANT); \
cb(dt, NHWC, REPLICATE); \
cb(dt, NHWC, REFLECT); \
cb(dt, NHWC, REFLECT_101); \
cb(dt, NHWC, WRAP); \
megdnn_throw("unsupported border type in remap cuda"); \
}

switch (src.layout.dtype.enumv()) {
support_dtype(dtype::Float32)
MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16))
support_dtype(dtype::Int8)
support_dtype(dtype::Uint8)
default:
megdnn_throw("unsupported dtype in remap cuda");
}

#undef supported_dtype
#undef cb
}

// vim: syntax=cpp.doxygen

+ 238
- 0
dnn/src/cuda/remap/forward.cu View File

@@ -0,0 +1,238 @@
/**
* \file dnn/src/cuda/remap/forward.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include "src/common/rounding_converter.cuh"
#include "src/cuda/cv/kernel_common.cuh"
#include "src/cuda/remap/common.h"
#include "src/cuda/utils.cuh"

using namespace megdnn;
using namespace cuda;
using namespace remap;
using namespace rounding;

namespace {

template <typename ctype>
struct DirectSrcVisitor {
const ctype* ptr;

__device__ __forceinline__ const ctype* get(int batch, int im_size) {
return ptr + batch * im_size;
}

void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; }
};

template <const uint32_t format>
__device__ inline int get_offset(int height, int width, int channel, int h,
int w, int c);

template <>
__device__ inline int get_offset<param_enumv::Remap::Format::NCHW>(
int height, int width, int channel, int h, int w, int c) {
return channel * h * w + height * w + width;
}

template <>
__device__ inline int get_offset<param_enumv::Remap::Format::NHWC>(
int height, int width, int channel, int h, int w, int c) {
return height * w * c + width * c + channel;
}

template <typename ctype, const uint32_t format, ::BorderMode bmode>
struct GetSrcData {
__device__ static inline ctype get(const ctype* src, int height, int width,
int channel, int h, int w, int c,
float) {
height = megcv::border_interpolate<bmode>(height, h);
width = megcv::border_interpolate<bmode>(width, w);
return src[get_offset<format>(height, width, channel, h, w, c)];
}
};

template <typename ctype, const uint32_t format>
struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> {
__device__ static inline ctype get(const ctype* src, int height, int width,
int channel, int h, int w, int c,
float scalar) {
RoundingConverter<ctype> round_converter;
return (height >= 0 && height < h && width >= 0 && width < w)
? src[get_offset<format>(height, width, channel, h, w,
c)]
: round_converter(scalar);
}
};

template <typename ctype, typename SrcVisitor, ::BorderMode bmode>
__global__ void kern_general(SrcVisitor src, const float* map_xy,
ctype* __restrict dst, int C, int IH, int IW,
int OH, int OW, int S_IN, int S_IC, int S_IH,
int S_IW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN);
dst += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
RoundingConverter<ctype> round_converter;

if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
int col = (int)floor(index_col);
int row = (int)floor(index_row);
float v = index_col - col;
float u = index_row - row;
for (int c = 0; c < C; ++c) {
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW,
bmode>::get(sptr, row + 0, col + 0, c, IH,
IW, C, scalar);
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW,
bmode>::get(sptr, row + 0, col + 1, c, IH,
IW, C, scalar);
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW,
bmode>::get(sptr, row + 1, col + 0, c, IH,
IW, C, scalar);
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW,
bmode>::get(sptr, row + 1, col + 1, c, IH,
IW, C, scalar);
dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW,
C)] =
round_converter(a00 * (1.f - u) * (1.f - v) +
a01 * (1.f - u) * v + a10 * (1.f - v) * u +
a11 * u * v);
}
}
}

template <typename ctype, typename SrcVisitor, ::BorderMode bmode>
__global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy,
ctype* __restrict dst, int C, int IH, int IW,
int OH, int OW, float scalar) {
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW);
dst += blockIdx.z * C * OH * OW;
map_xy += blockIdx.z * 2 * OH * OW;
RoundingConverter<ctype> round_converter;

if (ow < OW && oh < OH) {
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0];
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1];
int col = (int)floor(index_col);
int row = (int)floor(index_row);
float v = index_col - col;
float u = index_row - row;
for (int c = 0; c < C; ++c) {
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC,
bmode>::get(sptr, row + 0, col + 0, c, IH,
IW, C, scalar);
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC,
bmode>::get(sptr, row + 0, col + 1, c, IH,
IW, C, scalar);
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC,
bmode>::get(sptr, row + 1, col + 0, c, IH,
IW, C, scalar);
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC,
bmode>::get(sptr, row + 1, col + 1, c, IH,
IW, C, scalar);
dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW,
C)] =
round_converter(a00 * (1.f - u) * (1.f - v) +
a01 * (1.f - u) * v + a10 * (1.f - v) * u +
a11 * u * v);
}
}
}

template <typename ctype, typename SrcVisitor, const uint32_t format,
::BorderMode bmode>
void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst,
int N, int C, int IH, int IW, int OH, int OW,
float scalar, int S_IN, int S_IC, int S_IH, int S_IW,
cudaStream_t stream) {
const int BX = 32, BY = 16;

const int max_batch_size = 65535;
while (N) {
size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
dim3 threads(BX, BY);
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size);

if (format == param_enumv::Remap::Format::NCHW) {
kern_general<ctype, SrcVisitor, bmode>
<<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH,
IW, OH, OW, S_IN, S_IC,
S_IH, S_IW, scalar);
} else if (format == param_enumv::Remap::Format::NHWC) {
kern_general_nhwc<ctype, SrcVisitor, bmode>
<<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH,
IW, OH, OW, scalar);
}

N -= curr_batch_size;
src.move_batch(curr_batch_size, C * IH * IW);
dst += curr_batch_size * C * OH * OW;
}
}

} // anonymous namespace

namespace megdnn {
namespace cuda {
namespace remap {

template <typename ctype, const uint32_t format, ::BorderMode bmode>
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N,
int C, int IH, int IW, int OH, int OW, float scalar,
int S_IN, int S_IC, int S_IH, int S_IW,
cudaStream_t stream) {
DirectSrcVisitor<ctype> visitor;
visitor.ptr = src;
using SrcVisitor = DirectSrcVisitor<ctype>;
dispatch_with_visitor<ctype, SrcVisitor, format, bmode>(
visitor, map_xy, dst, N, C, IH, IW, OH, OW, scalar, S_IN, S_IC,
S_IH, S_IW, stream);
after_kernel_launch();
}

#define INST(ctype, format, bmode) \
template void forward_proxy<ctype, param_enumv::Remap::Format::format, \
::BorderMode::bmode>( \
const ctype* src, const float*, ctype*, int, int, int, int, int, \
int, float, int, int, int, int, cudaStream_t);

#define FOR_FORMAT_BMODE(ctype) \
INST(ctype, NCHW, BORDER_CONSTANT) \
INST(ctype, NCHW, BORDER_REPLICATE) \
INST(ctype, NCHW, BORDER_REFLECT) \
INST(ctype, NCHW, BORDER_REFLECT_101) \
INST(ctype, NCHW, BORDER_WRAP) \
INST(ctype, NHWC, BORDER_CONSTANT) \
INST(ctype, NHWC, BORDER_REPLICATE) \
INST(ctype, NHWC, BORDER_REFLECT) \
INST(ctype, NHWC, BORDER_REFLECT_101) \
INST(ctype, NHWC, BORDER_WRAP)

FOR_FORMAT_BMODE(float)
MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16))
FOR_FORMAT_BMODE(int8_t)
FOR_FORMAT_BMODE(uint8_t)

#undef FOR_BMODE
#undef INST
} // namespace remap
} // namespace cuda
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 0
- 26
dnn/src/cuda/remap/opr_impl.cpp View File

@@ -1,26 +0,0 @@
/**
* \file dnn/src/opencl/cuda/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/cuda/remap/opr_impl.h"
#include "megdnn/config/config.h"
#include "src/common/utils.h"

using namespace megdnn;
using namespace cuda;

void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
_megdnn_tensor_in dst, _megdnn_workspace workspace) {
check_exec(src.layout, map_xy.layout, dst.layout, workspace.size);
megdnn_throw("megdnn currently do not support remap in cuda");
}

// vim: syntax=cpp.doxygen

+ 4
- 1
dnn/src/cuda/remap/opr_impl.h View File

@@ -1,5 +1,5 @@
/**
* \file dnn/src/opencl/cuda/opr_impl.h
* \file dnn/src/cuda/remap/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -16,13 +16,16 @@ namespace megdnn {
namespace cuda {
class RemapImpl final : public Remap {
using Remap::Remap;

void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out,
_megdnn_workspace) override;

size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};

} // namespace cuda
} // namespace megdnn



+ 22
- 33
dnn/src/naive/remap/opr_impl.cpp View File

@@ -37,42 +37,29 @@ inline int get_offset<param::Remap::Format::NHWC>(int height, int width,
}

template <typename DataType, param::Remap::Format format,
param::Remap::BorderMode bodertype>
param::Remap::BorderMode bordertype>
struct GetSrcData {
static inline DataType get(const DataType*, int, int, int, int, int, int,
int);
static inline DataType get(const DataType* src, int height, int width,
int channel, int h, int w, int c, float,
std::function<DataType(float)>) {
height = megcv::border_interpolate<bordertype>(height, h);
width = megcv::border_interpolate<bordertype>(width, w);
return src[get_offset<format>(height, width, channel, h, w, c)];
}
};

template <typename DataType, param::Remap::Format format>
struct GetSrcData<DataType, format, param::Remap::BorderMode::CONSTANT> {
static inline DataType get(const DataType* src, int height, int width,
int channel, int h, int w, int c, float scalar) {
int channel, int h, int w, int c, float scalar,
std::function<DataType(float)> round) {
return (height >= 0 && height < h && width >= 0 && width < w)
? src[get_offset<format>(height, width, channel, h, w,
c)]
: static_cast<DataType>(std::round(scalar));
: static_cast<DataType>(round(scalar));
}
};

#define cb(bmode) \
template <typename DataType, param::Remap::Format format> \
struct GetSrcData<DataType, format, param::Remap::BorderMode::bmode> { \
static inline DataType get(const DataType* src, int height, int width, \
int channel, int h, int w, int c, float) { \
height = megcv::border_interpolate< \
param::Remap::BorderMode::bmode>(height, h); \
width = megcv::border_interpolate< \
param::Remap::BorderMode::bmode>(width, w); \
return src[get_offset<format>(height, width, channel, h, w, c)]; \
} \
};

cb(REPLICATE);
cb(REFLECT);
cb(REFLECT_101);
cb(WRAP);
#undef cb

template <typename DataType, param::Remap::Format format,
param::Remap::BorderMode bordertype>
void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst,
@@ -92,20 +79,20 @@ void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst,
for (int c = 0; c < C; ++c) {
DataType a00 =
GetSrcData<DataType, format, bordertype>::get(
src, row + 0, col + 0, c, IH, IW, C,
scalar);
src, row + 0, col + 0, c, IH, IW, C, scalar,
round);
DataType a01 =
GetSrcData<DataType, format, bordertype>::get(
src, row + 0, col + 1, c, IH, IW, C,
scalar);
src, row + 0, col + 1, c, IH, IW, C, scalar,
round);
DataType a10 =
GetSrcData<DataType, format, bordertype>::get(
src, row + 1, col + 0, c, IH, IW, C,
scalar);
src, row + 1, col + 0, c, IH, IW, C, scalar,
round);
DataType a11 =
GetSrcData<DataType, format, bordertype>::get(
src, row + 1, col + 1, c, IH, IW, C,
scalar);
src, row + 1, col + 1, c, IH, IW, C, scalar,
round);

dst[get_offset<format>(h, w, c, OH, OW, C)] =
static_cast<DataType>(
@@ -139,11 +126,13 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy,
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
} else {
} else if (param().format == param::Remap::Format::NHWC) {
N = src.layout.shape[0];
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
} else {
megdnn_throw("unsupported format");
}
OH = map_xy.layout.shape[1];
OW = map_xy.layout.shape[2];


+ 203
- 0
dnn/test/cuda/remap.cpp View File

@@ -0,0 +1,203 @@
/**
* \file dnn/test/cuda/remap.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/common/remap.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
#include "test/cuda/benchmark.h"
#include "test/cuda/fixture.h"

namespace megdnn {
namespace test {
namespace remap {

TEST_F(CUDA, REMAP_NCHW_FLOAT) {
Checker<Remap> checker(handle_cuda());
std::vector<TestArg> args = get_nchw_args();
UniformFloatRNG float_rng(0, 255);
#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(0, data_type) \
.set_dtype(1, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_rng(0, &data_rng) \
.set_rng(1, &map_rng) \
.set_rng(2, &data_rng) \
.set_param(arg.param) \
.execs({arg.src, arg.map_xy, arg.dst}); \
}
cb(dtype::Float32(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
}

TEST_F(CUDA, REMAP_NCHW_INT) {
Checker<Remap> checker(handle_cuda());
std::vector<TestArg> args = get_nchw_args();
UniformIntRNG uint8_rng(0, 255);
UniformIntRNG int8_rng(-128, 127);

#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(0, data_type) \
.set_dtype(1, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_rng(0, &data_rng) \
.set_rng(1, &map_rng) \
.set_rng(2, &data_rng) \
.set_epsilon(1) \
.set_param(arg.param) \
.execs({arg.src, arg.map_xy, arg.dst}); \
}
cb(dtype::Int8(), int8_rng);
cb(dtype::Uint8(), uint8_rng);
#undef cb
}

TEST_F(CUDA, REMAP_NHWC_FLOAT) {
Checker<Remap> checker(handle_cuda());
std::vector<TestArg> args = get_nhwc_args();
UniformFloatRNG float_rng(0, 255);
#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(0, data_type) \
.set_dtype(1, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_rng(0, &data_rng) \
.set_rng(1, &map_rng) \
.set_rng(2, &data_rng) \
.set_param(arg.param) \
.execs({arg.src, arg.map_xy, arg.dst}); \
}
cb(dtype::Float32(), float_rng);
cb(dtype::Float16(), float_rng);
#undef cb
}

TEST_F(CUDA, REMAP_NHWC_INT) {
Checker<Remap> checker(handle_cuda());
std::vector<TestArg> args = get_nhwc_args();
UniformIntRNG uint8_rng(0, 255);
UniformIntRNG int8_rng(-128, 127);

#define cb(data_type, data_rng) \
for (auto arg : args) { \
UniformFloatRNG map_rng( \
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \
checker.set_dtype(0, data_type) \
.set_dtype(1, dtype::Float32()) \
.set_dtype(2, data_type) \
.set_rng(0, &data_rng) \
.set_rng(1, &map_rng) \
.set_rng(2, &data_rng) \
.set_epsilon(1) \
.set_param(arg.param) \
.execs({arg.src, arg.map_xy, arg.dst}); \
}
cb(dtype::Int8(), int8_rng);
cb(dtype::Uint8(), uint8_rng);
#undef cb
}

#if MEGDNN_WITH_BENCHMARK

TEST_F(CUDA, BENCHMARK_REMAP) {
using Param = param::Remap;
auto run = [&](const TensorShapeArray& shapes, Param param, DType dtype) {
auto handle_cpu = create_cpu_handle(2);
Benchmarker<Remap> benchmarker_naive(handle_cpu.get());
CUBenchmarker<Remap> benchmarker_cuda(handle_cuda());
UniformIntRNG rng(0, 0xff);
UniformFloatRNG map_rng(
-2, std::max(shapes[1].shape[1], shapes[1].shape[2]) + 2);
benchmarker_naive.set_rng(0, &rng);
benchmarker_cuda.set_rng(0, &rng);
benchmarker_naive.set_rng(1, &map_rng);
benchmarker_cuda.set_rng(1, &map_rng);
benchmarker_naive.set_rng(2, &rng);
benchmarker_cuda.set_rng(2, &rng);

benchmarker_naive.set_dtype(1, dtype::Float32());
benchmarker_cuda.set_dtype(1, dtype::Float32());
benchmarker_naive.set_dtype(0, dtype).set_dtype(2, dtype);
benchmarker_cuda.set_dtype(0, dtype).set_dtype(2, dtype);

size_t RUN = 10;
auto t1 = benchmarker_naive.set_display(false)
.set_times(RUN)
.set_param(param)
.execs(shapes);
auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs(
shapes);
const TensorShape dst_layout = shapes[2];

float calc_amount = dst_layout.total_nr_elems();
printf("naive={%.3fms, %.3fMflops}, "
"cuda={%.3fms, %.3fMflops}\n",
t1 / RUN, calc_amount / (t1 / RUN * 1000), t2,
calc_amount / (t2 * 1000));
};
Param param;
param.imode = param::Remap::InterpolationMode::LINEAR;
param.format = param::Remap::Format::NHWC;
param.border_type = param::Remap::BorderMode::CONSTANT;
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param,
dtype::Float32{});
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param,
dtype::Float16{});
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param,
dtype::Uint8{});
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param,
dtype::Int8{});
param.border_type = param::Remap::BorderMode::REPLICATE;
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param,
dtype::Float32{});
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param,
dtype::Float16{});
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param,
dtype::Uint8{});
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param,
dtype::Int8{});
param.format = param::Remap::Format::NCHW;
param.border_type = param::Remap::BorderMode::CONSTANT;
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param,
dtype::Float32{});
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param,
dtype::Float16{});
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param,
dtype::Uint8{});
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param,
dtype::Int8{});
param.border_type = param::Remap::BorderMode::REPLICATE;
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param,
dtype::Float32{});
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param,
dtype::Float16{});
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param,
dtype::Uint8{});
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param,
dtype::Int8{});
}

#endif

} // namespace remap
} // namespace test
} // namespace megdnn

// vim: syntax=cpp.doxygen

Loading…
Cancel
Save