GitOrigin-RevId: bb04d2a801
HuaHua404-patch-1
@@ -67,8 +67,12 @@ void ResizeBackward::check_exec( | |||||
auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); | auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); | ||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
megdnn_assert( | megdnn_assert( | ||||
param().format == Param::Format::NCHW && grad.dtype == dtype::Float32(), | |||||
"Backward resize only supports Float32 and NCHW."); | |||||
(param().format == Param::Format::NCHW || | |||||
param().format == Param::Format::NHWC) && | |||||
(grad.dtype == dtype::Float32() DNN_INC_FLOAT16( | |||||
|| grad.dtype == dtype::Float16())), | |||||
"Backward resize only supports NCHW and NHWC, the dtype only supports " | |||||
"Float32 and Float16."); | |||||
} | } | ||||
std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) { | std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) { | ||||
@@ -11,26 +11,56 @@ void ResizeBackwardImpl::exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | ||||
check_exec(diff.layout, grad.layout, workspace.size); | check_exec(diff.layout, grad.layout, workspace.size); | ||||
auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
auto N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], | |||||
IW = grad.layout.shape[3], OH = diff.layout.shape[2], | |||||
OW = diff.layout.shape[3]; | |||||
bool is_nhwc = param().format == param::Resize::Format::NHWC; | |||||
size_t N, C, IH, IW, OH, OW; | |||||
if (is_nhwc) { | |||||
if (param().imode != Param::InterpolationMode::LINEAR && | |||||
is_nhwc_contig_wc(grad.layout)) { | |||||
megdnn_assert( | |||||
0, | |||||
"unsupport mode in resizeBackward, only support param().imode = " | |||||
"LINEAR"); | |||||
} | |||||
N = grad.layout.shape[0]; | |||||
C = grad.layout.shape[3]; | |||||
IH = grad.layout.shape[1]; | |||||
IW = grad.layout.shape[2]; | |||||
OH = diff.layout.shape[1]; | |||||
OW = diff.layout.shape[2]; | |||||
} else { | |||||
N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], | |||||
IW = grad.layout.shape[3], OH = diff.layout.shape[2], OW = diff.layout.shape[3]; | |||||
} | |||||
size_t max_batch_x_channel = max_batch_x_channel_size(); | size_t max_batch_x_channel = max_batch_x_channel_size(); | ||||
dt_float32* diff_ptr = diff.ptr<dt_float32>(); | |||||
dt_float32* grad_ptr = grad.ptr<dt_float32>(); | |||||
size_t max_batch_size = max_batch_x_channel / C; | size_t max_batch_size = max_batch_x_channel / C; | ||||
while (N > 0) { | while (N > 0) { | ||||
size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; | size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; | ||||
resize::backward_data_proxy( | |||||
resize::get_imode(param().imode), diff_ptr, grad_ptr, curr_batch_size, | |||||
C, IH, IW, OH, OW, stream); | |||||
if (N <= max_batch_size) { | |||||
break; | |||||
} else { | |||||
N -= max_batch_size; | |||||
diff_ptr += curr_batch_size * diff.layout.stride[0]; | |||||
grad_ptr += curr_batch_size * grad.layout.stride[0]; | |||||
switch (grad.layout.dtype.enumv()) { | |||||
#define cb(_t) \ | |||||
case DTypeTrait<_t>::enumv: { \ | |||||
typedef DTypeTrait<_t>::ctype ct; \ | |||||
ct* diff_ptr = diff.ptr<ct>(); \ | |||||
ct* grad_ptr = grad.ptr<ct>(); \ | |||||
resize::backward_data_proxy( \ | |||||
is_nhwc, resize::get_imode(param().imode), diff_ptr, grad_ptr, \ | |||||
curr_batch_size, C, IH, IW, OH, OW, stream); \ | |||||
if (N <= max_batch_size) { \ | |||||
return; \ | |||||
} else { \ | |||||
N -= max_batch_size; \ | |||||
diff_ptr += curr_batch_size * diff.layout.stride[0]; \ | |||||
grad_ptr += curr_batch_size * grad.layout.stride[0]; \ | |||||
} \ | |||||
break; \ | |||||
} | |||||
cb(megdnn::dtype::Float32); | |||||
DNN_INC_FLOAT16(cb(megdnn::dtype::Float16)); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"unsupported dtype: %s in resize backward", | |||||
grad.layout.dtype.name())); | |||||
} | } | ||||
#undef cb | |||||
} | } | ||||
} | } | ||||
@@ -1,3 +1,4 @@ | |||||
#include "src/common/rounding_converter.cuh" | |||||
#include "src/cuda/resize/common.cuh" | #include "src/cuda/resize/common.cuh" | ||||
#include "src/cuda/resize/common.h" | #include "src/cuda/resize/common.h" | ||||
@@ -11,9 +12,52 @@ namespace megdnn { | |||||
namespace cuda { | namespace cuda { | ||||
namespace resize { | namespace resize { | ||||
template <typename ctype, typename OutputConverter> | |||||
__global__ void resize_bwd_nhwc_kernel( | |||||
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
float scale_h, float scale_w) { | |||||
OutputConverter output_converter; | |||||
int n = blockIdx.z; | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
hidden += n * C * OH * OW; | |||||
dst += n * C * IH * IW; | |||||
if (ow < OW && oh < OH) { | |||||
float alphah, alphaw; | |||||
int ih0, iw0; | |||||
get_origin_coord(scale_h, IH, oh, alphah, ih0); | |||||
get_origin_coord(scale_w, IW, ow, alphaw, iw0); | |||||
int ih1 = ih0 + 1; | |||||
int iw1 = iw0 + 1; | |||||
float nalphaw = 1.0f - alphaw; | |||||
float nalphah = 1.0f - alphah; | |||||
for (int c = 0; c < C; ++c) { | |||||
atomic_add( | |||||
dst + (ih0 * IW + iw0) * C + c, | |||||
output_converter( | |||||
hidden[(oh * OW + ow) * C + c] * nalphaw * nalphah)); | |||||
atomic_add( | |||||
dst + (ih0 * IW + iw1) * C + c, | |||||
output_converter( | |||||
hidden[(oh * OW + ow) * C + c] * alphaw * nalphah)); | |||||
atomic_add( | |||||
dst + (ih1 * IW + iw0) * C + c, | |||||
output_converter( | |||||
hidden[(oh * OW + ow) * C + c] * nalphaw * alphah)); | |||||
atomic_add( | |||||
dst + (ih1 * IW + iw1) * C + c, | |||||
output_converter(hidden[(oh * OW + ow) * C + c] * alphaw * alphah)); | |||||
} | |||||
} | |||||
} | |||||
template <typename ctype, typename OutputConverter> | |||||
__global__ void resize_bwd_linear_kernel( | __global__ void resize_bwd_linear_kernel( | ||||
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
float scale_h, float scale_w) { | float scale_h, float scale_w) { | ||||
OutputConverter output_converter; | |||||
int n = blockIdx.z; | int n = blockIdx.z; | ||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | int oh = blockIdx.y * blockDim.y + threadIdx.y; | ||||
@@ -31,19 +75,29 @@ __global__ void resize_bwd_linear_kernel( | |||||
float nalphaw = 1.0f - alphaw; | float nalphaw = 1.0f - alphaw; | ||||
float nalphah = 1.0f - alphah; | float nalphah = 1.0f - alphah; | ||||
for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
atomicAdd(dst + ih0 * IW + iw0, hidden[oh * OW + ow] * nalphaw * nalphah); | |||||
atomicAdd(dst + ih0 * IW + iw1, hidden[oh * OW + ow] * alphaw * nalphah); | |||||
atomicAdd(dst + ih1 * IW + iw0, hidden[oh * OW + ow] * nalphaw * alphah); | |||||
atomicAdd(dst + ih1 * IW + iw1, hidden[oh * OW + ow] * alphaw * alphah); | |||||
atomic_add( | |||||
dst + ih0 * IW + iw0, | |||||
output_converter(hidden[oh * OW + ow] * nalphaw * nalphah)); | |||||
atomic_add( | |||||
dst + ih0 * IW + iw1, | |||||
output_converter(hidden[oh * OW + ow] * alphaw * nalphah)); | |||||
atomic_add( | |||||
dst + ih1 * IW + iw0, | |||||
output_converter(hidden[oh * OW + ow] * nalphaw * alphah)); | |||||
atomic_add( | |||||
dst + ih1 * IW + iw1, | |||||
output_converter(hidden[oh * OW + ow] * alphaw * alphah)); | |||||
hidden += OH * OW; | hidden += OH * OW; | ||||
dst += IH * IW; | dst += IH * IW; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
template <typename ctype, typename OutputConverter> | |||||
__global__ void resize_bwd_nearest_kernel( | __global__ void resize_bwd_nearest_kernel( | ||||
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
float scale_h, float scale_w) { | float scale_h, float scale_w) { | ||||
OutputConverter output_converter; | |||||
int n = blockIdx.z; | int n = blockIdx.z; | ||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | int oh = blockIdx.y * blockDim.y + threadIdx.y; | ||||
@@ -54,16 +108,18 @@ __global__ void resize_bwd_nearest_kernel( | |||||
int iw = get_nearest_src(scale_w, IW, ow); | int iw = get_nearest_src(scale_w, IW, ow); | ||||
for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
atomicAdd(dst + ih * IW + iw, hidden[oh * OW + ow]); | |||||
atomic_add(dst + ih * IW + iw, output_converter(hidden[oh * OW + ow])); | |||||
hidden += OH * OW; | hidden += OH * OW; | ||||
dst += IH * IW; | dst += IH * IW; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
template <typename ctype, typename OutputConverter> | |||||
__global__ void resize_bwd_cubic_kernel( | __global__ void resize_bwd_cubic_kernel( | ||||
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
float scale_h, float scale_w) { | float scale_h, float scale_w) { | ||||
OutputConverter output_converter; | |||||
int n = blockIdx.z; | int n = blockIdx.z; | ||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | int oh = blockIdx.y * blockDim.y + threadIdx.y; | ||||
@@ -85,9 +141,10 @@ __global__ void resize_bwd_cubic_kernel( | |||||
int ih = saturate(ih0 + kh, 0, IH - 1); | int ih = saturate(ih0 + kh, 0, IH - 1); | ||||
for (int kw = 0; kw < ksize; kw++) { | for (int kw = 0; kw < ksize; kw++) { | ||||
int iw = saturate(iw0 + kw, 0, IW - 1); | int iw = saturate(iw0 + kw, 0, IW - 1); | ||||
atomicAdd( | |||||
atomic_add( | |||||
dst + ih * IW + iw, | dst + ih * IW + iw, | ||||
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]); | |||||
output_converter( | |||||
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw])); | |||||
} | } | ||||
} | } | ||||
@@ -97,41 +154,59 @@ __global__ void resize_bwd_cubic_kernel( | |||||
} | } | ||||
} | } | ||||
template <typename ctype> | |||||
void backward_data_proxy( | void backward_data_proxy( | ||||
InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, | |||||
int IW, int OH, int OW, cudaStream_t stream) { | |||||
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, | |||||
int C, int IH, int IW, int OH, int OW, cudaStream_t stream) { | |||||
const int BY = 16, BX = 32; | const int BY = 16, BX = 32; | ||||
{ | { | ||||
dim3 threads(BX, BY); | dim3 threads(BX, BY); | ||||
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N); | dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N); | ||||
cuda_check(cudaMemsetAsync(grad, 0, sizeof(float) * N * C * IH * IW, stream)); | |||||
cuda_check(cudaMemsetAsync(grad, 0, sizeof(ctype) * N * C * IH * IW, stream)); | |||||
float scale_h = static_cast<float>(OH) / IH; | float scale_h = static_cast<float>(OH) / IH; | ||||
float scale_w = static_cast<float>(OW) / IW; | float scale_w = static_cast<float>(OW) / IW; | ||||
switch (imode) { | |||||
case InterpolationMode::INTER_LINEAR: { | |||||
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_NEAREST: { | |||||
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_CUBIC: { | |||||
resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
default: { | |||||
megdnn_throw("unsupported interpolation mode"); | |||||
break; | |||||
if (is_nhwc) { | |||||
resize_bwd_nhwc_kernel<ctype, rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
} else { | |||||
switch (imode) { | |||||
case InterpolationMode::INTER_LINEAR: { | |||||
resize_bwd_linear_kernel<ctype, rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_NEAREST: { | |||||
resize_bwd_nearest_kernel<ctype, rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_CUBIC: { | |||||
resize_bwd_cubic_kernel<ctype, rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
default: { | |||||
megdnn_throw("unsupported interpolation mode"); | |||||
break; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
#define INST(ctype) \ | |||||
template void backward_data_proxy( \ | |||||
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \ | |||||
int, cudaStream_t); | |||||
INST(dt_float32); | |||||
DNN_INC_FLOAT16(INST(dt_float16)); | |||||
#undef INST | |||||
} // namespace resize | } // namespace resize | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -20,9 +20,10 @@ void forward_proxy_nchw4( | |||||
const ctype* src, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | const ctype* src, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | ||||
cudaStream_t stream); | cudaStream_t stream); | ||||
template <typename ctype> | |||||
void backward_data_proxy( | void backward_data_proxy( | ||||
InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, | |||||
int IW, int OH, int OW, cudaStream_t stream); | |||||
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, | |||||
int C, int IH, int IW, int OH, int OW, cudaStream_t stream); | |||||
} // namespace resize | } // namespace resize | ||||
} // namespace cuda | } // namespace cuda | ||||
@@ -148,6 +148,11 @@ void ResizeImpl::exec( | |||||
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float32>(), | is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float32>(), | ||||
dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, | dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, | ||||
S_IH, S_IW, stream); | S_IH, S_IW, stream); | ||||
} else if (src.layout.dtype == dtype::Float16{}) { | |||||
resize::forward_proxy( | |||||
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float16>(), | |||||
dst.ptr<dt_float16>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, | |||||
S_IH, S_IW, stream); | |||||
} else if (src.layout.dtype == dtype::Uint8()) { | } else if (src.layout.dtype == dtype::Uint8()) { | ||||
resize::forward_proxy( | resize::forward_proxy( | ||||
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_uint8>(), | is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_uint8>(), | ||||
@@ -298,6 +298,7 @@ void forward_proxy_nchw4( | |||||
INST(float) | INST(float) | ||||
INST(uint8_t) | INST(uint8_t) | ||||
INST(int8_t) | INST(int8_t) | ||||
DNN_INC_FLOAT16(INST(dt_float16)) | |||||
#undef INST | #undef INST | ||||
#define INST(ctype) \ | #define INST(ctype) \ | ||||
@@ -387,40 +387,53 @@ void ResizeImpl::exec( | |||||
} | } | ||||
} | } | ||||
void ResizeBackwardImpl::exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||||
check_exec(diff.layout, grad.layout, workspace.size); | |||||
megdnn_assert( | |||||
param().format == param::Resize::Format::NCHW, "invalid resize format"); | |||||
const int N = grad.layout.shape[0], C = grad.layout.shape[1], | |||||
IH = grad.layout.shape[2], IW = grad.layout.shape[3]; | |||||
const int OH = diff.layout.shape[2], OW = diff.layout.shape[3]; | |||||
const float* hptr_ = diff.ptr<dt_float32>(); | |||||
float* sptr_ = grad.ptr<dt_float32>(); | |||||
// ***************************Backward*************************** // | |||||
template <typename ctype> | |||||
void ResizeBackwardImpl::kern_naive( | |||||
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, | |||||
int C, int IH, int IW, int OH, int OW) { | |||||
float scale_h = static_cast<float>(OH) / IH; | float scale_h = static_cast<float>(OH) / IH; | ||||
float scale_w = static_cast<float>(OW) / IW; | float scale_w = static_cast<float>(OW) / IW; | ||||
rounding::RoundingConverter<ctype> output_converter; | |||||
auto kern = [=]() { | auto kern = [=]() { | ||||
auto hptr = hptr_; | |||||
auto sptr = sptr_; | |||||
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); | |||||
auto hptr = diff; | |||||
auto sptr = grad; | |||||
std::memset(sptr, 0, sizeof(ctype) * N * C * IH * IW); | |||||
rep(n, N) { | rep(n, N) { | ||||
rep(oh, OH) rep(ow, OW) { | rep(oh, OH) rep(ow, OW) { | ||||
switch (param().imode) { | |||||
switch (imode) { | |||||
case InterpolationMode::INTER_LINEAR: { | case InterpolationMode::INTER_LINEAR: { | ||||
int ih0, ih1, iw0, iw1; | int ih0, ih1, iw0, iw1; | ||||
float ah0, ah1, aw0, aw1; | float ah0, ah1, aw0, aw1; | ||||
std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( | |||||
param().imode, scale_h, IH, oh); | |||||
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( | |||||
param().imode, scale_w, IW, ow); | |||||
rep(c, C) { | |||||
float hidden = hptr[c * OH * OW + oh * OW + ow]; | |||||
sptr[c * IH * IW + ih0 * IW + iw0] += ah0 * aw0 * hidden; | |||||
sptr[c * IH * IW + ih1 * IW + iw0] += ah1 * aw0 * hidden; | |||||
sptr[c * IH * IW + ih0 * IW + iw1] += ah0 * aw1 * hidden; | |||||
sptr[c * IH * IW + ih1 * IW + iw1] += ah1 * aw1 * hidden; | |||||
std::tie(ah0, ih0, ah1, ih1) = | |||||
get_nearest_linear_coord(imode, scale_h, IH, oh); | |||||
std::tie(aw0, iw0, aw1, iw1) = | |||||
get_nearest_linear_coord(imode, scale_w, IW, ow); | |||||
if (is_nhwc) { | |||||
rep(c, C) { | |||||
sptr[(ih0 * IW + iw0) * C + c] += output_converter( | |||||
hptr[(oh * OW + ow) * C + c] * ah0 * aw0); | |||||
sptr[(ih0 * IW + iw1) * C + c] += output_converter( | |||||
hptr[(oh * OW + ow) * C + c] * ah0 * aw1); | |||||
sptr[(ih1 * IW + iw0) * C + c] += output_converter( | |||||
hptr[(oh * OW + ow) * C + c] * ah1 * aw0); | |||||
sptr[(ih1 * IW + iw1) * C + c] += output_converter( | |||||
hptr[(oh * OW + ow) * C + c] * ah1 * aw1); | |||||
} | |||||
} else { | |||||
rep(c, C) { | |||||
float hidden = hptr[c * OH * OW + oh * OW + ow]; | |||||
sptr[c * IH * IW + ih0 * IW + iw0] += | |||||
output_converter(ah0 * aw0 * hidden); | |||||
sptr[c * IH * IW + ih1 * IW + iw0] += | |||||
output_converter(ah1 * aw0 * hidden); | |||||
sptr[c * IH * IW + ih0 * IW + iw1] += | |||||
output_converter(ah0 * aw1 * hidden); | |||||
sptr[c * IH * IW + ih1 * IW + iw1] += | |||||
output_converter(ah1 * aw1 * hidden); | |||||
} | |||||
} | } | ||||
break; | break; | ||||
} | } | ||||
@@ -429,7 +442,7 @@ void ResizeBackwardImpl::exec( | |||||
auto iw = get_nearest_src(scale_w, IW, ow); | auto iw = get_nearest_src(scale_w, IW, ow); | ||||
rep(c, static_cast<int>(C)) { | rep(c, static_cast<int>(C)) { | ||||
sptr[c * IH * IW + ih * IW + iw] += | sptr[c * IH * IW + ih * IW + iw] += | ||||
hptr[c * OH * OW + oh * OW + ow]; | |||||
output_converter(hptr[c * OH * OW + oh * OW + ow]); | |||||
} | } | ||||
break; | break; | ||||
} | } | ||||
@@ -452,9 +465,9 @@ void ResizeBackwardImpl::exec( | |||||
int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | ||||
rep(kw, ksize) { | rep(kw, ksize) { | ||||
int w = saturate<int, int>(iw0 + kw, 0, IW - 1); | int w = saturate<int, int>(iw0 + kw, 0, IW - 1); | ||||
sptr[c * IH * IW + h * IW + w] += | |||||
sptr[c * IH * IW + h * IW + w] += output_converter( | |||||
hptr[c * OH * OW + oh * OW + ow] * | hptr[c * OH * OW + oh * OW + ow] * | ||||
h_coeff[kh] * w_coeff[kw]; | |||||
h_coeff[kh] * w_coeff[kw]); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -473,4 +486,59 @@ void ResizeBackwardImpl::exec( | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | ||||
} | } | ||||
#define INST(ctype) \ | |||||
template void ResizeBackwardImpl::kern_naive( \ | |||||
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \ | |||||
int); | |||||
INST(dt_float32); | |||||
DNN_INC_FLOAT16(INST(dt_float16)); | |||||
#undef INST | |||||
void ResizeBackwardImpl::exec( | |||||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||||
check_exec(diff.layout, grad.layout, workspace.size); | |||||
megdnn_assert( | |||||
param().format == param::Resize::Format::NCHW || | |||||
param().format == param::Resize::Format::NHWC, | |||||
"invalid resize format"); | |||||
size_t N, C, IH, IW, OH, OW; | |||||
bool is_nhwc = param().format == param::Resize::Format::NHWC; | |||||
if (is_nhwc) { | |||||
if (param().imode != Param::InterpolationMode::LINEAR && | |||||
is_nhwc_contig_wc(grad.layout)) { | |||||
megdnn_assert( | |||||
0, | |||||
"unsupport mode in resizeBackward, only support param().imode = " | |||||
"LINEAR"); | |||||
} | |||||
N = grad.layout.shape[0]; | |||||
C = grad.layout.shape[3]; | |||||
IH = grad.layout.shape[1]; | |||||
IW = grad.layout.shape[2]; | |||||
OH = diff.layout.shape[1]; | |||||
OW = diff.layout.shape[2]; | |||||
} else { | |||||
N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], | |||||
IW = grad.layout.shape[3]; | |||||
OH = diff.layout.shape[2], OW = diff.layout.shape[3]; | |||||
} | |||||
switch (grad.layout.dtype.enumv()) { | |||||
#define cb(_t) \ | |||||
case DTypeTrait<_t>::enumv: { \ | |||||
typedef DTypeTrait<_t>::ctype ct; \ | |||||
ct* diff_ptr = diff.ptr<ct>(); \ | |||||
ct* grad_ptr = grad.ptr<ct>(); \ | |||||
ResizeBackwardImpl::kern_naive( \ | |||||
is_nhwc, param().imode, diff_ptr, grad_ptr, N, C, IH, IW, OH, OW); \ | |||||
break; \ | |||||
} | |||||
cb(megdnn::dtype::Float32); | |||||
DNN_INC_FLOAT16(cb(megdnn::dtype::Float16)); | |||||
default: | |||||
megdnn_throw(ssprintf( | |||||
"unsupported dtype: %s in resize backward", | |||||
grad.layout.dtype.name())); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -75,6 +75,12 @@ public: | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { | size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { | ||||
return 0; | return 0; | ||||
} | } | ||||
private: | |||||
template <typename ctype> | |||||
void kern_naive( | |||||
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, | |||||
int N, int C, int IH, int IW, int OH, int OW); | |||||
}; | }; | ||||
} // namespace naive | } // namespace naive | ||||
@@ -61,13 +61,67 @@ TEST_F(CUDA, RESIZE_FORWARD) { | |||||
.set_epsilon(1) | .set_epsilon(1) | ||||
.execs({arg.src, arg.dst}); | .execs({arg.src, arg.dst}); | ||||
} | } | ||||
for (auto&& arg : args) { | |||||
checker.set_param(arg.param) | |||||
.set_dtype(0, dtype::Float16()) | |||||
.set_dtype(1, dtype::Float16()) | |||||
.set_epsilon(1e-3) | |||||
.execs({arg.src, arg.dst}); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(CUDA, RESIZE_NHWC) { | |||||
using namespace resize; | |||||
std::vector<TestArg> args; | |||||
param::Resize param; | |||||
param.format = param::Resize::Format::NHWC; | |||||
param.imode = param::Resize::InterpolationMode::LINEAR; | |||||
args.emplace_back(param, TensorShape{1, 1, 4, 5}, TensorShape{1, 1, 8, 5}); | |||||
args.emplace_back(param, TensorShape{2, 6, 4, 5}, TensorShape{2, 3, 8, 5}); | |||||
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); | |||||
Checker<ResizeBackward> checkerBackward(handle_cuda()); | |||||
for (auto&& arg : args) { | |||||
checkerBackward.set_param(arg.param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_epsilon(1e-3) | |||||
.execs({arg.src, arg.dst}); | |||||
} | |||||
for (auto&& arg : args) { | |||||
checkerBackward.set_param(arg.param) | |||||
.set_dtype(0, dtype::Float16()) | |||||
.set_dtype(1, dtype::Float16()) | |||||
.set_epsilon(1e-3) | |||||
.execs({arg.src, arg.dst}); | |||||
} | |||||
Checker<ResizeForward> checkerForward(handle_cuda()); | |||||
for (auto&& arg : args) { | |||||
checkerForward.set_param(arg.param) | |||||
.set_dtype(0, dtype::Float16()) | |||||
.set_dtype(1, dtype::Float16()) | |||||
.set_epsilon(1e-3) | |||||
.execs({arg.src, arg.dst}); | |||||
} | |||||
for (auto&& arg : args) { | |||||
checkerForward.set_param(arg.param) | |||||
.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_epsilon(1e-3) | |||||
.execs({arg.src, arg.dst}); | |||||
} | } | ||||
} | } | ||||
TEST_F(CUDA, RESIZE_NCHW4) { | TEST_F(CUDA, RESIZE_NCHW4) { | ||||
using namespace resize; | using namespace resize; | ||||
Checker<Resize> checker(handle_cuda()); | Checker<Resize> checker(handle_cuda()); | ||||
auto args = get_nchw4_args(); | auto args = get_nchw4_args(); | ||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
checker.set_param(arg.param) | checker.set_param(arg.param) | ||||
@@ -113,6 +167,24 @@ TEST_F(CUDA, RESIZE_BACKWARD) { | |||||
param.format = param::Resize::Format::NCHW; | param.format = param::Resize::Format::NCHW; | ||||
param.imode = imode; | param.imode = imode; | ||||
checker.set_param(param); | checker.set_param(param); | ||||
checker.set_dtype(0, dtype::Float16()); | |||||
checker.set_dtype(1, dtype::Float16()); | |||||
checker.set_epsilon(1 + 1e-3); | |||||
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); | |||||
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); | |||||
checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}}); | |||||
checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}}); | |||||
} | |||||
for (auto imode : modes) { | |||||
Checker<ResizeBackward> checker(handle_cuda()); | |||||
param::Resize param; | |||||
param.format = param::Resize::Format::NCHW; | |||||
param.imode = imode; | |||||
checker.set_param(param); | |||||
checker.set_dtype(0, dtype::Float32()); | |||||
checker.set_dtype(1, dtype::Float32()); | |||||
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); | checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); | ||||
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); | checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); | ||||