GitOrigin-RevId: f7023a3fd3
tags/v1.6.0-rc1
@@ -197,7 +197,7 @@ public: | |||
protected: | |||
//! get origin coord | |||
std::pair<float, int> get_origin_coord(float scale, int size, int idx); | |||
std::pair<float, int> get_origin_coord(float scale, int size, int idx, bool cubic=false); | |||
//! get nearest index in src | |||
int get_nearest_src(float scale, int size, int idx); | |||
@@ -11,6 +11,7 @@ | |||
*/ | |||
#include "megdnn/handle.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
@@ -29,8 +30,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src, | |||
if (param().format == Param::Format::NCHW) { | |||
megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str()); | |||
auto imode = param().imode; | |||
megdnn_assert(imode == param::Resize::InterpolationMode::INTER_LINEAR || | |||
imode == param::Resize::InterpolationMode::NEAREST); | |||
using IMode = param::Resize::InterpolationMode; | |||
megdnn_assert(imode == IMode::INTER_LINEAR || imode == IMode::NEAREST || | |||
imode == IMode::INTER_CUBIC); | |||
} else if (param().format == Param::Format::NHWC) { | |||
megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str()); | |||
} else if (param().format == Param::Format::NCHW4) { | |||
@@ -66,19 +68,20 @@ void ResizeBackward::check_exec(const TensorLayout& diff, | |||
} | |||
std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size, | |||
int idx) { | |||
int idx, bool cubic) { | |||
//! copy from resize_cv.cpp | |||
float alpha = (idx + 0.5f) / scale - 0.5f; | |||
int origin_idx = static_cast<int>(floor(alpha)); | |||
alpha -= origin_idx; | |||
if (origin_idx < 0) { | |||
origin_idx = 0; | |||
alpha = 0; | |||
} else if (origin_idx + 1 >= size) { | |||
origin_idx = size - 2; | |||
alpha = 1; | |||
if (!cubic) { | |||
if (origin_idx < 0) { | |||
origin_idx = 0; | |||
alpha = 0; | |||
} else if (origin_idx + 1 >= size) { | |||
origin_idx = size - 2; | |||
alpha = 1; | |||
} | |||
} | |||
return {alpha, origin_idx}; | |||
} | |||
@@ -0,0 +1,39 @@ | |||
/** | |||
* \file dnn/src/common/resize.cuh | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
#if MEGDNN_CC_HOST && !defined(__host__) | |||
#if __GNUC__ || __has_attribute(always_inline) | |||
#define __forceinline__ inline __attribute__((always_inline)) | |||
#else | |||
#define __forceinline__ inline | |||
#endif | |||
#endif | |||
namespace megdnn { | |||
namespace resize { | |||
MEGDNN_HOST MEGDNN_DEVICE __forceinline__ void interpolate_cubic( | |||
float x, float* coeffs) { | |||
const float A = -0.75f; | |||
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; | |||
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; | |||
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; | |||
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; | |||
} | |||
} // namespace resize | |||
} // namespace megdnn | |||
/* vim: set ft=cpp: */ |
@@ -71,7 +71,10 @@ struct RoundingConverter<uint8_t> { | |||
__host__ __device__ __forceinline__ uint8_t operator()(float x) const { | |||
#if MEGDNN_CC_HOST | |||
using std::round; | |||
using std::max; | |||
using std::min; | |||
#endif | |||
x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places | |||
return static_cast<uint8_t>(round(x)); | |||
} | |||
}; | |||
@@ -11,6 +11,7 @@ | |||
#pragma once | |||
#include "src/common/cv/enums.h" | |||
#include "src/common/resize.cuh" | |||
#include "megdnn/basic_types.h" | |||
@@ -49,15 +50,6 @@ __device__ inline void interpolate_linear_coefs(float x, float* coeffs) { | |||
coeffs[1] = x; | |||
} | |||
__host__ __device__ inline void interpolate_cubic_coefs(float x, | |||
float* coeffs) { | |||
const float A = -0.75f; | |||
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; | |||
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; | |||
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; | |||
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; | |||
} | |||
__device__ inline void interpolate_lanczos4_coefs(float x, float* coeffs) { | |||
const float s45 = 0.70710678118654752440084436210485; | |||
const float cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, | |||
@@ -197,7 +189,7 @@ __device__ inline void interpolate_coefs<INTER_LINEAR>(float x, float* coeffs) { | |||
} | |||
template <> | |||
__device__ inline void interpolate_coefs<INTER_CUBIC>(float x, float* coeffs) { | |||
interpolate_cubic_coefs(x, coeffs); | |||
megdnn::resize::interpolate_cubic(x, coeffs); | |||
} | |||
template <> | |||
__device__ inline void interpolate_coefs<INTER_LANCZOS4>(float x, | |||
@@ -12,6 +12,10 @@ | |||
#include "src/cuda/resize/common.h" | |||
#include "src/cuda/utils.cuh" | |||
#include "src/cuda/cv/kernel_common.cuh" | |||
using megdnn::resize::interpolate_cubic; | |||
using megdnn::megcv::saturate; | |||
namespace megdnn { | |||
namespace cuda { | |||
@@ -72,6 +76,42 @@ __global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst, | |||
} | |||
} | |||
} | |||
__global__ void resize_bwd_cubic_kernel(const float* hidden, float* dst, int N, | |||
int C, int IH, int IW, int OH, int OW, | |||
float scale_h, float scale_w) { | |||
int n = blockIdx.z; | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
hidden += n * C * OH * OW; | |||
dst += n * C * IH * IW; | |||
if (ow < OW && oh < OH) { | |||
float alphah, alphaw; | |||
int ih0, iw0; | |||
get_origin_coord(scale_h, IH, oh, alphah, ih0, true); | |||
get_origin_coord(scale_w, IW, ow, alphaw, iw0, true); | |||
ih0--; | |||
iw0--; | |||
float h_coeff[4], w_coeff[4]; | |||
interpolate_cubic(alphah, h_coeff); | |||
interpolate_cubic(alphaw, w_coeff); | |||
for (int c = 0; c < C; ++c) { | |||
constexpr int ksize = 4; | |||
for (int kh = 0; kh < ksize; kh++) { | |||
int ih = saturate(ih0 + kh, 0, IH - 1); | |||
for (int kw = 0; kw < ksize; kw++) { | |||
int iw = saturate(iw0 + kw, 0, IW - 1); | |||
atomicAdd(dst + ih * IW + iw, | |||
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]); | |||
} | |||
} | |||
hidden += OH * OW; | |||
dst += IH * IW; | |||
} | |||
} | |||
} | |||
void backward_data_proxy(InterpolationMode imode, const float* diff, | |||
float* grad, int N, int C, int IH, int IW, int OH, | |||
int OW, cudaStream_t stream) { | |||
@@ -83,13 +123,26 @@ void backward_data_proxy(InterpolationMode imode, const float* diff, | |||
stream)); | |||
float scale_h = static_cast<float>(OH) / IH; | |||
float scale_w = static_cast<float>(OW) / IW; | |||
if(imode == InterpolationMode::INTER_LINEAR) { | |||
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
} | |||
else if (imode == InterpolationMode::INTER_NEAREST) { | |||
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
switch (imode) { | |||
case InterpolationMode::INTER_LINEAR: { | |||
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
case InterpolationMode::INTER_NEAREST: { | |||
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
case InterpolationMode::INTER_CUBIC: { | |||
resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>( | |||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||
break; | |||
} | |||
default: { | |||
megdnn_throw("unsupported interpolation mode"); | |||
break; | |||
} | |||
} | |||
} | |||
after_kernel_launch(); | |||
@@ -15,16 +15,19 @@ namespace cuda { | |||
namespace resize { | |||
__device__ inline void get_origin_coord(float scale, int size, int idx, | |||
float& alpha, int& origin_idx) { | |||
float& alpha, int& origin_idx, | |||
bool cubic = false) { | |||
alpha = (idx + 0.5f) / scale - 0.5f; | |||
origin_idx = static_cast<int>(floor(alpha)); | |||
alpha -= origin_idx; | |||
if (origin_idx < 0) { | |||
origin_idx = 0; | |||
alpha = 0; | |||
} else if (origin_idx + 1 >= size) { | |||
origin_idx = size - 2; | |||
alpha = 1; | |||
if (!cubic) { | |||
if (origin_idx < 0) { | |||
origin_idx = 0; | |||
alpha = 0; | |||
} else if (origin_idx + 1 >= size) { | |||
origin_idx = size - 2; | |||
alpha = 1; | |||
} | |||
} | |||
} | |||
@@ -147,9 +147,11 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||
C, IH, IW, OH, OW, stream); | |||
return; | |||
} | |||
megdnn_assert(param().imode == Param::InterpolationMode::LINEAR || | |||
param().imode == Param::InterpolationMode::NEAREST, | |||
"unsupported interpolation mode for NCHW format"); | |||
megdnn_assert( | |||
param().imode == Param::InterpolationMode::LINEAR || | |||
param().imode == Param::InterpolationMode::NEAREST || | |||
param().imode == Param::InterpolationMode::INTER_CUBIC, | |||
"unsupported interpolation mode for NCHW format"); | |||
if (src.layout.dtype == dtype::Float32{}) { | |||
resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)), | |||
@@ -8,15 +8,20 @@ | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/rounding_converter.cuh" | |||
#include "src/common/utils.cuh" | |||
#include "src/cuda/resize/common.cuh" | |||
#include "src/cuda/resize/common.h" | |||
#include "src/common/rounding_converter.cuh" | |||
#include "src/cuda/resize/resize_cv.cuh" | |||
#include "src/cuda/utils.cuh" | |||
#include "src/cuda/cv/kernel_common.cuh" | |||
#include "src/common/resize.cuh" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace resize; | |||
using namespace megdnn::cuda::resize; | |||
using megdnn::resize::interpolate_cubic; | |||
using megdnn::megcv::saturate; | |||
namespace { | |||
@@ -81,8 +86,7 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst, | |||
int iw = get_nearest_src(scale_w, IW, ow); | |||
for (int c = 0; c < C; ++c) { | |||
dst[oh * OW + ow] = output_converter( | |||
sptr[ih * S_IH + iw * S_IW]); | |||
dst[oh * OW + ow] = output_converter(sptr[ih * S_IH + iw * S_IW]); | |||
sptr += S_IC; | |||
dst += OH * OW; | |||
@@ -91,6 +95,45 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst, | |||
} | |||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | |||
__global__ void kern_general_cubic(SrcVisitor src, ctype* __restrict dst, int C, | |||
int IH, int IW, int OH, int OW, int S_IN, | |||
int S_IC, int S_IH, int S_IW, float scale_h, | |||
float scale_w) { | |||
OutputConverter output_converter; | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN); | |||
dst += blockIdx.z * C * OH * OW; | |||
if (ow < OW && oh < OH) { | |||
float alphah, alphaw; | |||
int ih0, iw0; | |||
get_origin_coord(scale_h, IH, oh, alphah, ih0, true); | |||
get_origin_coord(scale_w, IW, ow, alphaw, iw0, true); | |||
ih0--; | |||
iw0--; | |||
float h_coeff[4], w_coeff[4]; | |||
interpolate_cubic(alphah, h_coeff); | |||
interpolate_cubic(alphaw, w_coeff); | |||
for (int c = 0; c < C; ++c) { | |||
float ret = 0; | |||
constexpr int ksize = 4; | |||
for (int kh = 0; kh < ksize; kh++) { | |||
int ih = saturate(ih0 + kh, 0, IH - 1); | |||
for (int kw = 0; kw < ksize; kw++) { | |||
int iw = saturate(iw0 + kw, 0, IW - 1); | |||
ret += sptr[ih * S_IH + iw * S_IW] * h_coeff[kh] * | |||
w_coeff[kw]; | |||
} | |||
} | |||
dst[oh * OW + ow] = output_converter(ret); | |||
sptr += S_IC; | |||
dst += OH * OW; | |||
} | |||
} | |||
} | |||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | |||
__global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C, | |||
int IH, int IW, int OH, int OW, float scale_h, | |||
float scale_w) { | |||
@@ -140,18 +183,31 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode, | |||
<<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH, | |||
OW, scale_h, scale_w); | |||
} else { | |||
if (imode == InterpolationMode::INTER_LINEAR) { | |||
kern_general_linear<ctype, SrcVisitor, | |||
rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH, | |||
S_IW, scale_h, scale_w); | |||
} else if (imode == InterpolationMode::INTER_NEAREST) { | |||
kern_general_nearest<ctype, SrcVisitor, | |||
rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH, | |||
S_IW, scale_h, scale_w); | |||
switch (imode) { | |||
case InterpolationMode::INTER_LINEAR: | |||
kern_general_linear<ctype, SrcVisitor, | |||
rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, | |||
S_IH, S_IW, scale_h, scale_w); | |||
break; | |||
case InterpolationMode::INTER_NEAREST: | |||
kern_general_nearest<ctype, SrcVisitor, | |||
rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, | |||
S_IH, S_IW, scale_h, scale_w); | |||
break; | |||
case InterpolationMode::INTER_CUBIC: | |||
kern_general_cubic<ctype, SrcVisitor, | |||
rounding::RoundingConverter<ctype>> | |||
<<<blocks, threads, 0, stream>>>( | |||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, | |||
S_IH, S_IW, scale_h, scale_w); | |||
break; | |||
default: | |||
megdnn_throw("unsupported interpolation mode"); | |||
break; | |||
} | |||
} | |||
N -= curr_batch_size; | |||
@@ -162,8 +218,8 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode, | |||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | |||
__global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, | |||
int IH, int IW, int OH, int OW, float scale_h, | |||
float scale_w) { | |||
int IH, int IW, int OH, int OW, | |||
float scale_h, float scale_w) { | |||
OutputConverter output_converter; | |||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
@@ -188,10 +244,11 @@ __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, | |||
#pragma unroll | |||
for (int c1 = 0; c1 < 4; ++c1) { | |||
dst[o_coor + c1] = output_converter( | |||
sptr[i_coor00 + c1] * (1.0f - alphaw) * (1.0f - alphah) + | |||
sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) + | |||
sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah + | |||
sptr[i_coor11 + c1] * alphaw * alphah); | |||
sptr[i_coor00 + c1] * (1.0f - alphaw) * | |||
(1.0f - alphah) + | |||
sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) + | |||
sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah + | |||
sptr[i_coor11 + c1] * alphaw * alphah); | |||
} | |||
dst += OH * OW * 4; | |||
sptr += IH * IW * 4; | |||
@@ -250,18 +307,18 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH, | |||
after_kernel_launch(); | |||
} | |||
#define INST(ctype) \ | |||
template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, int, int, int, \ | |||
int, int, int, int, int, int, int, \ | |||
cudaStream_t); | |||
#define INST(ctype) \ | |||
template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, \ | |||
int, int, int, int, int, int, int, int, int, \ | |||
int, cudaStream_t); | |||
INST(float) | |||
INST(uint8_t) | |||
INST(int8_t) | |||
#undef INST | |||
#define INST(ctype) \ | |||
#define INST(ctype) \ | |||
template void forward_proxy_nchw4(const ctype*, ctype*, int, int, int, \ | |||
int, int, int, cudaStream_t) | |||
int, int, int, cudaStream_t) | |||
INST(int8_t); | |||
#undef INST | |||
@@ -59,12 +59,14 @@ | |||
* --------------------------------------------------------------------------- | |||
*/ | |||
#include "src/cuda/cv/kernel_common.cuh" | |||
#include "src/common/resize.cuh" | |||
#include "src/cuda/resize/resize_cv.cuh" | |||
#include "src/cuda/utils.cuh" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace megcv; | |||
using megdnn::resize::interpolate_cubic; | |||
namespace { | |||
@@ -126,7 +128,7 @@ __global__ void precompute_cubic_coef_f32(float* dst, float scale, | |||
fr -= sr[tid]; | |||
float coef[4]; | |||
interpolate_cubic_coefs(fr, coef); | |||
interpolate_cubic(fr, coef); | |||
#pragma unroll | |||
for (int j = 0, index = 0; j < 4; j++, index += size) { | |||
dst[tid + index] = coef[j]; | |||
@@ -144,7 +146,7 @@ __global__ void precompute_cubic_coef_u8(short* dst, float scale, size_t size) { | |||
fr -= sr[tid]; | |||
float coef[4]; | |||
interpolate_cubic_coefs(fr, coef); | |||
interpolate_cubic(fr, coef); | |||
#pragma unroll | |||
for (int j = 0, index = 0; j < 4; j++, index += size) { | |||
dst[tid + index] = (short)(coef[j] * ONE); | |||
@@ -406,7 +408,7 @@ __global__ void resize_cubic_32f_kernel_vector( | |||
int sc = floor(fc); | |||
fc -= sc; | |||
float coef_col[4]; | |||
interpolate_cubic_coefs(fc, coef_col); | |||
interpolate_cubic(fc, coef_col); | |||
for (int i = 0; i < ELEMENTS_PER_THREADS; i++) { | |||
if (dr >= dst_rows) | |||
@@ -415,7 +417,7 @@ __global__ void resize_cubic_32f_kernel_vector( | |||
int sr = floor(fr); | |||
fr -= sr; | |||
float coef_row[4]; | |||
interpolate_cubic_coefs(fr, coef_row); | |||
interpolate_cubic(fr, coef_row); | |||
float dst_data[CH] = {0}; | |||
#pragma unroll | |||
for (int offset_r = 0; offset_r < 4; ++offset_r) { | |||
@@ -459,7 +461,7 @@ __global__ void resize_cubic_8u_kernel_vector( | |||
short icoef_col[4] = {0}; | |||
float coef_col[4]; | |||
interpolate_cubic_coefs(fc, coef_col); | |||
interpolate_cubic(fc, coef_col); | |||
#pragma unroll | |||
for (int i = 0; i < 4; i++) { | |||
icoef_col[i] = (short)(coef_col[i] * ONE); | |||
@@ -473,7 +475,7 @@ __global__ void resize_cubic_8u_kernel_vector( | |||
fr -= sr; | |||
short icoef_row[4]; | |||
float coef_row[4]; | |||
interpolate_cubic_coefs(fr, coef_row); | |||
interpolate_cubic(fr, coef_row); | |||
#pragma unroll | |||
for (int i = 0; i < 4; i++) { | |||
icoef_row[i] = (short)(coef_row[i] * ONE); | |||
@@ -118,7 +118,7 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
if (param().format == param::Resize::Format::NCHW4 || | |||
(param().format == param::Resize::Format::NCHW && | |||
param().imode == param::Resize::InterpolationMode::NEAREST)) { | |||
param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) { | |||
naive::ResizeImpl::exec(src, dst, workspace); | |||
return; | |||
} | |||
@@ -9,18 +9,21 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/naive/resize/opr_impl.h" | |||
#include "midout.h" | |||
#include "src/common/cv/enums.h" | |||
#include "src/common/resize.cuh" | |||
#include "src/common/rounding_converter.cuh" | |||
#include "src/common/utils.cuh" | |||
#include "src/naive/handle.h" | |||
#include "src/naive/resize/opr_impl.h" | |||
#include "src/naive/resize/resize_cv.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_naive_resize_layout) | |||
MIDOUT_DECL(megdnn_naive_resize_layout_nearest) | |||
MIDOUT_DECL(megdnn_naive_resize_nchw) | |||
using namespace megdnn; | |||
using namespace naive; | |||
using namespace resize; | |||
template <typename ctype> | |||
ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors( | |||
@@ -90,20 +93,84 @@ INST(dt_quint8); | |||
#undef INST | |||
template <typename ctype> | |||
void ResizeImpl::kern_nchw_nearest (const KernParam<ctype>& kern_param) { | |||
void ResizeImpl::kern_nchw(const KernParam<ctype>& kern_param, | |||
InterpolationMode imode) { | |||
megdnn_assert(kern_param.format == Format::NCHW); | |||
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param); | |||
float scale_h = static_cast<float>(OH) / IH; | |||
float scale_w = static_cast<float>(OW) / IW; | |||
rounding::RoundingConverter<ctype> output_converter; | |||
rep(n, N) { | |||
rep(oh, OH) rep(ow, OW) { | |||
auto ih = get_nearest_src(scale_h, IH, oh); | |||
auto iw = get_nearest_src(scale_w, IW, ow); | |||
switch (imode) { | |||
case InterpolationMode::NEAREST: { | |||
auto ih = get_nearest_src(scale_h, IH, oh); | |||
auto iw = get_nearest_src(scale_w, IW, ow); | |||
rep(c, static_cast<int>(C)) { | |||
dptr[c * OH * OW + oh * OW + ow] = | |||
sptr[c * S_IC + ih * S_IH + iw * S_IW]; | |||
} | |||
break; | |||
} | |||
case InterpolationMode::INTER_LINEAR: { | |||
auto coord_h = get_origin_coord(scale_h, IH, oh); | |||
auto coord_w = get_origin_coord(scale_w, IW, ow); | |||
float alphah = coord_h.first; | |||
float alphaw = coord_w.first; | |||
int ih0 = coord_h.second; | |||
int ih1 = ih0 + 1; | |||
int iw0 = coord_w.second; | |||
int iw1 = iw0 + 1; | |||
rep(c, static_cast<int>(C)) { | |||
dptr[c * OH * OW + oh * OW + ow] = output_converter( | |||
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * | |||
(1.0f - alphaw) * (1.0f - alphah) + | |||
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * | |||
alphaw * (1.0f - alphah) + | |||
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * | |||
(1.0f - alphaw) * alphah + | |||
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * | |||
alphaw * alphah); | |||
} | |||
break; | |||
} | |||
case InterpolationMode::INTER_CUBIC: { | |||
auto coord_h = get_origin_coord(scale_h, IH, oh, true); | |||
auto coord_w = get_origin_coord(scale_w, IW, ow, true); | |||
float alphah = coord_h.first; | |||
float alphaw = coord_w.first; | |||
rep(c, static_cast<int>(C)) { | |||
dptr[c * OH * OW + oh * OW + ow] = sptr[c * S_IC + ih * S_IH + iw * S_IW]; | |||
int ih0 = coord_h.second - 1; | |||
int iw0 = coord_w.second - 1; | |||
float h_coeff[4], w_coeff[4]; | |||
interpolate_cubic(alphah, h_coeff); | |||
interpolate_cubic(alphaw, w_coeff); | |||
rep(c, static_cast<int>(C)) { | |||
constexpr int ksize = 4; | |||
float ret = 0; | |||
rep(kh, ksize) { | |||
int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | |||
rep(kw, ksize) { | |||
int w = saturate<int, int>(iw0 + kw, 0, IW - 1); | |||
ret += sptr[c * S_IC + h * S_IH + w * S_IW] * | |||
h_coeff[kh] * w_coeff[kw]; | |||
} | |||
} | |||
dptr[c * OH * OW + oh * OW + ow] = | |||
output_converter(ret); | |||
} | |||
break; | |||
} | |||
default: | |||
megdnn_throw("unsupported mode in ResizeBackwardImpl"); | |||
break; | |||
} | |||
} | |||
sptr += S_IN; | |||
@@ -131,40 +198,6 @@ void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) { | |||
MIDOUT_END(); | |||
return; | |||
} | |||
megdnn_assert(kern_param.format == Format::NCHW); | |||
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param); | |||
rounding::RoundingConverter<ctype> output_converter; | |||
float scale_h = static_cast<float>(OH) / IH; | |||
float scale_w = static_cast<float>(OW) / IW; | |||
rep(n, N) { | |||
rep(oh, OH) rep(ow, OW) { | |||
auto coord_h = get_origin_coord(scale_h, IH, oh); | |||
auto coord_w = get_origin_coord(scale_w, IW, ow); | |||
float alphah = coord_h.first; | |||
float alphaw = coord_w.first; | |||
int ih0 = coord_h.second; | |||
int ih1 = ih0 + 1; | |||
int iw0 = coord_w.second; | |||
int iw1 = iw0 + 1; | |||
rep(c, static_cast<int>(C)) { | |||
dptr[c * OH * OW + oh * OW + ow] = output_converter( | |||
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * | |||
(1.0f - alphaw) * (1.0f - alphah) + | |||
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * alphaw * | |||
(1.0f - alphah) + | |||
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * | |||
(1.0f - alphaw) * alphah + | |||
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * alphaw * | |||
alphah); | |||
} | |||
} | |||
sptr += S_IN; | |||
dptr += C * OH * OW; | |||
} | |||
} | |||
template <typename ctype> | |||
@@ -290,18 +323,16 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) { | |||
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||
_megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
if (param().format == param::Resize::Format::NCHW && | |||
param().imode == param::Resize::InterpolationMode::NEAREST) { | |||
#define cb(dt, ct, _midout_iv) \ | |||
case DTypeTrait<dt>::enumv: { \ | |||
MIDOUT_BEGIN(megdnn_naive_resize_layout_nearest, \ | |||
midout_iv(_midout_iv)) { \ | |||
auto kparam = KernParam<ct>::from_tensors(param().format, src, \ | |||
dst, workspace); \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw_nearest(kparam)); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return; \ | |||
if (param().format == param::Resize::Format::NCHW) { | |||
#define cb(dt, ct, _midout_iv) \ | |||
case DTypeTrait<dt>::enumv: { \ | |||
MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \ | |||
auto kparam = KernParam<ct>::from_tensors(param().format, src, \ | |||
dst, workspace); \ | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \ | |||
} \ | |||
MIDOUT_END(); \ | |||
return; \ | |||
} | |||
switch (src.layout.dtype.enumv()) { | |||
@@ -320,11 +351,9 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||
} | |||
#undef cb | |||
#undef cb | |||
} | |||
if ((param().format == param::Resize::Format::NCHW || | |||
(src.layout[3] != 1 && src.layout[3] != 3) || | |||
if (((src.layout[3] != 1 && src.layout[3] != 3) || | |||
!is_nhwc_contig_wc(src.layout)) || | |||
(param().imode == param::Resize::InterpolationMode::LINEAR)) { | |||
#define cb(dt, ct, _midout_iv) \ | |||
@@ -378,37 +407,73 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); | |||
rep(n, N) { | |||
rep(oh, OH) rep(ow, OW) { | |||
if(param().imode == InterpolationMode::INTER_LINEAR) { | |||
auto coord_h = get_origin_coord(scale_h, IH, oh); | |||
auto coord_w = get_origin_coord(scale_w, IW, ow); | |||
float alphah = coord_h.first; | |||
float alphaw = coord_w.first; | |||
int ih0 = coord_h.second; | |||
int ih1 = ih0 + 1; | |||
int iw0 = coord_w.second; | |||
int iw1 = iw0 + 1; | |||
rep(c, C) { | |||
float hidden = hptr[c * OH * OW + oh * OW + ow]; | |||
sptr[c * IH * IW + ih0 * IW + iw0] += | |||
(1.0f - alphaw) * (1.0f - alphah) * hidden; | |||
sptr[c * IH * IW + ih1 * IW + iw0] += | |||
(1.0f - alphaw) * alphah * hidden; | |||
sptr[c * IH * IW + ih0 * IW + iw1] += | |||
alphaw * (1.0f - alphah) * hidden; | |||
sptr[c * IH * IW + ih1 * IW + iw1] += | |||
alphaw * alphah * hidden; | |||
switch (param().imode) { | |||
case InterpolationMode::INTER_LINEAR: { | |||
auto coord_h = get_origin_coord(scale_h, IH, oh); | |||
auto coord_w = get_origin_coord(scale_w, IW, ow); | |||
float alphah = coord_h.first; | |||
float alphaw = coord_w.first; | |||
int ih0 = coord_h.second; | |||
int ih1 = ih0 + 1; | |||
int iw0 = coord_w.second; | |||
int iw1 = iw0 + 1; | |||
rep(c, C) { | |||
float hidden = hptr[c * OH * OW + oh * OW + ow]; | |||
sptr[c * IH * IW + ih0 * IW + iw0] += | |||
(1.0f - alphaw) * (1.0f - alphah) * hidden; | |||
sptr[c * IH * IW + ih1 * IW + iw0] += | |||
(1.0f - alphaw) * alphah * hidden; | |||
sptr[c * IH * IW + ih0 * IW + iw1] += | |||
alphaw * (1.0f - alphah) * hidden; | |||
sptr[c * IH * IW + ih1 * IW + iw1] += | |||
alphaw * alphah * hidden; | |||
} | |||
break; | |||
} | |||
} else if (param().imode == InterpolationMode::NEAREST) { | |||
auto ih = get_nearest_src(scale_h, IH, oh); | |||
auto iw = get_nearest_src(scale_w, IW, ow); | |||
rep(c, static_cast<int>(C)) { | |||
sptr[c * IH * IW + ih * IW + iw] += hptr[c * OH * OW + oh * OW + ow]; | |||
case InterpolationMode::NEAREST: { | |||
auto ih = get_nearest_src(scale_h, IH, oh); | |||
auto iw = get_nearest_src(scale_w, IW, ow); | |||
rep(c, static_cast<int>(C)) { | |||
sptr[c * IH * IW + ih * IW + iw] += | |||
hptr[c * OH * OW + oh * OW + ow]; | |||
} | |||
break; | |||
} | |||
case InterpolationMode::INTER_CUBIC: { | |||
auto coord_h = get_origin_coord(scale_h, IH, oh, true); | |||
auto coord_w = get_origin_coord(scale_w, IW, ow, true); | |||
float alphah = coord_h.first; | |||
float alphaw = coord_w.first; | |||
int ih0 = coord_h.second - 1; | |||
int iw0 = coord_w.second - 1; | |||
float h_coeff[4], w_coeff[4]; | |||
interpolate_cubic(alphah, h_coeff); | |||
interpolate_cubic(alphaw, w_coeff); | |||
rep(c, static_cast<int>(C)) { | |||
constexpr int ksize = 4; | |||
rep(kh, ksize) { | |||
int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | |||
rep(kw, ksize) { | |||
int w = saturate<int, int>(iw0 + kw, 0, IW - 1); | |||
sptr[c * IH * IW + h * IW + w] += | |||
hptr[c * OH * OW + oh * OW + ow] * | |||
h_coeff[kh] * w_coeff[kw]; | |||
} | |||
} | |||
} | |||
break; | |||
} | |||
default: { | |||
megdnn_throw("unsupported mode in ResizeBackwardImpl"); | |||
break; | |||
} | |||
} | |||
else megdnn_throw("unsupported mode in ResizeBackwardImpl"); | |||
} | |||
sptr += C * IH * IW; | |||
hptr += C * OH * OW; | |||
@@ -47,7 +47,7 @@ private: | |||
void kern_naive(const KernParam<ctype>& kern_param); | |||
template <typename ctype> | |||
void kern_nchw_nearest(const KernParam<ctype>& kern_param); | |||
void kern_nchw(const KernParam<ctype>& kern_param, InterpolationMode imode); | |||
template <typename ctype> | |||
void kern_naive_nhwc(const KernParam<ctype>& kern_param); | |||
@@ -68,6 +68,7 @@ | |||
#include "src/common/cv/helper.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/handle.h" | |||
#include "src/common/resize.cuh" | |||
MIDOUT_DECL(megdnn_naive_resizecv_imode) | |||
MIDOUT_DECL(megdnn_naive_resizecv_dtype) | |||
@@ -75,6 +76,7 @@ MIDOUT_DECL(megdnn_naive_resizecv_dtype) | |||
using namespace megdnn; | |||
using namespace naive; | |||
using namespace megcv; | |||
using namespace megdnn::resize; | |||
namespace { | |||
@@ -383,14 +385,6 @@ using ResizeAreaFunc = void (*)(const Mat<T>& src, Mat<T>& dst, | |||
const DecimateAlpha* ytab, int ytab_size, | |||
const int* yofs); | |||
static inline void interpolate_cubic(float x, float* coeffs) { | |||
const float A = -0.75f; | |||
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; | |||
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; | |||
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; | |||
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; | |||
} | |||
static inline void interpolate_lanczos4(float x, float* coeffs) { | |||
static const double s45 = 0.70710678118654752440084436210485; | |||
static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, | |||
@@ -43,7 +43,7 @@ TEST_F(CUDA, RESIZE_CV) { | |||
TEST_F(CUDA, RESIZE_FORWARD) { | |||
using namespace resize; | |||
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; | |||
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; | |||
for (auto imode : modes) { | |||
std::vector<TestArg> args = get_args(imode); | |||
Checker<Resize> checker(handle_cuda()); | |||
@@ -88,7 +88,7 @@ TEST_F(CUDA, RESIZE_NCHW4) { | |||
} | |||
TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { | |||
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; | |||
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; | |||
for (auto imode : modes) { | |||
param::Resize param; | |||
param.format = param::Resize::Format::NCHW; | |||
@@ -117,7 +117,7 @@ TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { | |||
} | |||
TEST_F(CUDA, RESIZE_BACKWARD) { | |||
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; | |||
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; | |||
for (auto imode : modes) { | |||
Checker<ResizeBackward> checker(handle_cuda()); | |||
param::Resize param; | |||
@@ -574,19 +574,25 @@ def interpolate( | |||
raise ValueError("under linear mode, size can only be single value") | |||
dsize = size | |||
if not align_corners and mode in ("bilinear", "nearest") and inp.ndim in [4, 5]: | |||
if not align_corners: | |||
# fastpath for interpolate | |||
op = builtin.Resize( | |||
imode="linear" if mode == "bilinear" else "nearest", format="NCHW" | |||
) | |||
mode_map = { | |||
"linear": "linear", | |||
"bilinear": "linear", | |||
"nearest": "nearest", | |||
"bicubic": "cubic", | |||
} | |||
op = builtin.Resize(imode=mode_map[mode], format="NCHW") | |||
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) | |||
(result,) = apply(op, inp, shape) | |||
return result | |||
oh, ow = dsize[0], dsize[1] | |||
ih, iw = inp.shape[2], inp.shape[3] | |||
if align_corners: | |||
(ret,) = apply(op, inp, shape) | |||
else: | |||
assert mode in [ | |||
"linear", | |||
"bilinear", | |||
], "align_corners only support linear or bilinear mode" | |||
oh, ow = dsize[0], dsize[1] | |||
ih, iw = inp.shape[2], inp.shape[3] | |||
hscale = (ih - 1.0) / (oh - 1.0) | |||
wscale = 1.0 * iw / ow | |||
if mode != "linear": | |||
@@ -607,34 +613,11 @@ def interpolate( | |||
axis=0, | |||
).reshape(1, 3, 3) | |||
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||
else: | |||
hscale = 1.0 * ih / oh | |||
wscale = 1.0 * iw / ow | |||
row0 = concat( | |||
[wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5], | |||
axis=0, | |||
).reshape(1, 3) | |||
row1 = concat( | |||
[Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5], | |||
axis=0, | |||
).reshape(1, 3) | |||
weight = concat( | |||
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | |||
axis=0, | |||
).reshape(1, 3, 3) | |||
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||
weight = weight.astype("float32") | |||
if mode in ["linear", "bilinear"]: | |||
ret = warp_perspective(inp, weight, dsize, interp_mode="linear") | |||
if mode == "linear": | |||
ret = reshape(ret, ret.shape[0:3]) | |||
else: | |||
# only NHWC format support "cubic" mode | |||
assert mode == "bicubic" | |||
inp = transpose(inp, (0, 2, 3, 1)) | |||
ret = warp_perspective(inp, weight, dsize, format="NHWC", interp_mode="cubic",) | |||
ret = transpose(ret, (0, 3, 1, 2)) | |||
if mode == "linear": | |||
ret = reshape(ret, ret.shape[0:3]) | |||
return ret | |||