GitOrigin-RevId: bb04d2a801
HuaHua404-patch-1
@@ -67,8 +67,12 @@ void ResizeBackward::check_exec( | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
megdnn_assert( | |||
param().format == Param::Format::NCHW && grad.dtype == dtype::Float32(), | |||
"Backward resize only supports Float32 and NCHW."); | |||
(param().format == Param::Format::NCHW || | |||
param().format == Param::Format::NHWC) && | |||
(grad.dtype == dtype::Float32() DNN_INC_FLOAT16( | |||
|| grad.dtype == dtype::Float16())), | |||
"Backward resize only supports NCHW and NHWC, the dtype only supports " | |||
"Float32 and Float16."); | |||
} | |||
std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) { | |||
@@ -11,26 +11,56 @@ void ResizeBackwardImpl::exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
check_exec(diff.layout, grad.layout, workspace.size); | |||
auto stream = cuda_stream(this->handle()); | |||
auto N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], | |||
IW = grad.layout.shape[3], OH = diff.layout.shape[2], | |||
OW = diff.layout.shape[3]; | |||
bool is_nhwc = param().format == param::Resize::Format::NHWC; | |||
size_t N, C, IH, IW, OH, OW; | |||
if (is_nhwc) { | |||
if (param().imode != Param::InterpolationMode::LINEAR && | |||
is_nhwc_contig_wc(grad.layout)) { | |||
megdnn_assert( | |||
0, | |||
"unsupport mode in resizeBackward, only support param().imode = " | |||
"LINEAR"); | |||
} | |||
N = grad.layout.shape[0]; | |||
C = grad.layout.shape[3]; | |||
IH = grad.layout.shape[1]; | |||
IW = grad.layout.shape[2]; | |||
OH = diff.layout.shape[1]; | |||
OW = diff.layout.shape[2]; | |||
} else { | |||
N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], | |||
IW = grad.layout.shape[3], OH = diff.layout.shape[2], OW = diff.layout.shape[3]; | |||
} | |||
size_t max_batch_x_channel = max_batch_x_channel_size(); | |||
dt_float32* diff_ptr = diff.ptr<dt_float32>(); | |||
dt_float32* grad_ptr = grad.ptr<dt_float32>(); | |||
size_t max_batch_size = max_batch_x_channel / C; | |||
while (N > 0) { | |||
size_t curr_batch_size = N > max_batch_size ? max_batch_size : N; | |||
resize::backward_data_proxy( | |||
resize::get_imode(param().imode), diff_ptr, grad_ptr, curr_batch_size, | |||
C, IH, IW, OH, OW, stream); | |||
if (N <= max_batch_size) { | |||
break; | |||
} else { | |||
N -= max_batch_size; | |||
diff_ptr += curr_batch_size * diff.layout.stride[0]; | |||
grad_ptr += curr_batch_size * grad.layout.stride[0]; | |||
switch (grad.layout.dtype.enumv()) { | |||
#define cb(_t) \ | |||
case DTypeTrait<_t>::enumv: { \ | |||
typedef DTypeTrait<_t>::ctype ct; \ | |||
ct* diff_ptr = diff.ptr<ct>(); \ | |||
ct* grad_ptr = grad.ptr<ct>(); \ | |||
resize::backward_data_proxy( \ | |||
is_nhwc, resize::get_imode(param().imode), diff_ptr, grad_ptr, \ | |||
curr_batch_size, C, IH, IW, OH, OW, stream); \ | |||
if (N <= max_batch_size) { \ | |||
return; \ | |||
} else { \ | |||
N -= max_batch_size; \ | |||
diff_ptr += curr_batch_size * diff.layout.stride[0]; \ | |||
grad_ptr += curr_batch_size * grad.layout.stride[0]; \ | |||
} \ | |||
break; \ | |||
} | |||
cb(megdnn::dtype::Float32); | |||
DNN_INC_FLOAT16(cb(megdnn::dtype::Float16)); | |||
default: | |||
megdnn_throw(ssprintf( | |||
"unsupported dtype: %s in resize backward", | |||
grad.layout.dtype.name())); | |||
} | |||
#undef cb | |||
} | |||
} | |||
@@ -1,3 +1,4 @@ | |||
#include "src/common/rounding_converter.cuh" | |||
#include "src/cuda/resize/common.cuh" | |||
#include "src/cuda/resize/common.h" | |||
@@ -11,9 +12,52 @@ namespace megdnn { | |||
namespace cuda { | |||
namespace resize { | |||
template <typename ctype, typename OutputConverter> | |||
__global__ void resize_bwd_nhwc_kernel( | |||
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
float scale_h, float scale_w) { | |||
OutputConverter output_converter; | |||
int n = blockIdx.z; | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
hidden += n * C * OH * OW; | |||
dst += n * C * IH * IW; | |||
if (ow < OW && oh < OH) { | |||
float alphah, alphaw; | |||
int ih0, iw0; | |||
get_origin_coord(scale_h, IH, oh, alphah, ih0); | |||
get_origin_coord(scale_w, IW, ow, alphaw, iw0); | |||
int ih1 = ih0 + 1; | |||
int iw1 = iw0 + 1; | |||
float nalphaw = 1.0f - alphaw; | |||
float nalphah = 1.0f - alphah; | |||
for (int c = 0; c < C; ++c) { | |||
atomic_add( | |||
dst + (ih0 * IW + iw0) * C + c, | |||
output_converter( | |||
hidden[(oh * OW + ow) * C + c] * nalphaw * nalphah)); | |||
atomic_add( | |||
dst + (ih0 * IW + iw1) * C + c, | |||
output_converter( | |||
hidden[(oh * OW + ow) * C + c] * alphaw * nalphah)); | |||
atomic_add( | |||
dst + (ih1 * IW + iw0) * C + c, | |||
output_converter( | |||
hidden[(oh * OW + ow) * C + c] * nalphaw * alphah)); | |||
atomic_add( | |||
dst + (ih1 * IW + iw1) * C + c, | |||
output_converter(hidden[(oh * OW + ow) * C + c] * alphaw * alphah)); | |||
} | |||
} | |||
} | |||
template <typename ctype, typename OutputConverter> | |||
__global__ void resize_bwd_linear_kernel( | |||
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
float scale_h, float scale_w) { | |||
OutputConverter output_converter; | |||
int n = blockIdx.z; | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
@@ -31,19 +75,29 @@ __global__ void resize_bwd_linear_kernel( | |||
float nalphaw = 1.0f - alphaw; | |||
float nalphah = 1.0f - alphah; | |||
for (int c = 0; c < C; ++c) { | |||
atomicAdd(dst + ih0 * IW + iw0, hidden[oh * OW + ow] * nalphaw * nalphah); | |||
atomicAdd(dst + ih0 * IW + iw1, hidden[oh * OW + ow] * alphaw * nalphah); | |||
atomicAdd(dst + ih1 * IW + iw0, hidden[oh * OW + ow] * nalphaw * alphah); | |||
atomicAdd(dst + ih1 * IW + iw1, hidden[oh * OW + ow] * alphaw * alphah); | |||
atomic_add( | |||
dst + ih0 * IW + iw0, | |||
output_converter(hidden[oh * OW + ow] * nalphaw * nalphah)); | |||
atomic_add( | |||
dst + ih0 * IW + iw1, | |||
output_converter(hidden[oh * OW + ow] * alphaw * nalphah)); | |||
atomic_add( | |||
dst + ih1 * IW + iw0, | |||
output_converter(hidden[oh * OW + ow] * nalphaw * alphah)); | |||
atomic_add( | |||
dst + ih1 * IW + iw1, | |||
output_converter(hidden[oh * OW + ow] * alphaw * alphah)); | |||
hidden += OH * OW; | |||
dst += IH * IW; | |||
} | |||
} | |||
} | |||
template <typename ctype, typename OutputConverter> | |||
__global__ void resize_bwd_nearest_kernel( | |||
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
float scale_h, float scale_w) { | |||
OutputConverter output_converter; | |||
int n = blockIdx.z; | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
@@ -54,16 +108,18 @@ __global__ void resize_bwd_nearest_kernel( | |||
int iw = get_nearest_src(scale_w, IW, ow); | |||
for (int c = 0; c < C; ++c) { | |||
atomicAdd(dst + ih * IW + iw, hidden[oh * OW + ow]); | |||
atomic_add(dst + ih * IW + iw, output_converter(hidden[oh * OW + ow])); | |||
hidden += OH * OW; | |||
dst += IH * IW; | |||
} | |||
} | |||
} | |||
template <typename ctype, typename OutputConverter> | |||
__global__ void resize_bwd_cubic_kernel( | |||
const float* hidden, float* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
const ctype* hidden, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
float scale_h, float scale_w) { | |||
OutputConverter output_converter; | |||
int n = blockIdx.z; | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
@@ -85,9 +141,10 @@ __global__ void resize_bwd_cubic_kernel( | |||
int ih = saturate(ih0 + kh, 0, IH - 1); | |||
for (int kw = 0; kw < ksize; kw++) { | |||
int iw = saturate(iw0 + kw, 0, IW - 1); | |||
atomicAdd( | |||
atomic_add( | |||
dst + ih * IW + iw, | |||
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]); | |||
output_converter( | |||
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw])); | |||
} | |||
} | |||
@@ -97,41 +154,59 @@ __global__ void resize_bwd_cubic_kernel( | |||
} | |||
} | |||
template <typename ctype> | |||
void backward_data_proxy( | |||
InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, | |||
int IW, int OH, int OW, cudaStream_t stream) { | |||
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, | |||
int C, int IH, int IW, int OH, int OW, cudaStream_t stream) { | |||
const int BY = 16, BX = 32; | |||
{ | |||
dim3 threads(BX, BY); | |||
dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, N); | |||
cuda_check(cudaMemsetAsync(grad, 0, sizeof(float) * N * C * IH * IW, stream)); | |||
cuda_check(cudaMemsetAsync(grad, 0, sizeof(ctype) * N * C * IH * IW, stream)); | |||
float scale_h = static_cast<float>(OH) / IH; | |||
float scale_w = static_cast<float>(OW) / IW; | |||
switch (imode) { | |||
case InterpolationMode::INTER_LINEAR: { | |||
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
case InterpolationMode::INTER_NEAREST: { | |||
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
case InterpolationMode::INTER_CUBIC: { | |||
resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
default: { | |||
megdnn_throw("unsupported interpolation mode"); | |||
break; | |||
if (is_nhwc) { | |||
resize_bwd_nhwc_kernel<ctype, rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
} else { | |||
switch (imode) { | |||
case InterpolationMode::INTER_LINEAR: { | |||
resize_bwd_linear_kernel<ctype, rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
case InterpolationMode::INTER_NEAREST: { | |||
resize_bwd_nearest_kernel<ctype, rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
case InterpolationMode::INTER_CUBIC: { | |||
resize_bwd_cubic_kernel<ctype, rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
default: { | |||
megdnn_throw("unsupported interpolation mode"); | |||
break; | |||
} | |||
} | |||
} | |||
} | |||
after_kernel_launch(); | |||
} | |||
#define INST(ctype) \ | |||
template void backward_data_proxy( \ | |||
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \ | |||
int, cudaStream_t); | |||
INST(dt_float32); | |||
DNN_INC_FLOAT16(INST(dt_float16)); | |||
#undef INST | |||
} // namespace resize | |||
} // namespace cuda | |||
} // namespace megdnn | |||
@@ -20,9 +20,10 @@ void forward_proxy_nchw4( | |||
const ctype* src, ctype* dst, int N, int C, int IH, int IW, int OH, int OW, | |||
cudaStream_t stream); | |||
template <typename ctype> | |||
void backward_data_proxy( | |||
InterpolationMode imode, const float* diff, float* grad, int N, int C, int IH, | |||
int IW, int OH, int OW, cudaStream_t stream); | |||
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, | |||
int C, int IH, int IW, int OH, int OW, cudaStream_t stream); | |||
} // namespace resize | |||
} // namespace cuda | |||
@@ -148,6 +148,11 @@ void ResizeImpl::exec( | |||
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float32>(), | |||
dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, | |||
S_IH, S_IW, stream); | |||
} else if (src.layout.dtype == dtype::Float16{}) { | |||
resize::forward_proxy( | |||
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float16>(), | |||
dst.ptr<dt_float16>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, | |||
S_IH, S_IW, stream); | |||
} else if (src.layout.dtype == dtype::Uint8()) { | |||
resize::forward_proxy( | |||
is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_uint8>(), | |||
@@ -298,6 +298,7 @@ void forward_proxy_nchw4( | |||
INST(float) | |||
INST(uint8_t) | |||
INST(int8_t) | |||
DNN_INC_FLOAT16(INST(dt_float16)) | |||
#undef INST | |||
#define INST(ctype) \ | |||
@@ -387,40 +387,53 @@ void ResizeImpl::exec( | |||
} | |||
} | |||
void ResizeBackwardImpl::exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
check_exec(diff.layout, grad.layout, workspace.size); | |||
megdnn_assert( | |||
param().format == param::Resize::Format::NCHW, "invalid resize format"); | |||
const int N = grad.layout.shape[0], C = grad.layout.shape[1], | |||
IH = grad.layout.shape[2], IW = grad.layout.shape[3]; | |||
const int OH = diff.layout.shape[2], OW = diff.layout.shape[3]; | |||
const float* hptr_ = diff.ptr<dt_float32>(); | |||
float* sptr_ = grad.ptr<dt_float32>(); | |||
// ***************************Backward*************************** // | |||
template <typename ctype> | |||
void ResizeBackwardImpl::kern_naive( | |||
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, int N, | |||
int C, int IH, int IW, int OH, int OW) { | |||
float scale_h = static_cast<float>(OH) / IH; | |||
float scale_w = static_cast<float>(OW) / IW; | |||
rounding::RoundingConverter<ctype> output_converter; | |||
auto kern = [=]() { | |||
auto hptr = hptr_; | |||
auto sptr = sptr_; | |||
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); | |||
auto hptr = diff; | |||
auto sptr = grad; | |||
std::memset(sptr, 0, sizeof(ctype) * N * C * IH * IW); | |||
rep(n, N) { | |||
rep(oh, OH) rep(ow, OW) { | |||
switch (param().imode) { | |||
switch (imode) { | |||
case InterpolationMode::INTER_LINEAR: { | |||
int ih0, ih1, iw0, iw1; | |||
float ah0, ah1, aw0, aw1; | |||
std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord( | |||
param().imode, scale_h, IH, oh); | |||
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord( | |||
param().imode, scale_w, IW, ow); | |||
rep(c, C) { | |||
float hidden = hptr[c * OH * OW + oh * OW + ow]; | |||
sptr[c * IH * IW + ih0 * IW + iw0] += ah0 * aw0 * hidden; | |||
sptr[c * IH * IW + ih1 * IW + iw0] += ah1 * aw0 * hidden; | |||
sptr[c * IH * IW + ih0 * IW + iw1] += ah0 * aw1 * hidden; | |||
sptr[c * IH * IW + ih1 * IW + iw1] += ah1 * aw1 * hidden; | |||
std::tie(ah0, ih0, ah1, ih1) = | |||
get_nearest_linear_coord(imode, scale_h, IH, oh); | |||
std::tie(aw0, iw0, aw1, iw1) = | |||
get_nearest_linear_coord(imode, scale_w, IW, ow); | |||
if (is_nhwc) { | |||
rep(c, C) { | |||
sptr[(ih0 * IW + iw0) * C + c] += output_converter( | |||
hptr[(oh * OW + ow) * C + c] * ah0 * aw0); | |||
sptr[(ih0 * IW + iw1) * C + c] += output_converter( | |||
hptr[(oh * OW + ow) * C + c] * ah0 * aw1); | |||
sptr[(ih1 * IW + iw0) * C + c] += output_converter( | |||
hptr[(oh * OW + ow) * C + c] * ah1 * aw0); | |||
sptr[(ih1 * IW + iw1) * C + c] += output_converter( | |||
hptr[(oh * OW + ow) * C + c] * ah1 * aw1); | |||
} | |||
} else { | |||
rep(c, C) { | |||
float hidden = hptr[c * OH * OW + oh * OW + ow]; | |||
sptr[c * IH * IW + ih0 * IW + iw0] += | |||
output_converter(ah0 * aw0 * hidden); | |||
sptr[c * IH * IW + ih1 * IW + iw0] += | |||
output_converter(ah1 * aw0 * hidden); | |||
sptr[c * IH * IW + ih0 * IW + iw1] += | |||
output_converter(ah0 * aw1 * hidden); | |||
sptr[c * IH * IW + ih1 * IW + iw1] += | |||
output_converter(ah1 * aw1 * hidden); | |||
} | |||
} | |||
break; | |||
} | |||
@@ -429,7 +442,7 @@ void ResizeBackwardImpl::exec( | |||
auto iw = get_nearest_src(scale_w, IW, ow); | |||
rep(c, static_cast<int>(C)) { | |||
sptr[c * IH * IW + ih * IW + iw] += | |||
hptr[c * OH * OW + oh * OW + ow]; | |||
output_converter(hptr[c * OH * OW + oh * OW + ow]); | |||
} | |||
break; | |||
} | |||
@@ -452,9 +465,9 @@ void ResizeBackwardImpl::exec( | |||
int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | |||
rep(kw, ksize) { | |||
int w = saturate<int, int>(iw0 + kw, 0, IW - 1); | |||
sptr[c * IH * IW + h * IW + w] += | |||
sptr[c * IH * IW + h * IW + w] += output_converter( | |||
hptr[c * OH * OW + oh * OW + ow] * | |||
h_coeff[kh] * w_coeff[kw]; | |||
h_coeff[kh] * w_coeff[kw]); | |||
} | |||
} | |||
} | |||
@@ -473,4 +486,59 @@ void ResizeBackwardImpl::exec( | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | |||
} | |||
#define INST(ctype) \ | |||
template void ResizeBackwardImpl::kern_naive( \ | |||
bool, InterpolationMode, const ctype*, ctype*, int, int, int, int, int, \ | |||
int); | |||
INST(dt_float32); | |||
DNN_INC_FLOAT16(INST(dt_float16)); | |||
#undef INST | |||
void ResizeBackwardImpl::exec( | |||
_megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
check_exec(diff.layout, grad.layout, workspace.size); | |||
megdnn_assert( | |||
param().format == param::Resize::Format::NCHW || | |||
param().format == param::Resize::Format::NHWC, | |||
"invalid resize format"); | |||
size_t N, C, IH, IW, OH, OW; | |||
bool is_nhwc = param().format == param::Resize::Format::NHWC; | |||
if (is_nhwc) { | |||
if (param().imode != Param::InterpolationMode::LINEAR && | |||
is_nhwc_contig_wc(grad.layout)) { | |||
megdnn_assert( | |||
0, | |||
"unsupport mode in resizeBackward, only support param().imode = " | |||
"LINEAR"); | |||
} | |||
N = grad.layout.shape[0]; | |||
C = grad.layout.shape[3]; | |||
IH = grad.layout.shape[1]; | |||
IW = grad.layout.shape[2]; | |||
OH = diff.layout.shape[1]; | |||
OW = diff.layout.shape[2]; | |||
} else { | |||
N = grad.layout.shape[0], C = grad.layout.shape[1], IH = grad.layout.shape[2], | |||
IW = grad.layout.shape[3]; | |||
OH = diff.layout.shape[2], OW = diff.layout.shape[3]; | |||
} | |||
switch (grad.layout.dtype.enumv()) { | |||
#define cb(_t) \ | |||
case DTypeTrait<_t>::enumv: { \ | |||
typedef DTypeTrait<_t>::ctype ct; \ | |||
ct* diff_ptr = diff.ptr<ct>(); \ | |||
ct* grad_ptr = grad.ptr<ct>(); \ | |||
ResizeBackwardImpl::kern_naive( \ | |||
is_nhwc, param().imode, diff_ptr, grad_ptr, N, C, IH, IW, OH, OW); \ | |||
break; \ | |||
} | |||
cb(megdnn::dtype::Float32); | |||
DNN_INC_FLOAT16(cb(megdnn::dtype::Float16)); | |||
default: | |||
megdnn_throw(ssprintf( | |||
"unsupported dtype: %s in resize backward", | |||
grad.layout.dtype.name())); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -75,6 +75,12 @@ public: | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { | |||
return 0; | |||
} | |||
private: | |||
template <typename ctype> | |||
void kern_naive( | |||
bool is_nhwc, InterpolationMode imode, const ctype* diff, ctype* grad, | |||
int N, int C, int IH, int IW, int OH, int OW); | |||
}; | |||
} // namespace naive | |||
@@ -61,13 +61,67 @@ TEST_F(CUDA, RESIZE_FORWARD) { | |||
.set_epsilon(1) | |||
.execs({arg.src, arg.dst}); | |||
} | |||
for (auto&& arg : args) { | |||
checker.set_param(arg.param) | |||
.set_dtype(0, dtype::Float16()) | |||
.set_dtype(1, dtype::Float16()) | |||
.set_epsilon(1e-3) | |||
.execs({arg.src, arg.dst}); | |||
} | |||
} | |||
} | |||
TEST_F(CUDA, RESIZE_NHWC) { | |||
using namespace resize; | |||
std::vector<TestArg> args; | |||
param::Resize param; | |||
param.format = param::Resize::Format::NHWC; | |||
param.imode = param::Resize::InterpolationMode::LINEAR; | |||
args.emplace_back(param, TensorShape{1, 1, 4, 5}, TensorShape{1, 1, 8, 5}); | |||
args.emplace_back(param, TensorShape{2, 6, 4, 5}, TensorShape{2, 3, 8, 5}); | |||
args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); | |||
Checker<ResizeBackward> checkerBackward(handle_cuda()); | |||
for (auto&& arg : args) { | |||
checkerBackward.set_param(arg.param) | |||
.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_epsilon(1e-3) | |||
.execs({arg.src, arg.dst}); | |||
} | |||
for (auto&& arg : args) { | |||
checkerBackward.set_param(arg.param) | |||
.set_dtype(0, dtype::Float16()) | |||
.set_dtype(1, dtype::Float16()) | |||
.set_epsilon(1e-3) | |||
.execs({arg.src, arg.dst}); | |||
} | |||
Checker<ResizeForward> checkerForward(handle_cuda()); | |||
for (auto&& arg : args) { | |||
checkerForward.set_param(arg.param) | |||
.set_dtype(0, dtype::Float16()) | |||
.set_dtype(1, dtype::Float16()) | |||
.set_epsilon(1e-3) | |||
.execs({arg.src, arg.dst}); | |||
} | |||
for (auto&& arg : args) { | |||
checkerForward.set_param(arg.param) | |||
.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_epsilon(1e-3) | |||
.execs({arg.src, arg.dst}); | |||
} | |||
} | |||
TEST_F(CUDA, RESIZE_NCHW4) { | |||
using namespace resize; | |||
Checker<Resize> checker(handle_cuda()); | |||
auto args = get_nchw4_args(); | |||
for (auto&& arg : args) { | |||
checker.set_param(arg.param) | |||
@@ -113,6 +167,24 @@ TEST_F(CUDA, RESIZE_BACKWARD) { | |||
param.format = param::Resize::Format::NCHW; | |||
param.imode = imode; | |||
checker.set_param(param); | |||
checker.set_dtype(0, dtype::Float16()); | |||
checker.set_dtype(1, dtype::Float16()); | |||
checker.set_epsilon(1 + 1e-3); | |||
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); | |||
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); | |||
checker.execs({{2, 5, 8, 5}, {2, 5, 4, 9}}); | |||
checker.execs({{2, 5, 4, 9}, {2, 5, 8, 5}}); | |||
} | |||
for (auto imode : modes) { | |||
Checker<ResizeBackward> checker(handle_cuda()); | |||
param::Resize param; | |||
param.format = param::Resize::Format::NCHW; | |||
param.imode = imode; | |||
checker.set_param(param); | |||
checker.set_dtype(0, dtype::Float32()); | |||
checker.set_dtype(1, dtype::Float32()); | |||
checker.execs({{2, 3, 4, 5}, {2, 3, 8, 9}}); | |||
checker.execs({{2, 5, 8, 9}, {2, 5, 4, 5}}); | |||