GitOrigin-RevId: 31e7b72a78
tags/v1.9.0
@@ -18,21 +18,22 @@ namespace megdnn { | |||
void RemapBase::deduce_layout_fwd( | |||
const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst) { | |||
dst.dtype = src.dtype; | |||
dst.ndim = src.ndim; | |||
dst.shape[0] = src.shape[0]; | |||
size_t height_index, channel_index; | |||
size_t n = src.shape[0]; | |||
size_t c, oh, ow; | |||
oh = map_xy.shape[1]; | |||
ow = map_xy.shape[2]; | |||
if (param().format == param::Remap::Format::NHWC) { | |||
height_index = 1; | |||
channel_index = 3; | |||
c = src.shape[3]; | |||
dst = TensorLayout(TensorShape({n, oh, ow, c}), src.dtype); | |||
} else if (param().format == param::Remap::Format::NCHW) { | |||
c = src.shape[1]; | |||
dst = TensorLayout(TensorShape{n, c, oh, ow}, src.dtype, src.format); | |||
} else if (param().format == param::Remap::Format::NHWCD4) { | |||
c = src.shape[2]; | |||
dst = TensorLayout{{n, oh, c, ow, 4}, src.dtype, src.format}; | |||
} else { | |||
megdnn_assert(param().format == param::Remap::Format::NCHW); | |||
height_index = 2; | |||
channel_index = 1; | |||
megdnn_throw("unsupport format"); | |||
} | |||
dst.shape[height_index] = map_xy.shape[1]; | |||
dst.shape[height_index + 1] = map_xy.shape[2]; | |||
dst.shape[channel_index] = src.shape[channel_index]; | |||
} | |||
void RemapBase::check_layout_fwd( | |||
@@ -42,7 +43,7 @@ void RemapBase::check_layout_fwd( | |||
megdnn_layout_msg(dst); | |||
}; | |||
MEGDNN_MARK_USED_VAR(errmsg); | |||
megdnn_assert(src.ndim == map_xy.ndim && src.ndim == dst.ndim && src.ndim == 4); | |||
megdnn_assert(src.ndim == dst.ndim); | |||
megdnn_assert(dst.dtype == src.dtype); | |||
megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); | |||
megdnn_assert(map_xy.shape[3] == 2); | |||
@@ -64,10 +65,13 @@ void RemapBase::check_layout_fwd( | |||
megdnn_assert( | |||
dst.shape[2] == map_xy.shape[1] && dst.shape[3] == map_xy.shape[2], | |||
"%s", errmsg().c_str()); | |||
} else if (param().format == param::Remap::Format::NHWCD4) { | |||
megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); | |||
megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str()); | |||
megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); | |||
megdnn_assert(param().format == Param::Format::NHWCD4); | |||
} else { | |||
megdnn_throw( | |||
"currently do not support other param.format except NHWC and " | |||
"NCHW"); | |||
megdnn_throw("unsupport format"); | |||
} | |||
} | |||
@@ -22,8 +22,9 @@ void RemapBackwardDataImpl::exec( | |||
_megdnn_workspace workspace) { | |||
check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size); | |||
megdnn_assert( | |||
param().imode == param::Remap::InterpolationMode::LINEAR, | |||
"only support LINEAR interpolationMode"); | |||
(param().imode == param::Remap::InterpolationMode::NEAREST) || | |||
(param().imode == param::Remap::InterpolationMode::LINEAR), | |||
"only support NEAREST and LINEAR interpolationMode"); | |||
megdnn_assert( | |||
param().format == param::Remap::Format::NCHW, | |||
"only support NCHW format for remap backward"); | |||
@@ -36,13 +37,15 @@ void RemapBackwardDataImpl::exec( | |||
OH = map_xy.layout.shape[1]; | |||
OW = map_xy.layout.shape[2]; | |||
#define cb(dt, _format, bmode) \ | |||
#define cb(dt, _format, bmode, inter_mode) \ | |||
if (param().format == param::Remap::Format::_format && \ | |||
param().border_type == param::Remap::BorderMode::bmode) { \ | |||
param().border_type == param::Remap::BorderMode::bmode && \ | |||
param().imode == param::Remap::InterpolationMode::inter_mode) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
remap::backwarddata_proxy< \ | |||
ctype, param_enumv::Remap::Format::_format, \ | |||
::BorderMode::BORDER_##bmode>( \ | |||
::BorderMode::BORDER_##bmode, \ | |||
::InterpolationMode::INTER_##inter_mode>( \ | |||
grad.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | |||
diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, stream); \ | |||
break; \ | |||
@@ -50,11 +53,16 @@ void RemapBackwardDataImpl::exec( | |||
#define support_dtype(dt) \ | |||
case DTypeTrait<dt>::enumv: { \ | |||
cb(dt, NCHW, CONSTANT); \ | |||
cb(dt, NCHW, REPLICATE); \ | |||
cb(dt, NCHW, REFLECT); \ | |||
cb(dt, NCHW, REFLECT_101); \ | |||
cb(dt, NCHW, WRAP); \ | |||
cb(dt, NCHW, CONSTANT, NEAREST); \ | |||
cb(dt, NCHW, REPLICATE, NEAREST); \ | |||
cb(dt, NCHW, REFLECT, NEAREST); \ | |||
cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||
cb(dt, NCHW, WRAP, NEAREST); \ | |||
cb(dt, NCHW, CONSTANT, LINEAR); \ | |||
cb(dt, NCHW, REPLICATE, LINEAR); \ | |||
cb(dt, NCHW, REFLECT, LINEAR); \ | |||
cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||
cb(dt, NCHW, WRAP, LINEAR); \ | |||
megdnn_throw("unsupported border type in remap cuda"); \ | |||
} | |||
@@ -52,8 +52,49 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||
} | |||
}; | |||
__device__ inline float round_half_to_even(float f) { | |||
const float round_away_from_zero = round(f); | |||
const float diff = round_away_from_zero - f; | |||
if ((diff != 0.5f) && (diff != -0.5f)) { | |||
return round_away_from_zero; | |||
} | |||
if (fmod(round_away_from_zero, 2.0f) == 0.0f) { | |||
return round_away_from_zero; | |||
} | |||
return f - diff; | |||
} | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
__global__ void kern_general_nearest( | |||
ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH, | |||
int IW, int OH, int OW) { | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
grad += blockIdx.z * C * IH * IW; | |||
diff += blockIdx.z * C * OH * OW; | |||
map_xy += blockIdx.z * 2 * OH * OW; | |||
if (ow < OW && oh < OH) { | |||
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||
int col = static_cast<int>(round_half_to_even(index_col)); | |||
int row = static_cast<int>(round_half_to_even(index_row)); | |||
for (int c = 0; c < C; ++c) { | |||
ctype hidden = diff[get_offset<format>(oh, ow, c, OH, OW, C)]; | |||
int idx = | |||
GetSrcData<ctype, format, bmode>::get_index(row, col, c, IH, IW, C); | |||
if (idx != -1) { | |||
atomic_add(grad + idx, hidden); | |||
} | |||
} | |||
} | |||
} | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
__global__ void kern_general( | |||
__global__ void kern_general_linear( | |||
ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH, | |||
int IW, int OH, int OW) { | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
@@ -93,8 +134,8 @@ __global__ void kern_general( | |||
atomic_add(grad + a10, round_converter(u * (one - v) * hidden)); | |||
} | |||
int a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>:: | |||
get_index(row + 1, col + 1, c, IH, IW, C); | |||
int a11 = GetSrcData<ctype, format, bmode>::get_index( | |||
row + 1, col + 1, c, IH, IW, C); | |||
if (a11 != -1) { | |||
atomic_add(grad + a11, round_converter(u * v * hidden)); | |||
} | |||
@@ -102,7 +143,9 @@ __global__ void kern_general( | |||
} | |||
} | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void dispatch_backwarddata( | |||
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | |||
int IW, int OH, int OW, cudaStream_t stream) { | |||
@@ -115,8 +158,13 @@ void dispatch_backwarddata( | |||
cuda_check(cudaMemsetAsync( | |||
grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW, stream)); | |||
kern_general<ctype, format, bmode> | |||
<<<blocks, threads, 0, stream>>>(grad, map_xy, diff, C, IH, IW, OH, OW); | |||
if (imode == ::InterpolationMode::INTER_NEAREST) { | |||
kern_general_nearest<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||
grad, map_xy, diff, C, IH, IW, OH, OW); | |||
} else if (imode == ::InterpolationMode::INTER_LINEAR) { | |||
kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||
grad, map_xy, diff, C, IH, IW, OH, OW); | |||
} | |||
N -= curr_batch_size; | |||
grad += curr_batch_size * C * IH * IW; | |||
@@ -131,27 +179,35 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace remap { | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void backwarddata_proxy( | |||
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | |||
int IW, int OH, int OW, cudaStream_t stream) { | |||
dispatch_backwarddata<ctype, format, bmode>( | |||
dispatch_backwarddata<ctype, format, bmode, imode>( | |||
grad, map_xy, diff, N, C, IH, IW, OH, OW, stream); | |||
after_kernel_launch(); | |||
} | |||
#define INST(ctype, format, bmode) \ | |||
#define INST(ctype, format, bmode, imode) \ | |||
template void backwarddata_proxy< \ | |||
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ | |||
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ | |||
::InterpolationMode::imode>( \ | |||
ctype*, const float*, const ctype*, int, int, int, int, int, int, \ | |||
cudaStream_t); | |||
#define FOR_FORMAT_BMODE(ctype) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE) \ | |||
INST(ctype, NCHW, BORDER_REFLECT) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||
INST(ctype, NCHW, BORDER_WRAP) | |||
#define FOR_FORMAT_BMODE(ctype) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) | |||
FOR_FORMAT_BMODE(float) | |||
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | |||
@@ -22,8 +22,9 @@ void RemapBackwardMatImpl::exec( | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, workspace.size); | |||
megdnn_assert( | |||
param().imode == param::Remap::InterpolationMode::LINEAR, | |||
"only support LINEAR interpolationMode"); | |||
(param().imode == param::Remap::InterpolationMode::NEAREST) || | |||
(param().imode == param::Remap::InterpolationMode::LINEAR), | |||
"only support NEAREST and LINEAR interpolationMode"); | |||
megdnn_assert( | |||
param().format == param::Remap::Format::NCHW, | |||
"only support NCHW format for remap backward"); | |||
@@ -36,13 +37,15 @@ void RemapBackwardMatImpl::exec( | |||
OH = map_xy.layout.shape[1]; | |||
OW = map_xy.layout.shape[2]; | |||
#define cb(dt, _format, bmode) \ | |||
#define cb(dt, _format, bmode, inter_mode) \ | |||
if (param().format == param::Remap::Format::_format && \ | |||
param().border_type == param::Remap::BorderMode::bmode) { \ | |||
param().border_type == param::Remap::BorderMode::bmode && \ | |||
param().imode == param::Remap::InterpolationMode::inter_mode) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
remap::backwardmat_proxy< \ | |||
ctype, param_enumv::Remap::Format::_format, \ | |||
::BorderMode::BORDER_##bmode>( \ | |||
::BorderMode::BORDER_##bmode, \ | |||
::InterpolationMode::INTER_##inter_mode>( \ | |||
src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | |||
diff.compatible_ptr<ctype>(), grad.compatible_ptr<dt_float32>(), N, C, \ | |||
IH, IW, OH, OW, param().scalar, stream); \ | |||
@@ -51,11 +54,16 @@ void RemapBackwardMatImpl::exec( | |||
#define support_dtype(dt) \ | |||
case DTypeTrait<dt>::enumv: { \ | |||
cb(dt, NCHW, CONSTANT); \ | |||
cb(dt, NCHW, REPLICATE); \ | |||
cb(dt, NCHW, REFLECT); \ | |||
cb(dt, NCHW, REFLECT_101); \ | |||
cb(dt, NCHW, WRAP); \ | |||
cb(dt, NCHW, CONSTANT, NEAREST); \ | |||
cb(dt, NCHW, REPLICATE, NEAREST); \ | |||
cb(dt, NCHW, REFLECT, NEAREST); \ | |||
cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||
cb(dt, NCHW, WRAP, NEAREST); \ | |||
cb(dt, NCHW, CONSTANT, LINEAR); \ | |||
cb(dt, NCHW, REPLICATE, LINEAR); \ | |||
cb(dt, NCHW, REFLECT, LINEAR); \ | |||
cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||
cb(dt, NCHW, WRAP, LINEAR); \ | |||
megdnn_throw("unsupported border type in remap cuda"); \ | |||
} | |||
@@ -53,7 +53,7 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||
}; | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
__global__ void kern_general( | |||
__global__ void kern_general_linear( | |||
const ctype* src, const float* map_xy, const ctype* diff, | |||
float* __restrict grad, int C, int IH, int IW, int OH, int OW, float scalar) { | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
@@ -62,7 +62,6 @@ __global__ void kern_general( | |||
diff += blockIdx.z * C * OH * OW; | |||
map_xy += blockIdx.z * 2 * OH * OW; | |||
grad += blockIdx.z * 2 * OH * OW; | |||
RoundingConverter<ctype> round_converter; | |||
if (ow < OW && oh < OH) { | |||
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
@@ -86,23 +85,25 @@ __global__ void kern_general( | |||
int a11 = GetSrcData<ctype, format, bmode>::get_index( | |||
row + 1, col + 1, c, IH, IW, C); | |||
dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); | |||
dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); | |||
dv -= ((a10 != -1) ? src[a10] : scalar) * u; | |||
dv += ((a11 != -1) ? src[a11] : scalar) * u; | |||
dv -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * (one - u); | |||
dv += ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * (one - u); | |||
dv -= ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * u; | |||
dv += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * u; | |||
du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); | |||
du -= ((a01 != -1) ? src[a01] : scalar) * v; | |||
du += ((a10 != -1) ? src[a10] : scalar) * (one - v); | |||
du += ((a11 != -1) ? src[a11] : scalar) * v; | |||
du -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * (one - v); | |||
du -= ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * v; | |||
du += ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * (one - v); | |||
du += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * v; | |||
grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv); | |||
grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du); | |||
grad[oh * OW * 2 + ow * 2 + 0] += hidden * dv; | |||
grad[oh * OW * 2 + ow * 2 + 1] += hidden * du; | |||
} | |||
} | |||
} | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void dispatch_backwardmat( | |||
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | |||
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { | |||
@@ -115,8 +116,11 @@ void dispatch_backwardmat( | |||
cuda_check(cudaMemsetAsync( | |||
grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, stream)); | |||
kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||
src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar); | |||
if (imode == ::InterpolationMode::INTER_LINEAR) { | |||
kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||
src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar); | |||
} | |||
N -= curr_batch_size; | |||
src += curr_batch_size * C * IH * IW; | |||
@@ -132,27 +136,35 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace remap { | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void backwardmat_proxy( | |||
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | |||
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { | |||
dispatch_backwardmat<ctype, format, bmode>( | |||
dispatch_backwardmat<ctype, format, bmode, imode>( | |||
src, map_xy, diff, grad, N, C, IH, IW, OH, OW, scalar, stream); | |||
after_kernel_launch(); | |||
} | |||
#define INST(ctype, format, bmode) \ | |||
template void \ | |||
backwardmat_proxy<ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ | |||
#define INST(ctype, format, bmode, imode) \ | |||
template void backwardmat_proxy< \ | |||
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ | |||
::InterpolationMode::imode>( \ | |||
const ctype*, const float*, const ctype*, float*, int, int, int, int, int, \ | |||
int, float, cudaStream_t); | |||
#define FOR_FORMAT_BMODE(ctype) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE) \ | |||
INST(ctype, NCHW, BORDER_REFLECT) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||
INST(ctype, NCHW, BORDER_WRAP) | |||
#define FOR_FORMAT_BMODE(ctype) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) | |||
FOR_FORMAT_BMODE(float) | |||
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | |||
@@ -21,17 +21,23 @@ namespace remap { | |||
// all these kernels use LINEAR interpolation | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void forward_proxy( | |||
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | |||
int OH, int OW, float scalar, cudaStream_t stream); | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void backwarddata_proxy( | |||
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | |||
int IW, int OH, int OW, cudaStream_t stream); | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void backwardmat_proxy( | |||
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | |||
int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream); | |||
@@ -30,8 +30,9 @@ void RemapImpl::exec( | |||
OW = map_xy.layout.shape[2]; | |||
megdnn_assert( | |||
param().imode == param::Remap::InterpolationMode::LINEAR, | |||
"only support LINEAR interpolationMode"); | |||
(param().imode == param::Remap::InterpolationMode::NEAREST) || | |||
(param().imode == param::Remap::InterpolationMode::LINEAR), | |||
"only support NEAREST and LINEAR interpolationMode"); | |||
if (param().format == param::Remap::Format::NCHW) { | |||
N = src.layout.shape[0]; | |||
@@ -47,13 +48,15 @@ void RemapImpl::exec( | |||
megdnn_throw("unsupported format, cuda remap"); | |||
} | |||
#define cb(dt, _format, bmode) \ | |||
#define cb(dt, _format, bmode, inter_mode) \ | |||
if (param().format == param::Remap::Format::_format && \ | |||
param().border_type == param::Remap::BorderMode::bmode) { \ | |||
param().border_type == param::Remap::BorderMode::bmode && \ | |||
param().imode == param::Remap::InterpolationMode::inter_mode) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
remap::forward_proxy< \ | |||
ctype, param_enumv::Remap::Format::_format, \ | |||
::BorderMode::BORDER_##bmode>( \ | |||
::BorderMode::BORDER_##bmode, \ | |||
::InterpolationMode::INTER_##inter_mode>( \ | |||
src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | |||
dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, param().scalar, \ | |||
stream); \ | |||
@@ -62,16 +65,26 @@ void RemapImpl::exec( | |||
#define support_dtype(dt) \ | |||
case DTypeTrait<dt>::enumv: { \ | |||
cb(dt, NCHW, CONSTANT); \ | |||
cb(dt, NCHW, REPLICATE); \ | |||
cb(dt, NCHW, REFLECT); \ | |||
cb(dt, NCHW, REFLECT_101); \ | |||
cb(dt, NCHW, WRAP); \ | |||
cb(dt, NHWC, CONSTANT); \ | |||
cb(dt, NHWC, REPLICATE); \ | |||
cb(dt, NHWC, REFLECT); \ | |||
cb(dt, NHWC, REFLECT_101); \ | |||
cb(dt, NHWC, WRAP); \ | |||
cb(dt, NCHW, CONSTANT, NEAREST); \ | |||
cb(dt, NCHW, REPLICATE, NEAREST); \ | |||
cb(dt, NCHW, REFLECT, NEAREST); \ | |||
cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||
cb(dt, NCHW, WRAP, NEAREST); \ | |||
cb(dt, NHWC, CONSTANT, NEAREST); \ | |||
cb(dt, NHWC, REPLICATE, NEAREST); \ | |||
cb(dt, NHWC, REFLECT, NEAREST); \ | |||
cb(dt, NHWC, REFLECT_101, NEAREST); \ | |||
cb(dt, NHWC, WRAP, NEAREST); \ | |||
cb(dt, NCHW, CONSTANT, LINEAR); \ | |||
cb(dt, NCHW, REPLICATE, LINEAR); \ | |||
cb(dt, NCHW, REFLECT, LINEAR); \ | |||
cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||
cb(dt, NCHW, WRAP, LINEAR); \ | |||
cb(dt, NHWC, CONSTANT, LINEAR); \ | |||
cb(dt, NHWC, REPLICATE, LINEAR); \ | |||
cb(dt, NHWC, REFLECT, LINEAR); \ | |||
cb(dt, NHWC, REFLECT_101, LINEAR); \ | |||
cb(dt, NHWC, WRAP, LINEAR); \ | |||
megdnn_throw("unsupported border type in remap cuda"); \ | |||
} | |||
@@ -62,8 +62,23 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||
} | |||
}; | |||
template <typename ctype, ::BorderMode bmode> | |||
__global__ void kern_general( | |||
__device__ inline float round_half_to_even(float f) { | |||
const float round_away_from_zero = round(f); | |||
const float diff = round_away_from_zero - f; | |||
if ((diff != 0.5f) && (diff != -0.5f)) { | |||
return round_away_from_zero; | |||
} | |||
if (fmod(round_away_from_zero, 2.0f) == 0.0f) { | |||
return round_away_from_zero; | |||
} | |||
return f - diff; | |||
} | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
__global__ void kern_general_nearest( | |||
const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C, | |||
int IH, int IW, int OH, int OW, float scalar) { | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
@@ -71,37 +86,22 @@ __global__ void kern_general( | |||
sptr += blockIdx.z * C * IH * IW; | |||
dst += blockIdx.z * C * OH * OW; | |||
map_xy += blockIdx.z * 2 * OH * OW; | |||
RoundingConverter<ctype> round_converter; | |||
if (ow < OW && oh < OH) { | |||
float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||
int col = static_cast<int>(floor(index_col)); | |||
int row = static_cast<int>(floor(index_row)); | |||
float v = index_col - col; | |||
float u = index_row - row; | |||
int col = static_cast<int>(round_half_to_even(index_col)); | |||
int row = static_cast<int>(round_half_to_even(index_row)); | |||
for (int c = 0; c < C; ++c) { | |||
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get( | |||
sptr, row + 0, col + 0, c, IH, IW, C, scalar); | |||
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get( | |||
sptr, row + 0, col + 1, c, IH, IW, C, scalar); | |||
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get( | |||
sptr, row + 1, col + 0, c, IH, IW, C, scalar); | |||
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get( | |||
sptr, row + 1, col + 1, c, IH, IW, C, scalar); | |||
/* in remap, we use float as the type of intermediate result */ | |||
float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | |||
static_cast<float>(a01) * (1.f - u) * v + | |||
static_cast<float>(a10) * (1.f - v) * u + | |||
static_cast<float>(a11) * u * v; | |||
dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW, C)] = | |||
round_converter(result); | |||
dst[get_offset<format>(oh, ow, c, OH, OW, C)] = | |||
GetSrcData<ctype, format, bmode>::get( | |||
sptr, row, col, c, IH, IW, C, scalar); | |||
} | |||
} | |||
} | |||
template <typename ctype, ::BorderMode bmode> | |||
__global__ void kern_general_nhwc( | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
__global__ void kern_general_linear( | |||
const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C, | |||
int IH, int IW, int OH, int OW, float scalar) { | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
@@ -119,26 +119,27 @@ __global__ void kern_general_nhwc( | |||
float v = index_col - col; | |||
float u = index_row - row; | |||
for (int c = 0; c < C; ++c) { | |||
ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get( | |||
ctype a00 = GetSrcData<ctype, format, bmode>::get( | |||
sptr, row + 0, col + 0, c, IH, IW, C, scalar); | |||
ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get( | |||
ctype a01 = GetSrcData<ctype, format, bmode>::get( | |||
sptr, row + 0, col + 1, c, IH, IW, C, scalar); | |||
ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get( | |||
ctype a10 = GetSrcData<ctype, format, bmode>::get( | |||
sptr, row + 1, col + 0, c, IH, IW, C, scalar); | |||
ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get( | |||
ctype a11 = GetSrcData<ctype, format, bmode>::get( | |||
sptr, row + 1, col + 1, c, IH, IW, C, scalar); | |||
/* in remap, we use float as the type of intermediate result */ | |||
float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | |||
static_cast<float>(a01) * (1.f - u) * v + | |||
static_cast<float>(a10) * (1.f - v) * u + | |||
static_cast<float>(a11) * u * v; | |||
dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW, C)] = | |||
round_converter(result); | |||
dst[get_offset<format>(oh, ow, c, OH, OW, C)] = round_converter(result); | |||
} | |||
} | |||
} | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void dispatch_forward( | |||
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | |||
int OH, int OW, float scalar, cudaStream_t stream) { | |||
@@ -150,11 +151,11 @@ void dispatch_forward( | |||
dim3 threads(BX, BY); | |||
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||
if (format == param_enumv::Remap::Format::NCHW) { | |||
kern_general<ctype, bmode><<<blocks, threads, 0, stream>>>( | |||
if (imode == ::InterpolationMode::INTER_NEAREST) { | |||
kern_general_nearest<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||
src, map_xy, dst, C, IH, IW, OH, OW, scalar); | |||
} else if (format == param_enumv::Remap::Format::NHWC) { | |||
kern_general_nhwc<ctype, bmode><<<blocks, threads, 0, stream>>>( | |||
} else if (imode == ::InterpolationMode::INTER_LINEAR) { | |||
kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||
src, map_xy, dst, C, IH, IW, OH, OW, scalar); | |||
} | |||
@@ -171,32 +172,45 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace remap { | |||
template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
template < | |||
typename ctype, const uint32_t format, ::BorderMode bmode, | |||
::InterpolationMode imode> | |||
void forward_proxy( | |||
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | |||
int OH, int OW, float scalar, cudaStream_t stream) { | |||
dispatch_forward<ctype, format, bmode>( | |||
dispatch_forward<ctype, format, bmode, imode>( | |||
src, map_xy, dst, N, C, IH, IW, OH, OW, scalar, stream); | |||
after_kernel_launch(); | |||
} | |||
#define INST(ctype, format, bmode) \ | |||
template void \ | |||
forward_proxy<ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ | |||
#define INST(ctype, format, bmode, imode) \ | |||
template void forward_proxy< \ | |||
ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ | |||
::InterpolationMode::imode>( \ | |||
const ctype*, const float*, ctype*, int, int, int, int, int, int, float, \ | |||
cudaStream_t); | |||
#define FOR_FORMAT_BMODE(ctype) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE) \ | |||
INST(ctype, NCHW, BORDER_REFLECT) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||
INST(ctype, NCHW, BORDER_WRAP) \ | |||
INST(ctype, NHWC, BORDER_CONSTANT) \ | |||
INST(ctype, NHWC, BORDER_REPLICATE) \ | |||
INST(ctype, NHWC, BORDER_REFLECT) \ | |||
INST(ctype, NHWC, BORDER_REFLECT_101) \ | |||
INST(ctype, NHWC, BORDER_WRAP) | |||
#define FOR_FORMAT_BMODE(ctype) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ | |||
INST(ctype, NHWC, BORDER_CONSTANT, INTER_NEAREST) \ | |||
INST(ctype, NHWC, BORDER_REPLICATE, INTER_NEAREST) \ | |||
INST(ctype, NHWC, BORDER_REFLECT, INTER_NEAREST) \ | |||
INST(ctype, NHWC, BORDER_REFLECT_101, INTER_NEAREST) \ | |||
INST(ctype, NHWC, BORDER_WRAP, INTER_NEAREST) \ | |||
INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ | |||
INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) \ | |||
INST(ctype, NHWC, BORDER_CONSTANT, INTER_LINEAR) \ | |||
INST(ctype, NHWC, BORDER_REPLICATE, INTER_LINEAR) \ | |||
INST(ctype, NHWC, BORDER_REFLECT, INTER_LINEAR) \ | |||
INST(ctype, NHWC, BORDER_REFLECT_101, INTER_LINEAR) \ | |||
INST(ctype, NHWC, BORDER_WRAP, INTER_LINEAR) | |||
FOR_FORMAT_BMODE(float) | |||
DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | |||
@@ -36,6 +36,12 @@ inline int get_offset<param::Remap::Format::NHWC>( | |||
return height * w * c + width * c + channel; | |||
} | |||
template <> | |||
inline int get_offset<param::Remap::Format::NHWCD4>( | |||
int height, int width, int channel, int, int w, int c) { | |||
return ((height * c + channel) * w + width) * 4; | |||
} | |||
template < | |||
typename ctype, param::Remap::Format format, | |||
param::Remap::BorderMode bordertype> | |||
@@ -80,8 +86,12 @@ void remap_LINEAR( | |||
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | |||
int OH, int OW, float scalar) { | |||
RoundingConverter<ctype> round_converter; | |||
for (int n = 0; n < N; | |||
++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) { | |||
size_t c_scale = 1; | |||
if (format == param::Remap::Format::NHWCD4) { | |||
c_scale = 4; | |||
} | |||
for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW, | |||
dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) { | |||
for (int h = 0; h < OH; ++h) { | |||
for (int w = 0; w < OW; ++w) { | |||
float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||
@@ -92,18 +102,102 @@ void remap_LINEAR( | |||
float u = index_row - row; // alphah | |||
const float one = 1.f; | |||
for (int c = 0; c < C; ++c) { | |||
ctype a00 = GetSrcData<ctype, format, bordertype>::get( | |||
src, row + 0, col + 0, c, IH, IW, C, scalar); | |||
ctype a01 = GetSrcData<ctype, format, bordertype>::get( | |||
src, row + 0, col + 1, c, IH, IW, C, scalar); | |||
ctype a10 = GetSrcData<ctype, format, bordertype>::get( | |||
src, row + 1, col + 0, c, IH, IW, C, scalar); | |||
ctype a11 = GetSrcData<ctype, format, bordertype>::get( | |||
src, row + 1, col + 1, c, IH, IW, C, scalar); | |||
dst[get_offset<format>(h, w, c, OH, OW, C)] = round_converter( | |||
a00 * (one - v) * (one - u) + a01 * (one - u) * v + | |||
a10 * (one - v) * u + a11 * u * v); | |||
if (format == param::Remap::Format::NHWCD4) { | |||
int idx00 = GetSrcData<ctype, format, bordertype>::get_index( | |||
row + 0, col + 0, c, IH, IW, C); | |||
int idx01 = GetSrcData<ctype, format, bordertype>::get_index( | |||
row + 0, col + 1, c, IH, IW, C); | |||
int idx10 = GetSrcData<ctype, format, bordertype>::get_index( | |||
row + 1, col + 0, c, IH, IW, C); | |||
int idx11 = GetSrcData<ctype, format, bordertype>::get_index( | |||
row + 1, col + 1, c, IH, IW, C); | |||
for (int c_inner = 0; c_inner < 4; ++c_inner) { | |||
ctype a00 = (idx00 != -1) ? src[idx00 + c_inner] | |||
: round_converter(scalar); | |||
ctype a01 = (idx01 != -1) ? src[idx01 + c_inner] | |||
: round_converter(scalar); | |||
ctype a10 = (idx10 != -1) ? src[idx10 + c_inner] | |||
: round_converter(scalar); | |||
ctype a11 = (idx11 != -1) ? src[idx11 + c_inner] | |||
: round_converter(scalar); | |||
dst[get_offset<format>(h, w, c, OH, OW, C) + c_inner] = | |||
round_converter( | |||
a00 * (one - v) * (one - u) + | |||
a01 * (one - u) * v + a10 * (one - v) * u + | |||
a11 * u * v); | |||
} | |||
} else { | |||
ctype a00 = GetSrcData<ctype, format, bordertype>::get( | |||
src, row + 0, col + 0, c, IH, IW, C, scalar); | |||
ctype a01 = GetSrcData<ctype, format, bordertype>::get( | |||
src, row + 0, col + 1, c, IH, IW, C, scalar); | |||
ctype a10 = GetSrcData<ctype, format, bordertype>::get( | |||
src, row + 1, col + 0, c, IH, IW, C, scalar); | |||
ctype a11 = GetSrcData<ctype, format, bordertype>::get( | |||
src, row + 1, col + 1, c, IH, IW, C, scalar); | |||
dst[get_offset<format>(h, w, c, OH, OW, C)] = round_converter( | |||
a00 * (one - v) * (one - u) + a01 * (one - u) * v + | |||
a10 * (one - v) * u + a11 * u * v); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
namespace { | |||
inline float round_half_to_even(float f) { | |||
const float round_away_from_zero = std::round(f); | |||
const float diff = round_away_from_zero - f; | |||
if ((diff != 0.5f) && (diff != -0.5f)) { | |||
return round_away_from_zero; | |||
} | |||
if (std::fmod(round_away_from_zero, 2.0f) == 0.0f) { | |||
return round_away_from_zero; | |||
} | |||
return f - diff; | |||
} | |||
} // anonymous namespace | |||
template < | |||
typename ctype, param::Remap::Format format, | |||
param::Remap::BorderMode bordertype> | |||
void remap_NEAREST( | |||
const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | |||
int OH, int OW, float scalar) { | |||
RoundingConverter<ctype> round_converter; | |||
size_t c_scale = 1; | |||
if (format == param::Remap::Format::NHWCD4) { | |||
c_scale = 4; | |||
} | |||
for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW, | |||
dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) { | |||
for (int h = 0; h < OH; ++h) { | |||
for (int w = 0; w < OW; ++w) { | |||
float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||
float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||
int col = static_cast<int>(round_half_to_even(index_col)); | |||
int row = static_cast<int>(round_half_to_even(index_row)); | |||
for (int c = 0; c < C; ++c) { | |||
if (format == param::Remap::Format::NHWCD4) { | |||
int idx = GetSrcData<ctype, format, bordertype>::get_index( | |||
row, col, c, IH, IW, C); | |||
for (int c_inner = 0; c_inner < 4; ++c_inner) { | |||
dst[get_offset<format>(h, w, c, OH, OW, C) + c_inner] = | |||
(idx != -1) ? (src[idx + c_inner]) | |||
: round_converter(scalar); | |||
} | |||
} else { | |||
dst[get_offset<format>(h, w, c, OH, OW, C)] = | |||
GetSrcData<ctype, format, bordertype>::get( | |||
src, row, col, c, IH, IW, C, scalar); | |||
} | |||
} | |||
} | |||
} | |||
@@ -164,10 +258,37 @@ void remap_LINEAR_backwarddata( | |||
template < | |||
typename ctype, param::Remap::Format format, | |||
param::Remap::BorderMode bordertype> | |||
void remap_NEAREST_backwarddata( | |||
ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | |||
int IW, int OH, int OW) { | |||
std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW); | |||
for (int n = 0; n < N; | |||
++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) { | |||
for (int h = 0; h < OH; ++h) { | |||
for (int w = 0; w < OW; ++w) { | |||
float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||
float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||
int col = static_cast<int>(round_half_to_even(index_col)); | |||
int row = static_cast<int>(round_half_to_even(index_row)); | |||
for (int c = 0; c < C; ++c) { | |||
ctype hidden = diff[get_offset<format>(h, w, c, OH, OW, C)]; | |||
int idx = GetSrcData<ctype, format, bordertype>::get_index( | |||
row, col, c, IH, IW, C); | |||
if (idx != -1) { | |||
grad[idx] += hidden; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
template < | |||
typename ctype, param::Remap::Format format, | |||
param::Remap::BorderMode bordertype> | |||
void remap_LINEAR_backwardmat( | |||
const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | |||
int C, int IH, int IW, int OH, int OW, float scalar) { | |||
RoundingConverter<ctype> round_converter; | |||
std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); | |||
for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW, | |||
map_xy += OH * OW * 2, grad += OH * OW * 2) { | |||
@@ -194,24 +315,38 @@ void remap_LINEAR_backwardmat( | |||
int a11 = GetSrcData<ctype, format, bordertype>::get_index( | |||
row + 1, col + 1, c, IH, IW, C); | |||
dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); | |||
dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); | |||
dv -= ((a10 != -1) ? src[a10] : scalar) * u; | |||
dv += ((a11 != -1) ? src[a11] : scalar) * u; | |||
du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); | |||
du -= ((a01 != -1) ? src[a01] : scalar) * v; | |||
du += ((a10 != -1) ? src[a10] : scalar) * (one - v); | |||
du += ((a11 != -1) ? src[a11] : scalar) * v; | |||
grad[h * OW * 2 + w * 2 + 0] += round_converter(hidden * dv); | |||
grad[h * OW * 2 + w * 2 + 1] += round_converter(hidden * du); | |||
dv -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * | |||
(one - u); | |||
dv += ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * | |||
(one - u); | |||
dv -= ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * u; | |||
dv += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * u; | |||
du -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * | |||
(one - v); | |||
du -= ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * v; | |||
du += ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * | |||
(one - v); | |||
du += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * v; | |||
grad[h * OW * 2 + w * 2 + 0] += hidden * dv; | |||
grad[h * OW * 2 + w * 2 + 1] += hidden * du; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
template < | |||
typename ctype, param::Remap::Format format, | |||
param::Remap::BorderMode bordertype> | |||
void remap_NEAREST_backwardmat( | |||
const ctype*, const float*, const ctype*, float* grad, int N, int, int, int, | |||
int OH, int OW, float) { | |||
std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); | |||
return; | |||
} | |||
} // namespace | |||
void RemapImpl::exec( | |||
@@ -229,6 +364,11 @@ void RemapImpl::exec( | |||
C = src.layout.shape[3]; | |||
IH = src.layout.shape[1]; | |||
IW = src.layout.shape[2]; | |||
} else if (param().format == param::Remap::Format::NHWCD4) { | |||
N = src.layout.shape[0]; | |||
C = src.layout.shape[2]; | |||
IH = src.layout.shape[1]; | |||
IW = src.layout.shape[3]; | |||
} else { | |||
megdnn_throw("unsupported format"); | |||
} | |||
@@ -255,11 +395,31 @@ void RemapImpl::exec( | |||
cb(dt, NCHW, REFLECT, LINEAR); \ | |||
cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||
cb(dt, NCHW, WRAP, LINEAR); \ | |||
cb(dt, NHWCD4, CONSTANT, LINEAR); \ | |||
cb(dt, NHWCD4, REPLICATE, LINEAR); \ | |||
cb(dt, NHWCD4, REFLECT, LINEAR); \ | |||
cb(dt, NHWCD4, REFLECT_101, LINEAR); \ | |||
cb(dt, NHWCD4, WRAP, LINEAR); \ | |||
cb(dt, NHWC, CONSTANT, LINEAR); \ | |||
cb(dt, NHWC, REPLICATE, LINEAR); \ | |||
cb(dt, NHWC, REFLECT, LINEAR); \ | |||
cb(dt, NHWC, REFLECT_101, LINEAR); \ | |||
cb(dt, NHWC, WRAP, LINEAR); \ | |||
cb(dt, NCHW, CONSTANT, NEAREST); \ | |||
cb(dt, NCHW, REPLICATE, NEAREST); \ | |||
cb(dt, NCHW, REFLECT, NEAREST); \ | |||
cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||
cb(dt, NCHW, WRAP, NEAREST); \ | |||
cb(dt, NHWCD4, CONSTANT, NEAREST); \ | |||
cb(dt, NHWCD4, REPLICATE, NEAREST); \ | |||
cb(dt, NHWCD4, REFLECT, NEAREST); \ | |||
cb(dt, NHWCD4, REFLECT_101, NEAREST); \ | |||
cb(dt, NHWCD4, WRAP, NEAREST); \ | |||
cb(dt, NHWC, CONSTANT, NEAREST); \ | |||
cb(dt, NHWC, REPLICATE, NEAREST); \ | |||
cb(dt, NHWC, REFLECT, NEAREST); \ | |||
cb(dt, NHWC, REFLECT_101, NEAREST); \ | |||
cb(dt, NHWC, WRAP, NEAREST); \ | |||
megdnn_throw( \ | |||
"format, border type or imode is incorrect in remap navie " \ | |||
"with dtype = " #dt); \ | |||
@@ -313,6 +473,11 @@ void RemapBackwardDataImpl::exec( | |||
cb(dt, NCHW, REFLECT, LINEAR); \ | |||
cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||
cb(dt, NCHW, WRAP, LINEAR); \ | |||
cb(dt, NCHW, CONSTANT, NEAREST); \ | |||
cb(dt, NCHW, REPLICATE, NEAREST); \ | |||
cb(dt, NCHW, REFLECT, NEAREST); \ | |||
cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||
cb(dt, NCHW, WRAP, NEAREST); \ | |||
megdnn_throw( \ | |||
"format, border type or imode is incorrect in remap navie " \ | |||
"with dtype = " #dt); \ | |||
@@ -365,6 +530,11 @@ void RemapBackwardMatImpl::exec( | |||
cb(dt, NCHW, REFLECT, LINEAR); \ | |||
cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||
cb(dt, NCHW, WRAP, LINEAR); \ | |||
cb(dt, NCHW, CONSTANT, NEAREST); \ | |||
cb(dt, NCHW, REPLICATE, NEAREST); \ | |||
cb(dt, NCHW, REFLECT, NEAREST); \ | |||
cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||
cb(dt, NCHW, WRAP, NEAREST); \ | |||
megdnn_throw( \ | |||
"format, border type or imode is incorrect in remap navie " \ | |||
"with dtype = " #dt); \ | |||
@@ -34,53 +34,91 @@ static inline std::vector<TestArg> get_nchw_args() { | |||
param::Remap param; | |||
std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NCHW}; | |||
std::vector<param::Remap::InterpolationMode> interp_mode_vec = { | |||
param::Remap::InterpolationMode::NEAREST, | |||
param::Remap::InterpolationMode::LINEAR}; | |||
std::vector<param::Remap::BorderMode> border_mode_vec = { | |||
param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101, | |||
param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP, | |||
param::Remap::BorderMode::REPLICATE}; | |||
// current do not test this. | |||
std::vector<float> scalar; | |||
for (auto fmt : format_vec) { | |||
for (auto border_type : border_mode_vec) { | |||
param.format = fmt; | |||
param.border_type = border_type; | |||
args.emplace_back( | |||
param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2}, | |||
TensorShape{70000, 1, 2, 2}); | |||
args.emplace_back( | |||
param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 1, 2, 2}); | |||
args.emplace_back( | |||
param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 3, 2, 2}); | |||
args.emplace_back( | |||
param, TensorShape{1, 10, 100, 100}, TensorShape{1, 100, 100, 2}, | |||
TensorShape{1, 10, 100, 100}); | |||
args.emplace_back( | |||
param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2}, | |||
TensorShape{2, 4, 100, 200}); | |||
args.emplace_back( | |||
param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2}, | |||
TensorShape{2, 4, 20, 30}); | |||
args.emplace_back( | |||
param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2}, | |||
TensorShape{2, 4, 20, 30}); | |||
for (auto interp_mode : interp_mode_vec) { | |||
for (auto border_type : border_mode_vec) { | |||
param.format = fmt; | |||
param.imode = interp_mode; | |||
param.border_type = border_type; | |||
args.emplace_back( | |||
param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2}, | |||
TensorShape{70000, 1, 2, 2}); | |||
args.emplace_back( | |||
param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 1, 2, 2}); | |||
args.emplace_back( | |||
param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 3, 2, 2}); | |||
args.emplace_back( | |||
param, TensorShape{1, 10, 100, 100}, | |||
TensorShape{1, 100, 100, 2}, TensorShape{1, 10, 100, 100}); | |||
args.emplace_back( | |||
param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2}, | |||
TensorShape{2, 4, 100, 200}); | |||
args.emplace_back( | |||
param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2}, | |||
TensorShape{2, 4, 20, 30}); | |||
args.emplace_back( | |||
param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2}, | |||
TensorShape{2, 4, 20, 30}); | |||
} | |||
} | |||
} | |||
return args; | |||
} | |||
static inline std::vector<TestArg> get_nhwcd4_args() { | |||
std::vector<TestArg> args; | |||
param::Remap param; | |||
param.format = param::Remap::Format::NHWCD4; | |||
param.imode = param::Remap::InterpolationMode::LINEAR; | |||
param.border_type = param::Remap::BorderMode::CONSTANT; | |||
// FIXME: when fractional part of bval is not zero, naive and opencl bankend may | |||
// have different rounding result | |||
param.scalar = 77; | |||
args.emplace_back( | |||
param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2}, | |||
TensorShape{2, 4, 1, 6, 4}); | |||
args.emplace_back( | |||
param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2}, | |||
TensorShape{2, 2, 1, 3, 4}); | |||
param.imode = param::Remap::InterpolationMode::NEAREST; | |||
args.emplace_back( | |||
param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2}, | |||
TensorShape{2, 4, 1, 6, 4}); | |||
args.emplace_back( | |||
param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2}, | |||
TensorShape{2, 2, 1, 3, 4}); | |||
return args; | |||
} | |||
static inline std::vector<TestArg> get_nhwc_args() { | |||
std::vector<TestArg> args; | |||
param::Remap param; | |||
std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NHWC}; | |||
std::vector<param::Remap::InterpolationMode> interp_mode_vec = { | |||
param::Remap::InterpolationMode::NEAREST, | |||
param::Remap::InterpolationMode::LINEAR}; | |||
std::vector<param::Remap::BorderMode> border_mode_vec = { | |||
param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101, | |||
param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP, | |||
@@ -88,41 +126,44 @@ static inline std::vector<TestArg> get_nhwc_args() { | |||
// current do not test this. | |||
std::vector<float> scalar; | |||
for (auto fmt : format_vec) { | |||
for (auto border_type : border_mode_vec) { | |||
param.format = fmt; | |||
param.border_type = border_type; | |||
param.scalar = 12.f; | |||
args.emplace_back( | |||
param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2}, | |||
TensorShape{70000, 2, 2, 1}); | |||
args.emplace_back( | |||
param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 2, 2, 1}); | |||
args.emplace_back( | |||
param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 2, 2, 3}); | |||
args.emplace_back( | |||
param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 2, 2, 66}); | |||
args.emplace_back( | |||
param, TensorShape{1, 100, 100, 10}, TensorShape{1, 100, 100, 2}, | |||
TensorShape{1, 100, 100, 10}); | |||
args.emplace_back( | |||
param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2}, | |||
TensorShape{2, 100, 200, 4}); | |||
args.emplace_back( | |||
param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2}, | |||
TensorShape{2, 20, 30, 4}); | |||
args.emplace_back( | |||
param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2}, | |||
TensorShape{2, 20, 30, 4}); | |||
for (auto interp_mode : interp_mode_vec) { | |||
for (auto border_type : border_mode_vec) { | |||
param.format = fmt; | |||
param.imode = interp_mode; | |||
param.border_type = border_type; | |||
param.scalar = 12.f; | |||
args.emplace_back( | |||
param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2}, | |||
TensorShape{70000, 2, 2, 1}); | |||
args.emplace_back( | |||
param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 2, 2, 1}); | |||
args.emplace_back( | |||
param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 2, 2, 3}); | |||
args.emplace_back( | |||
param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2}, | |||
TensorShape{1, 2, 2, 66}); | |||
args.emplace_back( | |||
param, TensorShape{1, 100, 100, 10}, | |||
TensorShape{1, 100, 100, 2}, TensorShape{1, 100, 100, 10}); | |||
args.emplace_back( | |||
param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2}, | |||
TensorShape{2, 100, 200, 4}); | |||
args.emplace_back( | |||
param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2}, | |||
TensorShape{2, 20, 30, 4}); | |||
args.emplace_back( | |||
param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2}, | |||
TensorShape{2, 20, 30, 4}); | |||
} | |||
} | |||
} | |||
return args; | |||
@@ -58,6 +58,11 @@ static void set_nchw_args(std::vector<TestArg>& args) { | |||
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); | |||
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3}); | |||
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); | |||
param.imode = param::Resize::InterpolationMode::NEAREST; | |||
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); | |||
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3}); | |||
args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); | |||
} | |||
static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) { | |||
@@ -75,6 +80,25 @@ static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) { | |||
return args; | |||
} | |||
static inline std::vector<TestArg> get_nhwc_args() { | |||
std::vector<TestArg> args; | |||
param::Resize param; | |||
param.format = param::Resize::Format::NHWC; | |||
param.imode = param::Resize::InterpolationMode::LINEAR; | |||
args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2}); | |||
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); | |||
args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2}); | |||
param.imode = param::Resize::InterpolationMode::NEAREST; | |||
args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2}); | |||
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); | |||
args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2}); | |||
return args; | |||
} | |||
static inline std::vector<TestArg> get_nhwcd4_args() { | |||
std::vector<TestArg> args; | |||
@@ -83,6 +107,9 @@ static inline std::vector<TestArg> get_nhwcd4_args() { | |||
param.imode = param::Resize::InterpolationMode::LINEAR; | |||
args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4}); | |||
args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4}); | |||
param.imode = param::Resize::InterpolationMode::NEAREST; | |||
args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4}); | |||
args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4}); | |||
return args; | |||
} | |||
@@ -351,7 +351,7 @@ def remap( | |||
"reflect_101", "wrap". | |||
scalar: value used in case of a constant border. Default: 0 | |||
interp_mode: interpolation methods. | |||
Default: "linear". Currently only support "linear" mode. | |||
Default: "linear". Currently also support "nearest" mode. | |||
Returns: | |||
output tensor. | |||