@@ -0,0 +1,33 @@ | |||
/** | |||
* \file dnn/src/cuda/remap/common.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <cuda_runtime_api.h> | |||
#include "megcore_cdefs.h" | |||
#include "src/common/cv/enums.h" | |||
#include "src/common/opr_param_defs_enumv.cuh" | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace remap { | |||
// all these kernels use LINEAR interpolation | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | |||
int C, int IH, int IW, int OH, int OW, float scalar, | |||
int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream); | |||
} // namespace remap | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,93 @@ | |||
/** | |||
* \file dnn/src/cuda/remap/forward.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/config/config.h" | |||
#include "src/common/opr_param_defs_enumv.cuh" | |||
#include "src/cuda/remap/common.h" | |||
#include "src/cuda/remap/opr_impl.h" | |||
#include "src/cuda/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||
_megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); | |||
auto stream = cuda_stream(this->handle()); | |||
int N, C, IH, IW, OH, OW; | |||
ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0; | |||
OH = map_xy.layout.shape[1]; | |||
OW = map_xy.layout.shape[2]; | |||
megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR, | |||
"only support LINEAR interpolationMode"); | |||
if (param().format == param::Remap::Format::NCHW) { | |||
N = src.layout.shape[0]; | |||
C = src.layout.shape[1]; | |||
IH = src.layout.shape[2]; | |||
IW = src.layout.shape[3]; | |||
S_IN = src.layout.stride[0]; | |||
S_IC = src.layout.stride[1]; | |||
S_IH = src.layout.stride[2]; | |||
S_IW = src.layout.stride[3]; | |||
} else if (param().format == param::Remap::Format::NHWC) { | |||
N = src.layout.shape[0]; | |||
C = src.layout.shape[3]; | |||
IH = src.layout.shape[1]; | |||
IW = src.layout.shape[2]; | |||
} else { | |||
megdnn_throw("unsupported format, cuda remap"); | |||
} | |||
#define cb(dt, _format, bmode) \ | |||
if (param().format == param::Remap::Format::_format && \ | |||
param().border_type == param::Remap::BorderMode::bmode) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
remap::forward_proxy<ctype, param_enumv::Remap::Format::_format, \ | |||
::BorderMode::BORDER_##bmode>( \ | |||
src.compatible_ptr<ctype>(), \ | |||
map_xy.compatible_ptr<dt_float32>(), \ | |||
dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \ | |||
param().scalar, S_IN, S_IC, S_IH, S_IW, stream); \ | |||
break; \ | |||
} | |||
#define support_dtype(dt) \ | |||
case DTypeTrait<dt>::enumv: { \ | |||
cb(dt, NCHW, CONSTANT); \ | |||
cb(dt, NCHW, REPLICATE); \ | |||
cb(dt, NCHW, REFLECT); \ | |||
cb(dt, NCHW, REFLECT_101); \ | |||
cb(dt, NCHW, WRAP); \ | |||
cb(dt, NHWC, CONSTANT); \ | |||
cb(dt, NHWC, REPLICATE); \ | |||
cb(dt, NHWC, REFLECT); \ | |||
cb(dt, NHWC, REFLECT_101); \ | |||
cb(dt, NHWC, WRAP); \ | |||
megdnn_throw("unsupported border type in remap cuda"); \ | |||
} | |||
switch (src.layout.dtype.enumv()) { | |||
support_dtype(dtype::Float32) | |||
MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)) | |||
support_dtype(dtype::Int8) | |||
support_dtype(dtype::Uint8) | |||
default: | |||
megdnn_throw("unsupported dtype in remap cuda"); | |||
} | |||
#undef supported_dtype | |||
#undef cb | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,238 @@ | |||
/** | |||
* \file dnn/src/cuda/remap/forward.cu | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include <cuda.h> | |||
#include <cuda_runtime.h> | |||
#include "src/common/rounding_converter.cuh" | |||
#include "src/cuda/cv/kernel_common.cuh" | |||
#include "src/cuda/remap/common.h" | |||
#include "src/cuda/utils.cuh" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace remap; | |||
using namespace rounding; | |||
namespace { | |||
template <typename ctype> | |||
struct DirectSrcVisitor { | |||
const ctype* ptr; | |||
__device__ __forceinline__ const ctype* get(int batch, int im_size) { | |||
return ptr + batch * im_size; | |||
} | |||
void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; } | |||
}; | |||
template <const uint32_t format> | |||
__device__ inline int get_offset(int height, int width, int channel, int h, | |||
int w, int c); | |||
template <> | |||
__device__ inline int get_offset<param_enumv::Remap::Format::NCHW>( | |||
int height, int width, int channel, int h, int w, int c) { | |||
return channel * h * w + height * w + width; | |||
} | |||
template <> | |||
__device__ inline int get_offset<param_enumv::Remap::Format::NHWC>( | |||
int height, int width, int channel, int h, int w, int c) { | |||
return height * w * c + width * c + channel; | |||
} | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
struct GetSrcData { | |||
__device__ static inline ctype get(const ctype* src, int height, int width, | |||
int channel, int h, int w, int c, | |||
float) { | |||
height = megcv::border_interpolate<bmode>(height, h); | |||
width = megcv::border_interpolate<bmode>(width, w); | |||
return src[get_offset<format>(height, width, channel, h, w, c)]; | |||
} | |||
}; | |||
template <typename ctype, const uint32_t format> | |||
struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||
__device__ static inline ctype get(const ctype* src, int height, int width, | |||
int channel, int h, int w, int c, | |||
float scalar) { | |||
RoundingConverter<ctype> round_converter; | |||
return (height >= 0 && height < h && width >= 0 && width < w) | |||
? src[get_offset<format>(height, width, channel, h, w, | |||
c)] | |||
: round_converter(scalar); | |||
} | |||
}; | |||
template <typename ctype, typename SrcVisitor, ::BorderMode bmode> | |||
__global__ void kern_general(SrcVisitor src, const float* map_xy, | |||
ctype* __restrict dst, int C, int IH, int IW, | |||
int OH, int OW, int S_IN, int S_IC, int S_IH, | |||
int S_IW, float scalar) { | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN); | |||
dst += blockIdx.z * C * OH * OW; | |||
map_xy += blockIdx.z * 2 * OH * OW; | |||
RoundingConverter<ctype> round_converter; | |||
if (ow < OW && oh < OH) { | |||
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||
int col = (int)floor(index_col); | |||
int row = (int)floor(index_row); | |||
float v = index_col - col; | |||
float u = index_row - row; | |||
for (int c = 0; c < C; ++c) { | |||
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||
bmode>::get(sptr, row + 0, col + 0, c, IH, | |||
IW, C, scalar); | |||
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||
bmode>::get(sptr, row + 0, col + 1, c, IH, | |||
IW, C, scalar); | |||
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||
bmode>::get(sptr, row + 1, col + 0, c, IH, | |||
IW, C, scalar); | |||
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||
bmode>::get(sptr, row + 1, col + 1, c, IH, | |||
IW, C, scalar); | |||
dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW, | |||
C)] = | |||
round_converter(a00 * (1.f - u) * (1.f - v) + | |||
a01 * (1.f - u) * v + a10 * (1.f - v) * u + | |||
a11 * u * v); | |||
} | |||
} | |||
} | |||
template <typename ctype, typename SrcVisitor, ::BorderMode bmode> | |||
__global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, | |||
ctype* __restrict dst, int C, int IH, int IW, | |||
int OH, int OW, float scalar) { | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); | |||
dst += blockIdx.z * C * OH * OW; | |||
map_xy += blockIdx.z * 2 * OH * OW; | |||
RoundingConverter<ctype> round_converter; | |||
if (ow < OW && oh < OH) { | |||
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||
int col = (int)floor(index_col); | |||
int row = (int)floor(index_row); | |||
float v = index_col - col; | |||
float u = index_row - row; | |||
for (int c = 0; c < C; ++c) { | |||
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||
bmode>::get(sptr, row + 0, col + 0, c, IH, | |||
IW, C, scalar); | |||
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||
bmode>::get(sptr, row + 0, col + 1, c, IH, | |||
IW, C, scalar); | |||
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||
bmode>::get(sptr, row + 1, col + 0, c, IH, | |||
IW, C, scalar); | |||
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||
bmode>::get(sptr, row + 1, col + 1, c, IH, | |||
IW, C, scalar); | |||
dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW, | |||
C)] = | |||
round_converter(a00 * (1.f - u) * (1.f - v) + | |||
a01 * (1.f - u) * v + a10 * (1.f - v) * u + | |||
a11 * u * v); | |||
} | |||
} | |||
} | |||
template <typename ctype, typename SrcVisitor, const uint32_t format, | |||
::BorderMode bmode> | |||
void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst, | |||
int N, int C, int IH, int IW, int OH, int OW, | |||
float scalar, int S_IN, int S_IC, int S_IH, int S_IW, | |||
cudaStream_t stream) { | |||
const int BX = 32, BY = 16; | |||
const int max_batch_size = 65535; | |||
while (N) { | |||
size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; | |||
dim3 threads(BX, BY); | |||
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||
if (format == param_enumv::Remap::Format::NCHW) { | |||
kern_general<ctype, SrcVisitor, bmode> | |||
<<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH, | |||
IW, OH, OW, S_IN, S_IC, | |||
S_IH, S_IW, scalar); | |||
} else if (format == param_enumv::Remap::Format::NHWC) { | |||
kern_general_nhwc<ctype, SrcVisitor, bmode> | |||
<<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH, | |||
IW, OH, OW, scalar); | |||
} | |||
N -= curr_batch_size; | |||
src.move_batch(curr_batch_size, C * IH * IW); | |||
dst += curr_batch_size * C * OH * OW; | |||
} | |||
} | |||
} // anonymous namespace | |||
namespace megdnn { | |||
namespace cuda { | |||
namespace remap { | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | |||
int C, int IH, int IW, int OH, int OW, float scalar, | |||
int S_IN, int S_IC, int S_IH, int S_IW, | |||
cudaStream_t stream) { | |||
DirectSrcVisitor<ctype> visitor; | |||
visitor.ptr = src; | |||
using SrcVisitor = DirectSrcVisitor<ctype>; | |||
dispatch_with_visitor<ctype, SrcVisitor, format, bmode>( | |||
visitor, map_xy, dst, N, C, IH, IW, OH, OW, scalar, S_IN, S_IC, | |||
S_IH, S_IW, stream); | |||
after_kernel_launch(); | |||
} | |||
#define INST(ctype, format, bmode) \ | |||
template void forward_proxy<ctype, param_enumv::Remap::Format::format, \ | |||
::BorderMode::bmode>( \ | |||
const ctype* src, const float*, ctype*, int, int, int, int, int, \ | |||
int, float, int, int, int, int, cudaStream_t); | |||
#define FOR_FORMAT_BMODE(ctype) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE) \ | |||
INST(ctype, NCHW, BORDER_REFLECT) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||
INST(ctype, NCHW, BORDER_WRAP) \ | |||
INST(ctype, NHWC, BORDER_CONSTANT) \ | |||
INST(ctype, NHWC, BORDER_REPLICATE) \ | |||
INST(ctype, NHWC, BORDER_REFLECT) \ | |||
INST(ctype, NHWC, BORDER_REFLECT_101) \ | |||
INST(ctype, NHWC, BORDER_WRAP) | |||
FOR_FORMAT_BMODE(float) | |||
MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | |||
FOR_FORMAT_BMODE(int8_t) | |||
FOR_FORMAT_BMODE(uint8_t) | |||
#undef FOR_BMODE | |||
#undef INST | |||
} // namespace remap | |||
} // namespace cuda | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -1,26 +0,0 @@ | |||
/** | |||
* \file dnn/src/opencl/cuda/opr_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/cuda/remap/opr_impl.h" | |||
#include "megdnn/config/config.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||
_megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); | |||
megdnn_throw("megdnn currently do not support remap in cuda"); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||
/** | |||
* \file dnn/src/opencl/cuda/opr_impl.h | |||
* \file dnn/src/cuda/remap/opr_impl.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
@@ -16,13 +16,16 @@ namespace megdnn { | |||
namespace cuda { | |||
class RemapImpl final : public Remap { | |||
using Remap::Remap; | |||
void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out, | |||
_megdnn_workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
const TensorLayout&) override { | |||
return 0; | |||
} | |||
}; | |||
} // namespace cuda | |||
} // namespace megdnn | |||
@@ -37,42 +37,29 @@ inline int get_offset<param::Remap::Format::NHWC>(int height, int width, | |||
} | |||
template <typename DataType, param::Remap::Format format, | |||
param::Remap::BorderMode bodertype> | |||
param::Remap::BorderMode bordertype> | |||
struct GetSrcData { | |||
static inline DataType get(const DataType*, int, int, int, int, int, int, | |||
int); | |||
static inline DataType get(const DataType* src, int height, int width, | |||
int channel, int h, int w, int c, float, | |||
std::function<DataType(float)>) { | |||
height = megcv::border_interpolate<bordertype>(height, h); | |||
width = megcv::border_interpolate<bordertype>(width, w); | |||
return src[get_offset<format>(height, width, channel, h, w, c)]; | |||
} | |||
}; | |||
template <typename DataType, param::Remap::Format format> | |||
struct GetSrcData<DataType, format, param::Remap::BorderMode::CONSTANT> { | |||
static inline DataType get(const DataType* src, int height, int width, | |||
int channel, int h, int w, int c, float scalar) { | |||
int channel, int h, int w, int c, float scalar, | |||
std::function<DataType(float)> round) { | |||
return (height >= 0 && height < h && width >= 0 && width < w) | |||
? src[get_offset<format>(height, width, channel, h, w, | |||
c)] | |||
: static_cast<DataType>(std::round(scalar)); | |||
: static_cast<DataType>(round(scalar)); | |||
} | |||
}; | |||
#define cb(bmode) \ | |||
template <typename DataType, param::Remap::Format format> \ | |||
struct GetSrcData<DataType, format, param::Remap::BorderMode::bmode> { \ | |||
static inline DataType get(const DataType* src, int height, int width, \ | |||
int channel, int h, int w, int c, float) { \ | |||
height = megcv::border_interpolate< \ | |||
param::Remap::BorderMode::bmode>(height, h); \ | |||
width = megcv::border_interpolate< \ | |||
param::Remap::BorderMode::bmode>(width, w); \ | |||
return src[get_offset<format>(height, width, channel, h, w, c)]; \ | |||
} \ | |||
}; | |||
cb(REPLICATE); | |||
cb(REFLECT); | |||
cb(REFLECT_101); | |||
cb(WRAP); | |||
#undef cb | |||
template <typename DataType, param::Remap::Format format, | |||
param::Remap::BorderMode bordertype> | |||
void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | |||
@@ -92,20 +79,20 @@ void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | |||
for (int c = 0; c < C; ++c) { | |||
DataType a00 = | |||
GetSrcData<DataType, format, bordertype>::get( | |||
src, row + 0, col + 0, c, IH, IW, C, | |||
scalar); | |||
src, row + 0, col + 0, c, IH, IW, C, scalar, | |||
round); | |||
DataType a01 = | |||
GetSrcData<DataType, format, bordertype>::get( | |||
src, row + 0, col + 1, c, IH, IW, C, | |||
scalar); | |||
src, row + 0, col + 1, c, IH, IW, C, scalar, | |||
round); | |||
DataType a10 = | |||
GetSrcData<DataType, format, bordertype>::get( | |||
src, row + 1, col + 0, c, IH, IW, C, | |||
scalar); | |||
src, row + 1, col + 0, c, IH, IW, C, scalar, | |||
round); | |||
DataType a11 = | |||
GetSrcData<DataType, format, bordertype>::get( | |||
src, row + 1, col + 1, c, IH, IW, C, | |||
scalar); | |||
src, row + 1, col + 1, c, IH, IW, C, scalar, | |||
round); | |||
dst[get_offset<format>(h, w, c, OH, OW, C)] = | |||
static_cast<DataType>( | |||
@@ -139,11 +126,13 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
C = src.layout.shape[1]; | |||
IH = src.layout.shape[2]; | |||
IW = src.layout.shape[3]; | |||
} else { | |||
} else if (param().format == param::Remap::Format::NHWC) { | |||
N = src.layout.shape[0]; | |||
C = src.layout.shape[3]; | |||
IH = src.layout.shape[1]; | |||
IW = src.layout.shape[2]; | |||
} else { | |||
megdnn_throw("unsupported format"); | |||
} | |||
OH = map_xy.layout.shape[1]; | |||
OW = map_xy.layout.shape[2]; | |||
@@ -0,0 +1,203 @@ | |||
/** | |||
* \file dnn/test/cuda/remap.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/common/remap.h" | |||
#include "test/common/benchmarker.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/rng.h" | |||
#include "test/cuda/benchmark.h" | |||
#include "test/cuda/fixture.h" | |||
namespace megdnn { | |||
namespace test { | |||
namespace remap { | |||
TEST_F(CUDA, REMAP_NCHW_FLOAT) { | |||
Checker<Remap> checker(handle_cuda()); | |||
std::vector<TestArg> args = get_nchw_args(); | |||
UniformFloatRNG float_rng(0, 255); | |||
#define cb(data_type, data_rng) \ | |||
for (auto arg : args) { \ | |||
UniformFloatRNG map_rng( \ | |||
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
checker.set_dtype(0, data_type) \ | |||
.set_dtype(1, dtype::Float32()) \ | |||
.set_dtype(2, data_type) \ | |||
.set_rng(0, &data_rng) \ | |||
.set_rng(1, &map_rng) \ | |||
.set_rng(2, &data_rng) \ | |||
.set_param(arg.param) \ | |||
.execs({arg.src, arg.map_xy, arg.dst}); \ | |||
} | |||
cb(dtype::Float32(), float_rng); | |||
cb(dtype::Float16(), float_rng); | |||
#undef cb | |||
} | |||
TEST_F(CUDA, REMAP_NCHW_INT) { | |||
Checker<Remap> checker(handle_cuda()); | |||
std::vector<TestArg> args = get_nchw_args(); | |||
UniformIntRNG uint8_rng(0, 255); | |||
UniformIntRNG int8_rng(-128, 127); | |||
#define cb(data_type, data_rng) \ | |||
for (auto arg : args) { \ | |||
UniformFloatRNG map_rng( \ | |||
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
checker.set_dtype(0, data_type) \ | |||
.set_dtype(1, dtype::Float32()) \ | |||
.set_dtype(2, data_type) \ | |||
.set_rng(0, &data_rng) \ | |||
.set_rng(1, &map_rng) \ | |||
.set_rng(2, &data_rng) \ | |||
.set_epsilon(1) \ | |||
.set_param(arg.param) \ | |||
.execs({arg.src, arg.map_xy, arg.dst}); \ | |||
} | |||
cb(dtype::Int8(), int8_rng); | |||
cb(dtype::Uint8(), uint8_rng); | |||
#undef cb | |||
} | |||
TEST_F(CUDA, REMAP_NHWC_FLOAT) { | |||
Checker<Remap> checker(handle_cuda()); | |||
std::vector<TestArg> args = get_nhwc_args(); | |||
UniformFloatRNG float_rng(0, 255); | |||
#define cb(data_type, data_rng) \ | |||
for (auto arg : args) { \ | |||
UniformFloatRNG map_rng( \ | |||
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
checker.set_dtype(0, data_type) \ | |||
.set_dtype(1, dtype::Float32()) \ | |||
.set_dtype(2, data_type) \ | |||
.set_rng(0, &data_rng) \ | |||
.set_rng(1, &map_rng) \ | |||
.set_rng(2, &data_rng) \ | |||
.set_param(arg.param) \ | |||
.execs({arg.src, arg.map_xy, arg.dst}); \ | |||
} | |||
cb(dtype::Float32(), float_rng); | |||
cb(dtype::Float16(), float_rng); | |||
#undef cb | |||
} | |||
TEST_F(CUDA, REMAP_NHWC_INT) { | |||
Checker<Remap> checker(handle_cuda()); | |||
std::vector<TestArg> args = get_nhwc_args(); | |||
UniformIntRNG uint8_rng(0, 255); | |||
UniformIntRNG int8_rng(-128, 127); | |||
#define cb(data_type, data_rng) \ | |||
for (auto arg : args) { \ | |||
UniformFloatRNG map_rng( \ | |||
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
checker.set_dtype(0, data_type) \ | |||
.set_dtype(1, dtype::Float32()) \ | |||
.set_dtype(2, data_type) \ | |||
.set_rng(0, &data_rng) \ | |||
.set_rng(1, &map_rng) \ | |||
.set_rng(2, &data_rng) \ | |||
.set_epsilon(1) \ | |||
.set_param(arg.param) \ | |||
.execs({arg.src, arg.map_xy, arg.dst}); \ | |||
} | |||
cb(dtype::Int8(), int8_rng); | |||
cb(dtype::Uint8(), uint8_rng); | |||
#undef cb | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(CUDA, BENCHMARK_REMAP) { | |||
using Param = param::Remap; | |||
auto run = [&](const TensorShapeArray& shapes, Param param, DType dtype) { | |||
auto handle_cpu = create_cpu_handle(2); | |||
Benchmarker<Remap> benchmarker_naive(handle_cpu.get()); | |||
CUBenchmarker<Remap> benchmarker_cuda(handle_cuda()); | |||
UniformIntRNG rng(0, 0xff); | |||
UniformFloatRNG map_rng( | |||
-2, std::max(shapes[1].shape[1], shapes[1].shape[2]) + 2); | |||
benchmarker_naive.set_rng(0, &rng); | |||
benchmarker_cuda.set_rng(0, &rng); | |||
benchmarker_naive.set_rng(1, &map_rng); | |||
benchmarker_cuda.set_rng(1, &map_rng); | |||
benchmarker_naive.set_rng(2, &rng); | |||
benchmarker_cuda.set_rng(2, &rng); | |||
benchmarker_naive.set_dtype(1, dtype::Float32()); | |||
benchmarker_cuda.set_dtype(1, dtype::Float32()); | |||
benchmarker_naive.set_dtype(0, dtype).set_dtype(2, dtype); | |||
benchmarker_cuda.set_dtype(0, dtype).set_dtype(2, dtype); | |||
size_t RUN = 10; | |||
auto t1 = benchmarker_naive.set_display(false) | |||
.set_times(RUN) | |||
.set_param(param) | |||
.execs(shapes); | |||
auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs( | |||
shapes); | |||
const TensorShape dst_layout = shapes[2]; | |||
float calc_amount = dst_layout.total_nr_elems(); | |||
printf("naive={%.3fms, %.3fMflops}, " | |||
"cuda={%.3fms, %.3fMflops}\n", | |||
t1 / RUN, calc_amount / (t1 / RUN * 1000), t2, | |||
calc_amount / (t2 * 1000)); | |||
}; | |||
Param param; | |||
param.imode = param::Remap::InterpolationMode::LINEAR; | |||
param.format = param::Remap::Format::NHWC; | |||
param.border_type = param::Remap::BorderMode::CONSTANT; | |||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||
dtype::Float32{}); | |||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||
dtype::Float16{}); | |||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||
dtype::Uint8{}); | |||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||
dtype::Int8{}); | |||
param.border_type = param::Remap::BorderMode::REPLICATE; | |||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||
dtype::Float32{}); | |||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||
dtype::Float16{}); | |||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||
dtype::Uint8{}); | |||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||
dtype::Int8{}); | |||
param.format = param::Remap::Format::NCHW; | |||
param.border_type = param::Remap::BorderMode::CONSTANT; | |||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||
dtype::Float32{}); | |||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||
dtype::Float16{}); | |||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||
dtype::Uint8{}); | |||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||
dtype::Int8{}); | |||
param.border_type = param::Remap::BorderMode::REPLICATE; | |||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||
dtype::Float32{}); | |||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||
dtype::Float16{}); | |||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||
dtype::Uint8{}); | |||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||
dtype::Int8{}); | |||
} | |||
#endif | |||
} // namespace remap | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |