@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file dnn/src/cuda/remap/common.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cuda_runtime_api.h> | |||||
#include "megcore_cdefs.h" | |||||
#include "src/common/cv/enums.h" | |||||
#include "src/common/opr_param_defs_enumv.cuh" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace remap { | |||||
// all these kernels use LINEAR interpolation | |||||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | |||||
int C, int IH, int IW, int OH, int OW, float scalar, | |||||
int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream); | |||||
} // namespace remap | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,93 @@ | |||||
/** | |||||
* \file dnn/src/cuda/remap/forward.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/config/config.h" | |||||
#include "src/common/opr_param_defs_enumv.cuh" | |||||
#include "src/cuda/remap/common.h" | |||||
#include "src/cuda/remap/opr_impl.h" | |||||
#include "src/cuda/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||||
_megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||||
check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); | |||||
auto stream = cuda_stream(this->handle()); | |||||
int N, C, IH, IW, OH, OW; | |||||
ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0; | |||||
OH = map_xy.layout.shape[1]; | |||||
OW = map_xy.layout.shape[2]; | |||||
megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR, | |||||
"only support LINEAR interpolationMode"); | |||||
if (param().format == param::Remap::Format::NCHW) { | |||||
N = src.layout.shape[0]; | |||||
C = src.layout.shape[1]; | |||||
IH = src.layout.shape[2]; | |||||
IW = src.layout.shape[3]; | |||||
S_IN = src.layout.stride[0]; | |||||
S_IC = src.layout.stride[1]; | |||||
S_IH = src.layout.stride[2]; | |||||
S_IW = src.layout.stride[3]; | |||||
} else if (param().format == param::Remap::Format::NHWC) { | |||||
N = src.layout.shape[0]; | |||||
C = src.layout.shape[3]; | |||||
IH = src.layout.shape[1]; | |||||
IW = src.layout.shape[2]; | |||||
} else { | |||||
megdnn_throw("unsupported format, cuda remap"); | |||||
} | |||||
#define cb(dt, _format, bmode) \ | |||||
if (param().format == param::Remap::Format::_format && \ | |||||
param().border_type == param::Remap::BorderMode::bmode) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
remap::forward_proxy<ctype, param_enumv::Remap::Format::_format, \ | |||||
::BorderMode::BORDER_##bmode>( \ | |||||
src.compatible_ptr<ctype>(), \ | |||||
map_xy.compatible_ptr<dt_float32>(), \ | |||||
dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \ | |||||
param().scalar, S_IN, S_IC, S_IH, S_IW, stream); \ | |||||
break; \ | |||||
} | |||||
#define support_dtype(dt) \ | |||||
case DTypeTrait<dt>::enumv: { \ | |||||
cb(dt, NCHW, CONSTANT); \ | |||||
cb(dt, NCHW, REPLICATE); \ | |||||
cb(dt, NCHW, REFLECT); \ | |||||
cb(dt, NCHW, REFLECT_101); \ | |||||
cb(dt, NCHW, WRAP); \ | |||||
cb(dt, NHWC, CONSTANT); \ | |||||
cb(dt, NHWC, REPLICATE); \ | |||||
cb(dt, NHWC, REFLECT); \ | |||||
cb(dt, NHWC, REFLECT_101); \ | |||||
cb(dt, NHWC, WRAP); \ | |||||
megdnn_throw("unsupported border type in remap cuda"); \ | |||||
} | |||||
switch (src.layout.dtype.enumv()) { | |||||
support_dtype(dtype::Float32) | |||||
MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)) | |||||
support_dtype(dtype::Int8) | |||||
support_dtype(dtype::Uint8) | |||||
default: | |||||
megdnn_throw("unsupported dtype in remap cuda"); | |||||
} | |||||
#undef supported_dtype | |||||
#undef cb | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,238 @@ | |||||
/** | |||||
* \file dnn/src/cuda/remap/forward.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <cuda.h> | |||||
#include <cuda_runtime.h> | |||||
#include "src/common/rounding_converter.cuh" | |||||
#include "src/cuda/cv/kernel_common.cuh" | |||||
#include "src/cuda/remap/common.h" | |||||
#include "src/cuda/utils.cuh" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
using namespace remap; | |||||
using namespace rounding; | |||||
namespace { | |||||
template <typename ctype> | |||||
struct DirectSrcVisitor { | |||||
const ctype* ptr; | |||||
__device__ __forceinline__ const ctype* get(int batch, int im_size) { | |||||
return ptr + batch * im_size; | |||||
} | |||||
void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; } | |||||
}; | |||||
template <const uint32_t format> | |||||
__device__ inline int get_offset(int height, int width, int channel, int h, | |||||
int w, int c); | |||||
template <> | |||||
__device__ inline int get_offset<param_enumv::Remap::Format::NCHW>( | |||||
int height, int width, int channel, int h, int w, int c) { | |||||
return channel * h * w + height * w + width; | |||||
} | |||||
template <> | |||||
__device__ inline int get_offset<param_enumv::Remap::Format::NHWC>( | |||||
int height, int width, int channel, int h, int w, int c) { | |||||
return height * w * c + width * c + channel; | |||||
} | |||||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
struct GetSrcData { | |||||
__device__ static inline ctype get(const ctype* src, int height, int width, | |||||
int channel, int h, int w, int c, | |||||
float) { | |||||
height = megcv::border_interpolate<bmode>(height, h); | |||||
width = megcv::border_interpolate<bmode>(width, w); | |||||
return src[get_offset<format>(height, width, channel, h, w, c)]; | |||||
} | |||||
}; | |||||
template <typename ctype, const uint32_t format> | |||||
struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||||
__device__ static inline ctype get(const ctype* src, int height, int width, | |||||
int channel, int h, int w, int c, | |||||
float scalar) { | |||||
RoundingConverter<ctype> round_converter; | |||||
return (height >= 0 && height < h && width >= 0 && width < w) | |||||
? src[get_offset<format>(height, width, channel, h, w, | |||||
c)] | |||||
: round_converter(scalar); | |||||
} | |||||
}; | |||||
template <typename ctype, typename SrcVisitor, ::BorderMode bmode> | |||||
__global__ void kern_general(SrcVisitor src, const float* map_xy, | |||||
ctype* __restrict dst, int C, int IH, int IW, | |||||
int OH, int OW, int S_IN, int S_IC, int S_IH, | |||||
int S_IW, float scalar) { | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN); | |||||
dst += blockIdx.z * C * OH * OW; | |||||
map_xy += blockIdx.z * 2 * OH * OW; | |||||
RoundingConverter<ctype> round_converter; | |||||
if (ow < OW && oh < OH) { | |||||
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||||
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||||
int col = (int)floor(index_col); | |||||
int row = (int)floor(index_row); | |||||
float v = index_col - col; | |||||
float u = index_row - row; | |||||
for (int c = 0; c < C; ++c) { | |||||
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||||
bmode>::get(sptr, row + 0, col + 0, c, IH, | |||||
IW, C, scalar); | |||||
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||||
bmode>::get(sptr, row + 0, col + 1, c, IH, | |||||
IW, C, scalar); | |||||
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||||
bmode>::get(sptr, row + 1, col + 0, c, IH, | |||||
IW, C, scalar); | |||||
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||||
bmode>::get(sptr, row + 1, col + 1, c, IH, | |||||
IW, C, scalar); | |||||
dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW, | |||||
C)] = | |||||
round_converter(a00 * (1.f - u) * (1.f - v) + | |||||
a01 * (1.f - u) * v + a10 * (1.f - v) * u + | |||||
a11 * u * v); | |||||
} | |||||
} | |||||
} | |||||
template <typename ctype, typename SrcVisitor, ::BorderMode bmode> | |||||
__global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, | |||||
ctype* __restrict dst, int C, int IH, int IW, | |||||
int OH, int OW, float scalar) { | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); | |||||
dst += blockIdx.z * C * OH * OW; | |||||
map_xy += blockIdx.z * 2 * OH * OW; | |||||
RoundingConverter<ctype> round_converter; | |||||
if (ow < OW && oh < OH) { | |||||
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||||
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||||
int col = (int)floor(index_col); | |||||
int row = (int)floor(index_row); | |||||
float v = index_col - col; | |||||
float u = index_row - row; | |||||
for (int c = 0; c < C; ++c) { | |||||
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||||
bmode>::get(sptr, row + 0, col + 0, c, IH, | |||||
IW, C, scalar); | |||||
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||||
bmode>::get(sptr, row + 0, col + 1, c, IH, | |||||
IW, C, scalar); | |||||
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||||
bmode>::get(sptr, row + 1, col + 0, c, IH, | |||||
IW, C, scalar); | |||||
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||||
bmode>::get(sptr, row + 1, col + 1, c, IH, | |||||
IW, C, scalar); | |||||
dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW, | |||||
C)] = | |||||
round_converter(a00 * (1.f - u) * (1.f - v) + | |||||
a01 * (1.f - u) * v + a10 * (1.f - v) * u + | |||||
a11 * u * v); | |||||
} | |||||
} | |||||
} | |||||
template <typename ctype, typename SrcVisitor, const uint32_t format, | |||||
::BorderMode bmode> | |||||
void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst, | |||||
int N, int C, int IH, int IW, int OH, int OW, | |||||
float scalar, int S_IN, int S_IC, int S_IH, int S_IW, | |||||
cudaStream_t stream) { | |||||
const int BX = 32, BY = 16; | |||||
const int max_batch_size = 65535; | |||||
while (N) { | |||||
size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; | |||||
dim3 threads(BX, BY); | |||||
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||||
if (format == param_enumv::Remap::Format::NCHW) { | |||||
kern_general<ctype, SrcVisitor, bmode> | |||||
<<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH, | |||||
IW, OH, OW, S_IN, S_IC, | |||||
S_IH, S_IW, scalar); | |||||
} else if (format == param_enumv::Remap::Format::NHWC) { | |||||
kern_general_nhwc<ctype, SrcVisitor, bmode> | |||||
<<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH, | |||||
IW, OH, OW, scalar); | |||||
} | |||||
N -= curr_batch_size; | |||||
src.move_batch(curr_batch_size, C * IH * IW); | |||||
dst += curr_batch_size * C * OH * OW; | |||||
} | |||||
} | |||||
} // anonymous namespace | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace remap { | |||||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | |||||
int C, int IH, int IW, int OH, int OW, float scalar, | |||||
int S_IN, int S_IC, int S_IH, int S_IW, | |||||
cudaStream_t stream) { | |||||
DirectSrcVisitor<ctype> visitor; | |||||
visitor.ptr = src; | |||||
using SrcVisitor = DirectSrcVisitor<ctype>; | |||||
dispatch_with_visitor<ctype, SrcVisitor, format, bmode>( | |||||
visitor, map_xy, dst, N, C, IH, IW, OH, OW, scalar, S_IN, S_IC, | |||||
S_IH, S_IW, stream); | |||||
after_kernel_launch(); | |||||
} | |||||
#define INST(ctype, format, bmode) \ | |||||
template void forward_proxy<ctype, param_enumv::Remap::Format::format, \ | |||||
::BorderMode::bmode>( \ | |||||
const ctype* src, const float*, ctype*, int, int, int, int, int, \ | |||||
int, float, int, int, int, int, cudaStream_t); | |||||
#define FOR_FORMAT_BMODE(ctype) \ | |||||
INST(ctype, NCHW, BORDER_CONSTANT) \ | |||||
INST(ctype, NCHW, BORDER_REPLICATE) \ | |||||
INST(ctype, NCHW, BORDER_REFLECT) \ | |||||
INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||||
INST(ctype, NCHW, BORDER_WRAP) \ | |||||
INST(ctype, NHWC, BORDER_CONSTANT) \ | |||||
INST(ctype, NHWC, BORDER_REPLICATE) \ | |||||
INST(ctype, NHWC, BORDER_REFLECT) \ | |||||
INST(ctype, NHWC, BORDER_REFLECT_101) \ | |||||
INST(ctype, NHWC, BORDER_WRAP) | |||||
FOR_FORMAT_BMODE(float) | |||||
MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | |||||
FOR_FORMAT_BMODE(int8_t) | |||||
FOR_FORMAT_BMODE(uint8_t) | |||||
#undef FOR_BMODE | |||||
#undef INST | |||||
} // namespace remap | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -1,26 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/opencl/cuda/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/remap/opr_impl.h" | |||||
#include "megdnn/config/config.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace cuda; | |||||
void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||||
_megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||||
check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); | |||||
megdnn_throw("megdnn currently do not support remap in cuda"); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/opencl/cuda/opr_impl.h | |||||
* \file dnn/src/cuda/remap/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -16,13 +16,16 @@ namespace megdnn { | |||||
namespace cuda { | namespace cuda { | ||||
class RemapImpl final : public Remap { | class RemapImpl final : public Remap { | ||||
using Remap::Remap; | using Remap::Remap; | ||||
void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out, | void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out, | ||||
_megdnn_workspace) override; | _megdnn_workspace) override; | ||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | ||||
const TensorLayout&) override { | const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
} | } | ||||
}; | }; | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -37,42 +37,29 @@ inline int get_offset<param::Remap::Format::NHWC>(int height, int width, | |||||
} | } | ||||
template <typename DataType, param::Remap::Format format, | template <typename DataType, param::Remap::Format format, | ||||
param::Remap::BorderMode bodertype> | |||||
param::Remap::BorderMode bordertype> | |||||
struct GetSrcData { | struct GetSrcData { | ||||
static inline DataType get(const DataType*, int, int, int, int, int, int, | |||||
int); | |||||
static inline DataType get(const DataType* src, int height, int width, | |||||
int channel, int h, int w, int c, float, | |||||
std::function<DataType(float)>) { | |||||
height = megcv::border_interpolate<bordertype>(height, h); | |||||
width = megcv::border_interpolate<bordertype>(width, w); | |||||
return src[get_offset<format>(height, width, channel, h, w, c)]; | |||||
} | |||||
}; | }; | ||||
template <typename DataType, param::Remap::Format format> | template <typename DataType, param::Remap::Format format> | ||||
struct GetSrcData<DataType, format, param::Remap::BorderMode::CONSTANT> { | struct GetSrcData<DataType, format, param::Remap::BorderMode::CONSTANT> { | ||||
static inline DataType get(const DataType* src, int height, int width, | static inline DataType get(const DataType* src, int height, int width, | ||||
int channel, int h, int w, int c, float scalar) { | |||||
int channel, int h, int w, int c, float scalar, | |||||
std::function<DataType(float)> round) { | |||||
return (height >= 0 && height < h && width >= 0 && width < w) | return (height >= 0 && height < h && width >= 0 && width < w) | ||||
? src[get_offset<format>(height, width, channel, h, w, | ? src[get_offset<format>(height, width, channel, h, w, | ||||
c)] | c)] | ||||
: static_cast<DataType>(std::round(scalar)); | |||||
: static_cast<DataType>(round(scalar)); | |||||
} | } | ||||
}; | }; | ||||
#define cb(bmode) \ | |||||
template <typename DataType, param::Remap::Format format> \ | |||||
struct GetSrcData<DataType, format, param::Remap::BorderMode::bmode> { \ | |||||
static inline DataType get(const DataType* src, int height, int width, \ | |||||
int channel, int h, int w, int c, float) { \ | |||||
height = megcv::border_interpolate< \ | |||||
param::Remap::BorderMode::bmode>(height, h); \ | |||||
width = megcv::border_interpolate< \ | |||||
param::Remap::BorderMode::bmode>(width, w); \ | |||||
return src[get_offset<format>(height, width, channel, h, w, c)]; \ | |||||
} \ | |||||
}; | |||||
cb(REPLICATE); | |||||
cb(REFLECT); | |||||
cb(REFLECT_101); | |||||
cb(WRAP); | |||||
#undef cb | |||||
template <typename DataType, param::Remap::Format format, | template <typename DataType, param::Remap::Format format, | ||||
param::Remap::BorderMode bordertype> | param::Remap::BorderMode bordertype> | ||||
void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | ||||
@@ -92,20 +79,20 @@ void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | |||||
for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
DataType a00 = | DataType a00 = | ||||
GetSrcData<DataType, format, bordertype>::get( | GetSrcData<DataType, format, bordertype>::get( | ||||
src, row + 0, col + 0, c, IH, IW, C, | |||||
scalar); | |||||
src, row + 0, col + 0, c, IH, IW, C, scalar, | |||||
round); | |||||
DataType a01 = | DataType a01 = | ||||
GetSrcData<DataType, format, bordertype>::get( | GetSrcData<DataType, format, bordertype>::get( | ||||
src, row + 0, col + 1, c, IH, IW, C, | |||||
scalar); | |||||
src, row + 0, col + 1, c, IH, IW, C, scalar, | |||||
round); | |||||
DataType a10 = | DataType a10 = | ||||
GetSrcData<DataType, format, bordertype>::get( | GetSrcData<DataType, format, bordertype>::get( | ||||
src, row + 1, col + 0, c, IH, IW, C, | |||||
scalar); | |||||
src, row + 1, col + 0, c, IH, IW, C, scalar, | |||||
round); | |||||
DataType a11 = | DataType a11 = | ||||
GetSrcData<DataType, format, bordertype>::get( | GetSrcData<DataType, format, bordertype>::get( | ||||
src, row + 1, col + 1, c, IH, IW, C, | |||||
scalar); | |||||
src, row + 1, col + 1, c, IH, IW, C, scalar, | |||||
round); | |||||
dst[get_offset<format>(h, w, c, OH, OW, C)] = | dst[get_offset<format>(h, w, c, OH, OW, C)] = | ||||
static_cast<DataType>( | static_cast<DataType>( | ||||
@@ -139,11 +126,13 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
C = src.layout.shape[1]; | C = src.layout.shape[1]; | ||||
IH = src.layout.shape[2]; | IH = src.layout.shape[2]; | ||||
IW = src.layout.shape[3]; | IW = src.layout.shape[3]; | ||||
} else { | |||||
} else if (param().format == param::Remap::Format::NHWC) { | |||||
N = src.layout.shape[0]; | N = src.layout.shape[0]; | ||||
C = src.layout.shape[3]; | C = src.layout.shape[3]; | ||||
IH = src.layout.shape[1]; | IH = src.layout.shape[1]; | ||||
IW = src.layout.shape[2]; | IW = src.layout.shape[2]; | ||||
} else { | |||||
megdnn_throw("unsupported format"); | |||||
} | } | ||||
OH = map_xy.layout.shape[1]; | OH = map_xy.layout.shape[1]; | ||||
OW = map_xy.layout.shape[2]; | OW = map_xy.layout.shape[2]; | ||||
@@ -0,0 +1,203 @@ | |||||
/** | |||||
* \file dnn/test/cuda/remap.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "test/common/remap.h" | |||||
#include "test/common/benchmarker.h" | |||||
#include "test/common/checker.h" | |||||
#include "test/common/rng.h" | |||||
#include "test/cuda/benchmark.h" | |||||
#include "test/cuda/fixture.h" | |||||
namespace megdnn { | |||||
namespace test { | |||||
namespace remap { | |||||
TEST_F(CUDA, REMAP_NCHW_FLOAT) { | |||||
Checker<Remap> checker(handle_cuda()); | |||||
std::vector<TestArg> args = get_nchw_args(); | |||||
UniformFloatRNG float_rng(0, 255); | |||||
#define cb(data_type, data_rng) \ | |||||
for (auto arg : args) { \ | |||||
UniformFloatRNG map_rng( \ | |||||
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
checker.set_dtype(0, data_type) \ | |||||
.set_dtype(1, dtype::Float32()) \ | |||||
.set_dtype(2, data_type) \ | |||||
.set_rng(0, &data_rng) \ | |||||
.set_rng(1, &map_rng) \ | |||||
.set_rng(2, &data_rng) \ | |||||
.set_param(arg.param) \ | |||||
.execs({arg.src, arg.map_xy, arg.dst}); \ | |||||
} | |||||
cb(dtype::Float32(), float_rng); | |||||
cb(dtype::Float16(), float_rng); | |||||
#undef cb | |||||
} | |||||
TEST_F(CUDA, REMAP_NCHW_INT) { | |||||
Checker<Remap> checker(handle_cuda()); | |||||
std::vector<TestArg> args = get_nchw_args(); | |||||
UniformIntRNG uint8_rng(0, 255); | |||||
UniformIntRNG int8_rng(-128, 127); | |||||
#define cb(data_type, data_rng) \ | |||||
for (auto arg : args) { \ | |||||
UniformFloatRNG map_rng( \ | |||||
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
checker.set_dtype(0, data_type) \ | |||||
.set_dtype(1, dtype::Float32()) \ | |||||
.set_dtype(2, data_type) \ | |||||
.set_rng(0, &data_rng) \ | |||||
.set_rng(1, &map_rng) \ | |||||
.set_rng(2, &data_rng) \ | |||||
.set_epsilon(1) \ | |||||
.set_param(arg.param) \ | |||||
.execs({arg.src, arg.map_xy, arg.dst}); \ | |||||
} | |||||
cb(dtype::Int8(), int8_rng); | |||||
cb(dtype::Uint8(), uint8_rng); | |||||
#undef cb | |||||
} | |||||
TEST_F(CUDA, REMAP_NHWC_FLOAT) { | |||||
Checker<Remap> checker(handle_cuda()); | |||||
std::vector<TestArg> args = get_nhwc_args(); | |||||
UniformFloatRNG float_rng(0, 255); | |||||
#define cb(data_type, data_rng) \ | |||||
for (auto arg : args) { \ | |||||
UniformFloatRNG map_rng( \ | |||||
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
checker.set_dtype(0, data_type) \ | |||||
.set_dtype(1, dtype::Float32()) \ | |||||
.set_dtype(2, data_type) \ | |||||
.set_rng(0, &data_rng) \ | |||||
.set_rng(1, &map_rng) \ | |||||
.set_rng(2, &data_rng) \ | |||||
.set_param(arg.param) \ | |||||
.execs({arg.src, arg.map_xy, arg.dst}); \ | |||||
} | |||||
cb(dtype::Float32(), float_rng); | |||||
cb(dtype::Float16(), float_rng); | |||||
#undef cb | |||||
} | |||||
TEST_F(CUDA, REMAP_NHWC_INT) { | |||||
Checker<Remap> checker(handle_cuda()); | |||||
std::vector<TestArg> args = get_nhwc_args(); | |||||
UniformIntRNG uint8_rng(0, 255); | |||||
UniformIntRNG int8_rng(-128, 127); | |||||
#define cb(data_type, data_rng) \ | |||||
for (auto arg : args) { \ | |||||
UniformFloatRNG map_rng( \ | |||||
-2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
checker.set_dtype(0, data_type) \ | |||||
.set_dtype(1, dtype::Float32()) \ | |||||
.set_dtype(2, data_type) \ | |||||
.set_rng(0, &data_rng) \ | |||||
.set_rng(1, &map_rng) \ | |||||
.set_rng(2, &data_rng) \ | |||||
.set_epsilon(1) \ | |||||
.set_param(arg.param) \ | |||||
.execs({arg.src, arg.map_xy, arg.dst}); \ | |||||
} | |||||
cb(dtype::Int8(), int8_rng); | |||||
cb(dtype::Uint8(), uint8_rng); | |||||
#undef cb | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | |||||
TEST_F(CUDA, BENCHMARK_REMAP) { | |||||
using Param = param::Remap; | |||||
auto run = [&](const TensorShapeArray& shapes, Param param, DType dtype) { | |||||
auto handle_cpu = create_cpu_handle(2); | |||||
Benchmarker<Remap> benchmarker_naive(handle_cpu.get()); | |||||
CUBenchmarker<Remap> benchmarker_cuda(handle_cuda()); | |||||
UniformIntRNG rng(0, 0xff); | |||||
UniformFloatRNG map_rng( | |||||
-2, std::max(shapes[1].shape[1], shapes[1].shape[2]) + 2); | |||||
benchmarker_naive.set_rng(0, &rng); | |||||
benchmarker_cuda.set_rng(0, &rng); | |||||
benchmarker_naive.set_rng(1, &map_rng); | |||||
benchmarker_cuda.set_rng(1, &map_rng); | |||||
benchmarker_naive.set_rng(2, &rng); | |||||
benchmarker_cuda.set_rng(2, &rng); | |||||
benchmarker_naive.set_dtype(1, dtype::Float32()); | |||||
benchmarker_cuda.set_dtype(1, dtype::Float32()); | |||||
benchmarker_naive.set_dtype(0, dtype).set_dtype(2, dtype); | |||||
benchmarker_cuda.set_dtype(0, dtype).set_dtype(2, dtype); | |||||
size_t RUN = 10; | |||||
auto t1 = benchmarker_naive.set_display(false) | |||||
.set_times(RUN) | |||||
.set_param(param) | |||||
.execs(shapes); | |||||
auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs( | |||||
shapes); | |||||
const TensorShape dst_layout = shapes[2]; | |||||
float calc_amount = dst_layout.total_nr_elems(); | |||||
printf("naive={%.3fms, %.3fMflops}, " | |||||
"cuda={%.3fms, %.3fMflops}\n", | |||||
t1 / RUN, calc_amount / (t1 / RUN * 1000), t2, | |||||
calc_amount / (t2 * 1000)); | |||||
}; | |||||
Param param; | |||||
param.imode = param::Remap::InterpolationMode::LINEAR; | |||||
param.format = param::Remap::Format::NHWC; | |||||
param.border_type = param::Remap::BorderMode::CONSTANT; | |||||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||||
dtype::Float32{}); | |||||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||||
dtype::Float16{}); | |||||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||||
dtype::Uint8{}); | |||||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||||
dtype::Int8{}); | |||||
param.border_type = param::Remap::BorderMode::REPLICATE; | |||||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||||
dtype::Float32{}); | |||||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||||
dtype::Float16{}); | |||||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||||
dtype::Uint8{}); | |||||
run({{4, 200, 300, 10}, {4, 200, 300, 2}, {4, 200, 300, 10}}, param, | |||||
dtype::Int8{}); | |||||
param.format = param::Remap::Format::NCHW; | |||||
param.border_type = param::Remap::BorderMode::CONSTANT; | |||||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||||
dtype::Float32{}); | |||||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||||
dtype::Float16{}); | |||||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||||
dtype::Uint8{}); | |||||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||||
dtype::Int8{}); | |||||
param.border_type = param::Remap::BorderMode::REPLICATE; | |||||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||||
dtype::Float32{}); | |||||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||||
dtype::Float16{}); | |||||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||||
dtype::Uint8{}); | |||||
run({{4, 10, 200, 300}, {4, 200, 300, 2}, {4, 10, 200, 300}}, param, | |||||
dtype::Int8{}); | |||||
} | |||||
#endif | |||||
} // namespace remap | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |