GitOrigin-RevId: f7023a3fd3
tags/v1.6.0-rc1
@@ -197,7 +197,7 @@ public: | |||||
protected: | protected: | ||||
//! get origin coord | //! get origin coord | ||||
std::pair<float, int> get_origin_coord(float scale, int size, int idx); | |||||
std::pair<float, int> get_origin_coord(float scale, int size, int idx, bool cubic=false); | |||||
//! get nearest index in src | //! get nearest index in src | ||||
int get_nearest_src(float scale, int size, int idx); | int get_nearest_src(float scale, int size, int idx); | ||||
@@ -11,6 +11,7 @@ | |||||
*/ | */ | ||||
#include "megdnn/handle.h" | #include "megdnn/handle.h" | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -29,8 +30,9 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src, | |||||
if (param().format == Param::Format::NCHW) { | if (param().format == Param::Format::NCHW) { | ||||
megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str()); | megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str()); | ||||
auto imode = param().imode; | auto imode = param().imode; | ||||
megdnn_assert(imode == param::Resize::InterpolationMode::INTER_LINEAR || | |||||
imode == param::Resize::InterpolationMode::NEAREST); | |||||
using IMode = param::Resize::InterpolationMode; | |||||
megdnn_assert(imode == IMode::INTER_LINEAR || imode == IMode::NEAREST || | |||||
imode == IMode::INTER_CUBIC); | |||||
} else if (param().format == Param::Format::NHWC) { | } else if (param().format == Param::Format::NHWC) { | ||||
megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str()); | megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str()); | ||||
} else if (param().format == Param::Format::NCHW4) { | } else if (param().format == Param::Format::NCHW4) { | ||||
@@ -66,19 +68,20 @@ void ResizeBackward::check_exec(const TensorLayout& diff, | |||||
} | } | ||||
std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size, | std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size, | ||||
int idx) { | |||||
int idx, bool cubic) { | |||||
//! copy from resize_cv.cpp | //! copy from resize_cv.cpp | ||||
float alpha = (idx + 0.5f) / scale - 0.5f; | float alpha = (idx + 0.5f) / scale - 0.5f; | ||||
int origin_idx = static_cast<int>(floor(alpha)); | int origin_idx = static_cast<int>(floor(alpha)); | ||||
alpha -= origin_idx; | alpha -= origin_idx; | ||||
if (origin_idx < 0) { | |||||
origin_idx = 0; | |||||
alpha = 0; | |||||
} else if (origin_idx + 1 >= size) { | |||||
origin_idx = size - 2; | |||||
alpha = 1; | |||||
if (!cubic) { | |||||
if (origin_idx < 0) { | |||||
origin_idx = 0; | |||||
alpha = 0; | |||||
} else if (origin_idx + 1 >= size) { | |||||
origin_idx = size - 2; | |||||
alpha = 1; | |||||
} | |||||
} | } | ||||
return {alpha, origin_idx}; | return {alpha, origin_idx}; | ||||
} | } | ||||
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* \file dnn/src/common/resize.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/arch.h" | |||||
#if MEGDNN_CC_HOST && !defined(__host__) | |||||
#if __GNUC__ || __has_attribute(always_inline) | |||||
#define __forceinline__ inline __attribute__((always_inline)) | |||||
#else | |||||
#define __forceinline__ inline | |||||
#endif | |||||
#endif | |||||
namespace megdnn { | |||||
namespace resize { | |||||
MEGDNN_HOST MEGDNN_DEVICE __forceinline__ void interpolate_cubic( | |||||
float x, float* coeffs) { | |||||
const float A = -0.75f; | |||||
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; | |||||
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; | |||||
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; | |||||
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; | |||||
} | |||||
} // namespace resize | |||||
} // namespace megdnn | |||||
/* vim: set ft=cpp: */ |
@@ -71,7 +71,10 @@ struct RoundingConverter<uint8_t> { | |||||
__host__ __device__ __forceinline__ uint8_t operator()(float x) const { | __host__ __device__ __forceinline__ uint8_t operator()(float x) const { | ||||
#if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
using std::round; | using std::round; | ||||
using std::max; | |||||
using std::min; | |||||
#endif | #endif | ||||
x = min(255.0f, max(0.0f, x)); //! FIXME!!! check other places | |||||
return static_cast<uint8_t>(round(x)); | return static_cast<uint8_t>(round(x)); | ||||
} | } | ||||
}; | }; | ||||
@@ -11,6 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/common/cv/enums.h" | #include "src/common/cv/enums.h" | ||||
#include "src/common/resize.cuh" | |||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
@@ -49,15 +50,6 @@ __device__ inline void interpolate_linear_coefs(float x, float* coeffs) { | |||||
coeffs[1] = x; | coeffs[1] = x; | ||||
} | } | ||||
__host__ __device__ inline void interpolate_cubic_coefs(float x, | |||||
float* coeffs) { | |||||
const float A = -0.75f; | |||||
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; | |||||
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; | |||||
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; | |||||
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; | |||||
} | |||||
__device__ inline void interpolate_lanczos4_coefs(float x, float* coeffs) { | __device__ inline void interpolate_lanczos4_coefs(float x, float* coeffs) { | ||||
const float s45 = 0.70710678118654752440084436210485; | const float s45 = 0.70710678118654752440084436210485; | ||||
const float cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, | const float cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, | ||||
@@ -197,7 +189,7 @@ __device__ inline void interpolate_coefs<INTER_LINEAR>(float x, float* coeffs) { | |||||
} | } | ||||
template <> | template <> | ||||
__device__ inline void interpolate_coefs<INTER_CUBIC>(float x, float* coeffs) { | __device__ inline void interpolate_coefs<INTER_CUBIC>(float x, float* coeffs) { | ||||
interpolate_cubic_coefs(x, coeffs); | |||||
megdnn::resize::interpolate_cubic(x, coeffs); | |||||
} | } | ||||
template <> | template <> | ||||
__device__ inline void interpolate_coefs<INTER_LANCZOS4>(float x, | __device__ inline void interpolate_coefs<INTER_LANCZOS4>(float x, | ||||
@@ -12,6 +12,10 @@ | |||||
#include "src/cuda/resize/common.h" | #include "src/cuda/resize/common.h" | ||||
#include "src/cuda/utils.cuh" | #include "src/cuda/utils.cuh" | ||||
#include "src/cuda/cv/kernel_common.cuh" | |||||
using megdnn::resize::interpolate_cubic; | |||||
using megdnn::megcv::saturate; | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace cuda { | namespace cuda { | ||||
@@ -72,6 +76,42 @@ __global__ void resize_bwd_nearest_kernel(const float* hidden, float* dst, | |||||
} | } | ||||
} | } | ||||
} | } | ||||
__global__ void resize_bwd_cubic_kernel(const float* hidden, float* dst, int N, | |||||
int C, int IH, int IW, int OH, int OW, | |||||
float scale_h, float scale_w) { | |||||
int n = blockIdx.z; | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
hidden += n * C * OH * OW; | |||||
dst += n * C * IH * IW; | |||||
if (ow < OW && oh < OH) { | |||||
float alphah, alphaw; | |||||
int ih0, iw0; | |||||
get_origin_coord(scale_h, IH, oh, alphah, ih0, true); | |||||
get_origin_coord(scale_w, IW, ow, alphaw, iw0, true); | |||||
ih0--; | |||||
iw0--; | |||||
float h_coeff[4], w_coeff[4]; | |||||
interpolate_cubic(alphah, h_coeff); | |||||
interpolate_cubic(alphaw, w_coeff); | |||||
for (int c = 0; c < C; ++c) { | |||||
constexpr int ksize = 4; | |||||
for (int kh = 0; kh < ksize; kh++) { | |||||
int ih = saturate(ih0 + kh, 0, IH - 1); | |||||
for (int kw = 0; kw < ksize; kw++) { | |||||
int iw = saturate(iw0 + kw, 0, IW - 1); | |||||
atomicAdd(dst + ih * IW + iw, | |||||
hidden[oh * OW + ow] * h_coeff[kh] * w_coeff[kw]); | |||||
} | |||||
} | |||||
hidden += OH * OW; | |||||
dst += IH * IW; | |||||
} | |||||
} | |||||
} | |||||
void backward_data_proxy(InterpolationMode imode, const float* diff, | void backward_data_proxy(InterpolationMode imode, const float* diff, | ||||
float* grad, int N, int C, int IH, int IW, int OH, | float* grad, int N, int C, int IH, int IW, int OH, | ||||
int OW, cudaStream_t stream) { | int OW, cudaStream_t stream) { | ||||
@@ -83,13 +123,26 @@ void backward_data_proxy(InterpolationMode imode, const float* diff, | |||||
stream)); | stream)); | ||||
float scale_h = static_cast<float>(OH) / IH; | float scale_h = static_cast<float>(OH) / IH; | ||||
float scale_w = static_cast<float>(OW) / IW; | float scale_w = static_cast<float>(OW) / IW; | ||||
if(imode == InterpolationMode::INTER_LINEAR) { | |||||
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
} | |||||
else if (imode == InterpolationMode::INTER_NEAREST) { | |||||
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
switch (imode) { | |||||
case InterpolationMode::INTER_LINEAR: { | |||||
resize_bwd_linear_kernel<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_NEAREST: { | |||||
resize_bwd_nearest_kernel<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_CUBIC: { | |||||
resize_bwd_cubic_kernel<<<blocks, threads, 0, stream>>>( | |||||
diff, grad, N, C, IH, IW, OH, OW, scale_h, scale_w); | |||||
break; | |||||
} | |||||
default: { | |||||
megdnn_throw("unsupported interpolation mode"); | |||||
break; | |||||
} | |||||
} | } | ||||
} | } | ||||
after_kernel_launch(); | after_kernel_launch(); | ||||
@@ -15,16 +15,19 @@ namespace cuda { | |||||
namespace resize { | namespace resize { | ||||
__device__ inline void get_origin_coord(float scale, int size, int idx, | __device__ inline void get_origin_coord(float scale, int size, int idx, | ||||
float& alpha, int& origin_idx) { | |||||
float& alpha, int& origin_idx, | |||||
bool cubic = false) { | |||||
alpha = (idx + 0.5f) / scale - 0.5f; | alpha = (idx + 0.5f) / scale - 0.5f; | ||||
origin_idx = static_cast<int>(floor(alpha)); | origin_idx = static_cast<int>(floor(alpha)); | ||||
alpha -= origin_idx; | alpha -= origin_idx; | ||||
if (origin_idx < 0) { | |||||
origin_idx = 0; | |||||
alpha = 0; | |||||
} else if (origin_idx + 1 >= size) { | |||||
origin_idx = size - 2; | |||||
alpha = 1; | |||||
if (!cubic) { | |||||
if (origin_idx < 0) { | |||||
origin_idx = 0; | |||||
alpha = 0; | |||||
} else if (origin_idx + 1 >= size) { | |||||
origin_idx = size - 2; | |||||
alpha = 1; | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -147,9 +147,11 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
C, IH, IW, OH, OW, stream); | C, IH, IW, OH, OW, stream); | ||||
return; | return; | ||||
} | } | ||||
megdnn_assert(param().imode == Param::InterpolationMode::LINEAR || | |||||
param().imode == Param::InterpolationMode::NEAREST, | |||||
"unsupported interpolation mode for NCHW format"); | |||||
megdnn_assert( | |||||
param().imode == Param::InterpolationMode::LINEAR || | |||||
param().imode == Param::InterpolationMode::NEAREST || | |||||
param().imode == Param::InterpolationMode::INTER_CUBIC, | |||||
"unsupported interpolation mode for NCHW format"); | |||||
if (src.layout.dtype == dtype::Float32{}) { | if (src.layout.dtype == dtype::Float32{}) { | ||||
resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)), | resize::forward_proxy(is_nhwc, resize::get_imode((param().imode)), | ||||
@@ -8,15 +8,20 @@ | |||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/common/rounding_converter.cuh" | |||||
#include "src/common/utils.cuh" | |||||
#include "src/cuda/resize/common.cuh" | #include "src/cuda/resize/common.cuh" | ||||
#include "src/cuda/resize/common.h" | #include "src/cuda/resize/common.h" | ||||
#include "src/common/rounding_converter.cuh" | |||||
#include "src/cuda/resize/resize_cv.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
#include "src/cuda/cv/kernel_common.cuh" | |||||
#include "src/common/resize.cuh" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
using namespace resize; | |||||
using namespace megdnn::cuda::resize; | |||||
using megdnn::resize::interpolate_cubic; | |||||
using megdnn::megcv::saturate; | |||||
namespace { | namespace { | ||||
@@ -81,8 +86,7 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst, | |||||
int iw = get_nearest_src(scale_w, IW, ow); | int iw = get_nearest_src(scale_w, IW, ow); | ||||
for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
dst[oh * OW + ow] = output_converter( | |||||
sptr[ih * S_IH + iw * S_IW]); | |||||
dst[oh * OW + ow] = output_converter(sptr[ih * S_IH + iw * S_IW]); | |||||
sptr += S_IC; | sptr += S_IC; | ||||
dst += OH * OW; | dst += OH * OW; | ||||
@@ -91,6 +95,45 @@ __global__ void kern_general_nearest(SrcVisitor src, ctype* __restrict dst, | |||||
} | } | ||||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | template <typename ctype, typename SrcVisitor, typename OutputConverter> | ||||
__global__ void kern_general_cubic(SrcVisitor src, ctype* __restrict dst, int C, | |||||
int IH, int IW, int OH, int OW, int S_IN, | |||||
int S_IC, int S_IH, int S_IW, float scale_h, | |||||
float scale_w) { | |||||
OutputConverter output_converter; | |||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
const ctype* __restrict sptr = src.get(blockIdx.z, S_IN); | |||||
dst += blockIdx.z * C * OH * OW; | |||||
if (ow < OW && oh < OH) { | |||||
float alphah, alphaw; | |||||
int ih0, iw0; | |||||
get_origin_coord(scale_h, IH, oh, alphah, ih0, true); | |||||
get_origin_coord(scale_w, IW, ow, alphaw, iw0, true); | |||||
ih0--; | |||||
iw0--; | |||||
float h_coeff[4], w_coeff[4]; | |||||
interpolate_cubic(alphah, h_coeff); | |||||
interpolate_cubic(alphaw, w_coeff); | |||||
for (int c = 0; c < C; ++c) { | |||||
float ret = 0; | |||||
constexpr int ksize = 4; | |||||
for (int kh = 0; kh < ksize; kh++) { | |||||
int ih = saturate(ih0 + kh, 0, IH - 1); | |||||
for (int kw = 0; kw < ksize; kw++) { | |||||
int iw = saturate(iw0 + kw, 0, IW - 1); | |||||
ret += sptr[ih * S_IH + iw * S_IW] * h_coeff[kh] * | |||||
w_coeff[kw]; | |||||
} | |||||
} | |||||
dst[oh * OW + ow] = output_converter(ret); | |||||
sptr += S_IC; | |||||
dst += OH * OW; | |||||
} | |||||
} | |||||
} | |||||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | |||||
__global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C, | __global__ void kern_general_nhwc(SrcVisitor src, ctype* __restrict dst, int C, | ||||
int IH, int IW, int OH, int OW, float scale_h, | int IH, int IW, int OH, int OW, float scale_h, | ||||
float scale_w) { | float scale_w) { | ||||
@@ -140,18 +183,31 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode, | |||||
<<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH, | <<<blocks, threads, 0, stream>>>(src, dst, C, IH, IW, OH, | ||||
OW, scale_h, scale_w); | OW, scale_h, scale_w); | ||||
} else { | } else { | ||||
if (imode == InterpolationMode::INTER_LINEAR) { | |||||
kern_general_linear<ctype, SrcVisitor, | |||||
rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH, | |||||
S_IW, scale_h, scale_w); | |||||
} else if (imode == InterpolationMode::INTER_NEAREST) { | |||||
kern_general_nearest<ctype, SrcVisitor, | |||||
rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, S_IH, | |||||
S_IW, scale_h, scale_w); | |||||
switch (imode) { | |||||
case InterpolationMode::INTER_LINEAR: | |||||
kern_general_linear<ctype, SrcVisitor, | |||||
rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, | |||||
S_IH, S_IW, scale_h, scale_w); | |||||
break; | |||||
case InterpolationMode::INTER_NEAREST: | |||||
kern_general_nearest<ctype, SrcVisitor, | |||||
rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, | |||||
S_IH, S_IW, scale_h, scale_w); | |||||
break; | |||||
case InterpolationMode::INTER_CUBIC: | |||||
kern_general_cubic<ctype, SrcVisitor, | |||||
rounding::RoundingConverter<ctype>> | |||||
<<<blocks, threads, 0, stream>>>( | |||||
src, dst, C, IH, IW, OH, OW, S_IN, S_IC, | |||||
S_IH, S_IW, scale_h, scale_w); | |||||
break; | |||||
default: | |||||
megdnn_throw("unsupported interpolation mode"); | |||||
break; | |||||
} | } | ||||
} | } | ||||
N -= curr_batch_size; | N -= curr_batch_size; | ||||
@@ -162,8 +218,8 @@ void dispatch_with_visitor(bool is_nhwc, InterpolationMode imode, | |||||
template <typename ctype, typename SrcVisitor, typename OutputConverter> | template <typename ctype, typename SrcVisitor, typename OutputConverter> | ||||
__global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, | __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, | ||||
int IH, int IW, int OH, int OW, float scale_h, | |||||
float scale_w) { | |||||
int IH, int IW, int OH, int OW, | |||||
float scale_h, float scale_w) { | |||||
OutputConverter output_converter; | OutputConverter output_converter; | ||||
int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
int oh = blockIdx.y * blockDim.y + threadIdx.y; | int oh = blockIdx.y * blockDim.y + threadIdx.y; | ||||
@@ -188,10 +244,11 @@ __global__ void kern_general_nchw4(SrcVisitor src, ctype* __restrict dst, int C, | |||||
#pragma unroll | #pragma unroll | ||||
for (int c1 = 0; c1 < 4; ++c1) { | for (int c1 = 0; c1 < 4; ++c1) { | ||||
dst[o_coor + c1] = output_converter( | dst[o_coor + c1] = output_converter( | ||||
sptr[i_coor00 + c1] * (1.0f - alphaw) * (1.0f - alphah) + | |||||
sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) + | |||||
sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah + | |||||
sptr[i_coor11 + c1] * alphaw * alphah); | |||||
sptr[i_coor00 + c1] * (1.0f - alphaw) * | |||||
(1.0f - alphah) + | |||||
sptr[i_coor01 + c1] * alphaw * (1.0f - alphah) + | |||||
sptr[i_coor10 + c1] * (1.0f - alphaw) * alphah + | |||||
sptr[i_coor11 + c1] * alphaw * alphah); | |||||
} | } | ||||
dst += OH * OW * 4; | dst += OH * OW * 4; | ||||
sptr += IH * IW * 4; | sptr += IH * IW * 4; | ||||
@@ -250,18 +307,18 @@ void forward_proxy_nchw4(const ctype* src, ctype* dst, int N, int C, int IH, | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
#define INST(ctype) \ | |||||
template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, int, int, int, \ | |||||
int, int, int, int, int, int, int, \ | |||||
cudaStream_t); | |||||
#define INST(ctype) \ | |||||
template void forward_proxy(bool, InterpolationMode, const ctype*, ctype*, \ | |||||
int, int, int, int, int, int, int, int, int, \ | |||||
int, cudaStream_t); | |||||
INST(float) | INST(float) | ||||
INST(uint8_t) | INST(uint8_t) | ||||
INST(int8_t) | INST(int8_t) | ||||
#undef INST | #undef INST | ||||
#define INST(ctype) \ | |||||
#define INST(ctype) \ | |||||
template void forward_proxy_nchw4(const ctype*, ctype*, int, int, int, \ | template void forward_proxy_nchw4(const ctype*, ctype*, int, int, int, \ | ||||
int, int, int, cudaStream_t) | |||||
int, int, int, cudaStream_t) | |||||
INST(int8_t); | INST(int8_t); | ||||
#undef INST | #undef INST | ||||
@@ -59,12 +59,14 @@ | |||||
* --------------------------------------------------------------------------- | * --------------------------------------------------------------------------- | ||||
*/ | */ | ||||
#include "src/cuda/cv/kernel_common.cuh" | #include "src/cuda/cv/kernel_common.cuh" | ||||
#include "src/common/resize.cuh" | |||||
#include "src/cuda/resize/resize_cv.cuh" | #include "src/cuda/resize/resize_cv.cuh" | ||||
#include "src/cuda/utils.cuh" | #include "src/cuda/utils.cuh" | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
using namespace megcv; | using namespace megcv; | ||||
using megdnn::resize::interpolate_cubic; | |||||
namespace { | namespace { | ||||
@@ -126,7 +128,7 @@ __global__ void precompute_cubic_coef_f32(float* dst, float scale, | |||||
fr -= sr[tid]; | fr -= sr[tid]; | ||||
float coef[4]; | float coef[4]; | ||||
interpolate_cubic_coefs(fr, coef); | |||||
interpolate_cubic(fr, coef); | |||||
#pragma unroll | #pragma unroll | ||||
for (int j = 0, index = 0; j < 4; j++, index += size) { | for (int j = 0, index = 0; j < 4; j++, index += size) { | ||||
dst[tid + index] = coef[j]; | dst[tid + index] = coef[j]; | ||||
@@ -144,7 +146,7 @@ __global__ void precompute_cubic_coef_u8(short* dst, float scale, size_t size) { | |||||
fr -= sr[tid]; | fr -= sr[tid]; | ||||
float coef[4]; | float coef[4]; | ||||
interpolate_cubic_coefs(fr, coef); | |||||
interpolate_cubic(fr, coef); | |||||
#pragma unroll | #pragma unroll | ||||
for (int j = 0, index = 0; j < 4; j++, index += size) { | for (int j = 0, index = 0; j < 4; j++, index += size) { | ||||
dst[tid + index] = (short)(coef[j] * ONE); | dst[tid + index] = (short)(coef[j] * ONE); | ||||
@@ -406,7 +408,7 @@ __global__ void resize_cubic_32f_kernel_vector( | |||||
int sc = floor(fc); | int sc = floor(fc); | ||||
fc -= sc; | fc -= sc; | ||||
float coef_col[4]; | float coef_col[4]; | ||||
interpolate_cubic_coefs(fc, coef_col); | |||||
interpolate_cubic(fc, coef_col); | |||||
for (int i = 0; i < ELEMENTS_PER_THREADS; i++) { | for (int i = 0; i < ELEMENTS_PER_THREADS; i++) { | ||||
if (dr >= dst_rows) | if (dr >= dst_rows) | ||||
@@ -415,7 +417,7 @@ __global__ void resize_cubic_32f_kernel_vector( | |||||
int sr = floor(fr); | int sr = floor(fr); | ||||
fr -= sr; | fr -= sr; | ||||
float coef_row[4]; | float coef_row[4]; | ||||
interpolate_cubic_coefs(fr, coef_row); | |||||
interpolate_cubic(fr, coef_row); | |||||
float dst_data[CH] = {0}; | float dst_data[CH] = {0}; | ||||
#pragma unroll | #pragma unroll | ||||
for (int offset_r = 0; offset_r < 4; ++offset_r) { | for (int offset_r = 0; offset_r < 4; ++offset_r) { | ||||
@@ -459,7 +461,7 @@ __global__ void resize_cubic_8u_kernel_vector( | |||||
short icoef_col[4] = {0}; | short icoef_col[4] = {0}; | ||||
float coef_col[4]; | float coef_col[4]; | ||||
interpolate_cubic_coefs(fc, coef_col); | |||||
interpolate_cubic(fc, coef_col); | |||||
#pragma unroll | #pragma unroll | ||||
for (int i = 0; i < 4; i++) { | for (int i = 0; i < 4; i++) { | ||||
icoef_col[i] = (short)(coef_col[i] * ONE); | icoef_col[i] = (short)(coef_col[i] * ONE); | ||||
@@ -473,7 +475,7 @@ __global__ void resize_cubic_8u_kernel_vector( | |||||
fr -= sr; | fr -= sr; | ||||
short icoef_row[4]; | short icoef_row[4]; | ||||
float coef_row[4]; | float coef_row[4]; | ||||
interpolate_cubic_coefs(fr, coef_row); | |||||
interpolate_cubic(fr, coef_row); | |||||
#pragma unroll | #pragma unroll | ||||
for (int i = 0; i < 4; i++) { | for (int i = 0; i < 4; i++) { | ||||
icoef_row[i] = (short)(coef_row[i] * ONE); | icoef_row[i] = (short)(coef_row[i] * ONE); | ||||
@@ -118,7 +118,7 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
if (param().format == param::Resize::Format::NCHW4 || | if (param().format == param::Resize::Format::NCHW4 || | ||||
(param().format == param::Resize::Format::NCHW && | (param().format == param::Resize::Format::NCHW && | ||||
param().imode == param::Resize::InterpolationMode::NEAREST)) { | |||||
param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) { | |||||
naive::ResizeImpl::exec(src, dst, workspace); | naive::ResizeImpl::exec(src, dst, workspace); | ||||
return; | return; | ||||
} | } | ||||
@@ -9,18 +9,21 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/naive/resize/opr_impl.h" | |||||
#include "midout.h" | |||||
#include "src/common/cv/enums.h" | |||||
#include "src/common/resize.cuh" | |||||
#include "src/common/rounding_converter.cuh" | #include "src/common/rounding_converter.cuh" | ||||
#include "src/common/utils.cuh" | #include "src/common/utils.cuh" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
#include "src/naive/resize/opr_impl.h" | |||||
#include "src/naive/resize/resize_cv.h" | #include "src/naive/resize/resize_cv.h" | ||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_naive_resize_layout) | MIDOUT_DECL(megdnn_naive_resize_layout) | ||||
MIDOUT_DECL(megdnn_naive_resize_layout_nearest) | |||||
MIDOUT_DECL(megdnn_naive_resize_nchw) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace naive; | using namespace naive; | ||||
using namespace resize; | |||||
template <typename ctype> | template <typename ctype> | ||||
ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors( | ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors( | ||||
@@ -90,20 +93,84 @@ INST(dt_quint8); | |||||
#undef INST | #undef INST | ||||
template <typename ctype> | template <typename ctype> | ||||
void ResizeImpl::kern_nchw_nearest (const KernParam<ctype>& kern_param) { | |||||
void ResizeImpl::kern_nchw(const KernParam<ctype>& kern_param, | |||||
InterpolationMode imode) { | |||||
megdnn_assert(kern_param.format == Format::NCHW); | megdnn_assert(kern_param.format == Format::NCHW); | ||||
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param); | UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param); | ||||
float scale_h = static_cast<float>(OH) / IH; | float scale_h = static_cast<float>(OH) / IH; | ||||
float scale_w = static_cast<float>(OW) / IW; | float scale_w = static_cast<float>(OW) / IW; | ||||
rounding::RoundingConverter<ctype> output_converter; | |||||
rep(n, N) { | rep(n, N) { | ||||
rep(oh, OH) rep(ow, OW) { | rep(oh, OH) rep(ow, OW) { | ||||
auto ih = get_nearest_src(scale_h, IH, oh); | |||||
auto iw = get_nearest_src(scale_w, IW, ow); | |||||
switch (imode) { | |||||
case InterpolationMode::NEAREST: { | |||||
auto ih = get_nearest_src(scale_h, IH, oh); | |||||
auto iw = get_nearest_src(scale_w, IW, ow); | |||||
rep(c, static_cast<int>(C)) { | |||||
dptr[c * OH * OW + oh * OW + ow] = | |||||
sptr[c * S_IC + ih * S_IH + iw * S_IW]; | |||||
} | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_LINEAR: { | |||||
auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
float alphah = coord_h.first; | |||||
float alphaw = coord_w.first; | |||||
int ih0 = coord_h.second; | |||||
int ih1 = ih0 + 1; | |||||
int iw0 = coord_w.second; | |||||
int iw1 = iw0 + 1; | |||||
rep(c, static_cast<int>(C)) { | |||||
dptr[c * OH * OW + oh * OW + ow] = output_converter( | |||||
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * | |||||
(1.0f - alphaw) * (1.0f - alphah) + | |||||
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * | |||||
alphaw * (1.0f - alphah) + | |||||
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * | |||||
(1.0f - alphaw) * alphah + | |||||
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * | |||||
alphaw * alphah); | |||||
} | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_CUBIC: { | |||||
auto coord_h = get_origin_coord(scale_h, IH, oh, true); | |||||
auto coord_w = get_origin_coord(scale_w, IW, ow, true); | |||||
float alphah = coord_h.first; | |||||
float alphaw = coord_w.first; | |||||
rep(c, static_cast<int>(C)) { | |||||
dptr[c * OH * OW + oh * OW + ow] = sptr[c * S_IC + ih * S_IH + iw * S_IW]; | |||||
int ih0 = coord_h.second - 1; | |||||
int iw0 = coord_w.second - 1; | |||||
float h_coeff[4], w_coeff[4]; | |||||
interpolate_cubic(alphah, h_coeff); | |||||
interpolate_cubic(alphaw, w_coeff); | |||||
rep(c, static_cast<int>(C)) { | |||||
constexpr int ksize = 4; | |||||
float ret = 0; | |||||
rep(kh, ksize) { | |||||
int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | |||||
rep(kw, ksize) { | |||||
int w = saturate<int, int>(iw0 + kw, 0, IW - 1); | |||||
ret += sptr[c * S_IC + h * S_IH + w * S_IW] * | |||||
h_coeff[kh] * w_coeff[kw]; | |||||
} | |||||
} | |||||
dptr[c * OH * OW + oh * OW + ow] = | |||||
output_converter(ret); | |||||
} | |||||
break; | |||||
} | |||||
default: | |||||
megdnn_throw("unsupported mode in ResizeBackwardImpl"); | |||||
break; | |||||
} | } | ||||
} | } | ||||
sptr += S_IN; | sptr += S_IN; | ||||
@@ -131,40 +198,6 @@ void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) { | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return; | return; | ||||
} | } | ||||
megdnn_assert(kern_param.format == Format::NCHW); | |||||
UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(kern_param); | |||||
rounding::RoundingConverter<ctype> output_converter; | |||||
float scale_h = static_cast<float>(OH) / IH; | |||||
float scale_w = static_cast<float>(OW) / IW; | |||||
rep(n, N) { | |||||
rep(oh, OH) rep(ow, OW) { | |||||
auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
float alphah = coord_h.first; | |||||
float alphaw = coord_w.first; | |||||
int ih0 = coord_h.second; | |||||
int ih1 = ih0 + 1; | |||||
int iw0 = coord_w.second; | |||||
int iw1 = iw0 + 1; | |||||
rep(c, static_cast<int>(C)) { | |||||
dptr[c * OH * OW + oh * OW + ow] = output_converter( | |||||
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * | |||||
(1.0f - alphaw) * (1.0f - alphah) + | |||||
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * alphaw * | |||||
(1.0f - alphah) + | |||||
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * | |||||
(1.0f - alphaw) * alphah + | |||||
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * alphaw * | |||||
alphah); | |||||
} | |||||
} | |||||
sptr += S_IN; | |||||
dptr += C * OH * OW; | |||||
} | |||||
} | } | ||||
template <typename ctype> | template <typename ctype> | ||||
@@ -290,18 +323,16 @@ void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) { | |||||
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
check_exec(src.layout, dst.layout, workspace.size); | check_exec(src.layout, dst.layout, workspace.size); | ||||
if (param().format == param::Resize::Format::NCHW && | |||||
param().imode == param::Resize::InterpolationMode::NEAREST) { | |||||
#define cb(dt, ct, _midout_iv) \ | |||||
case DTypeTrait<dt>::enumv: { \ | |||||
MIDOUT_BEGIN(megdnn_naive_resize_layout_nearest, \ | |||||
midout_iv(_midout_iv)) { \ | |||||
auto kparam = KernParam<ct>::from_tensors(param().format, src, \ | |||||
dst, workspace); \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw_nearest(kparam)); \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
return; \ | |||||
if (param().format == param::Resize::Format::NCHW) { | |||||
#define cb(dt, ct, _midout_iv) \ | |||||
case DTypeTrait<dt>::enumv: { \ | |||||
MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \ | |||||
auto kparam = KernParam<ct>::from_tensors(param().format, src, \ | |||||
dst, workspace); \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
return; \ | |||||
} | } | ||||
switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
@@ -320,11 +351,9 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, | |||||
} | } | ||||
#undef cb | #undef cb | ||||
#undef cb | |||||
} | } | ||||
if ((param().format == param::Resize::Format::NCHW || | |||||
(src.layout[3] != 1 && src.layout[3] != 3) || | |||||
if (((src.layout[3] != 1 && src.layout[3] != 3) || | |||||
!is_nhwc_contig_wc(src.layout)) || | !is_nhwc_contig_wc(src.layout)) || | ||||
(param().imode == param::Resize::InterpolationMode::LINEAR)) { | (param().imode == param::Resize::InterpolationMode::LINEAR)) { | ||||
#define cb(dt, ct, _midout_iv) \ | #define cb(dt, ct, _midout_iv) \ | ||||
@@ -378,37 +407,73 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); | std::memset(sptr, 0, sizeof(float) * N * C * IH * IW); | ||||
rep(n, N) { | rep(n, N) { | ||||
rep(oh, OH) rep(ow, OW) { | rep(oh, OH) rep(ow, OW) { | ||||
if(param().imode == InterpolationMode::INTER_LINEAR) { | |||||
auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
float alphah = coord_h.first; | |||||
float alphaw = coord_w.first; | |||||
int ih0 = coord_h.second; | |||||
int ih1 = ih0 + 1; | |||||
int iw0 = coord_w.second; | |||||
int iw1 = iw0 + 1; | |||||
rep(c, C) { | |||||
float hidden = hptr[c * OH * OW + oh * OW + ow]; | |||||
sptr[c * IH * IW + ih0 * IW + iw0] += | |||||
(1.0f - alphaw) * (1.0f - alphah) * hidden; | |||||
sptr[c * IH * IW + ih1 * IW + iw0] += | |||||
(1.0f - alphaw) * alphah * hidden; | |||||
sptr[c * IH * IW + ih0 * IW + iw1] += | |||||
alphaw * (1.0f - alphah) * hidden; | |||||
sptr[c * IH * IW + ih1 * IW + iw1] += | |||||
alphaw * alphah * hidden; | |||||
switch (param().imode) { | |||||
case InterpolationMode::INTER_LINEAR: { | |||||
auto coord_h = get_origin_coord(scale_h, IH, oh); | |||||
auto coord_w = get_origin_coord(scale_w, IW, ow); | |||||
float alphah = coord_h.first; | |||||
float alphaw = coord_w.first; | |||||
int ih0 = coord_h.second; | |||||
int ih1 = ih0 + 1; | |||||
int iw0 = coord_w.second; | |||||
int iw1 = iw0 + 1; | |||||
rep(c, C) { | |||||
float hidden = hptr[c * OH * OW + oh * OW + ow]; | |||||
sptr[c * IH * IW + ih0 * IW + iw0] += | |||||
(1.0f - alphaw) * (1.0f - alphah) * hidden; | |||||
sptr[c * IH * IW + ih1 * IW + iw0] += | |||||
(1.0f - alphaw) * alphah * hidden; | |||||
sptr[c * IH * IW + ih0 * IW + iw1] += | |||||
alphaw * (1.0f - alphah) * hidden; | |||||
sptr[c * IH * IW + ih1 * IW + iw1] += | |||||
alphaw * alphah * hidden; | |||||
} | |||||
break; | |||||
} | } | ||||
} else if (param().imode == InterpolationMode::NEAREST) { | |||||
auto ih = get_nearest_src(scale_h, IH, oh); | |||||
auto iw = get_nearest_src(scale_w, IW, ow); | |||||
rep(c, static_cast<int>(C)) { | |||||
sptr[c * IH * IW + ih * IW + iw] += hptr[c * OH * OW + oh * OW + ow]; | |||||
case InterpolationMode::NEAREST: { | |||||
auto ih = get_nearest_src(scale_h, IH, oh); | |||||
auto iw = get_nearest_src(scale_w, IW, ow); | |||||
rep(c, static_cast<int>(C)) { | |||||
sptr[c * IH * IW + ih * IW + iw] += | |||||
hptr[c * OH * OW + oh * OW + ow]; | |||||
} | |||||
break; | |||||
} | |||||
case InterpolationMode::INTER_CUBIC: { | |||||
auto coord_h = get_origin_coord(scale_h, IH, oh, true); | |||||
auto coord_w = get_origin_coord(scale_w, IW, ow, true); | |||||
float alphah = coord_h.first; | |||||
float alphaw = coord_w.first; | |||||
int ih0 = coord_h.second - 1; | |||||
int iw0 = coord_w.second - 1; | |||||
float h_coeff[4], w_coeff[4]; | |||||
interpolate_cubic(alphah, h_coeff); | |||||
interpolate_cubic(alphaw, w_coeff); | |||||
rep(c, static_cast<int>(C)) { | |||||
constexpr int ksize = 4; | |||||
rep(kh, ksize) { | |||||
int h = saturate<int, int>(ih0 + kh, 0, IH - 1); | |||||
rep(kw, ksize) { | |||||
int w = saturate<int, int>(iw0 + kw, 0, IW - 1); | |||||
sptr[c * IH * IW + h * IW + w] += | |||||
hptr[c * OH * OW + oh * OW + ow] * | |||||
h_coeff[kh] * w_coeff[kw]; | |||||
} | |||||
} | |||||
} | |||||
break; | |||||
} | |||||
default: { | |||||
megdnn_throw("unsupported mode in ResizeBackwardImpl"); | |||||
break; | |||||
} | } | ||||
} | } | ||||
else megdnn_throw("unsupported mode in ResizeBackwardImpl"); | |||||
} | } | ||||
sptr += C * IH * IW; | sptr += C * IH * IW; | ||||
hptr += C * OH * OW; | hptr += C * OH * OW; | ||||
@@ -47,7 +47,7 @@ private: | |||||
void kern_naive(const KernParam<ctype>& kern_param); | void kern_naive(const KernParam<ctype>& kern_param); | ||||
template <typename ctype> | template <typename ctype> | ||||
void kern_nchw_nearest(const KernParam<ctype>& kern_param); | |||||
void kern_nchw(const KernParam<ctype>& kern_param, InterpolationMode imode); | |||||
template <typename ctype> | template <typename ctype> | ||||
void kern_naive_nhwc(const KernParam<ctype>& kern_param); | void kern_naive_nhwc(const KernParam<ctype>& kern_param); | ||||
@@ -68,6 +68,7 @@ | |||||
#include "src/common/cv/helper.h" | #include "src/common/cv/helper.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
#include "src/common/resize.cuh" | |||||
MIDOUT_DECL(megdnn_naive_resizecv_imode) | MIDOUT_DECL(megdnn_naive_resizecv_imode) | ||||
MIDOUT_DECL(megdnn_naive_resizecv_dtype) | MIDOUT_DECL(megdnn_naive_resizecv_dtype) | ||||
@@ -75,6 +76,7 @@ MIDOUT_DECL(megdnn_naive_resizecv_dtype) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace naive; | using namespace naive; | ||||
using namespace megcv; | using namespace megcv; | ||||
using namespace megdnn::resize; | |||||
namespace { | namespace { | ||||
@@ -383,14 +385,6 @@ using ResizeAreaFunc = void (*)(const Mat<T>& src, Mat<T>& dst, | |||||
const DecimateAlpha* ytab, int ytab_size, | const DecimateAlpha* ytab, int ytab_size, | ||||
const int* yofs); | const int* yofs); | ||||
static inline void interpolate_cubic(float x, float* coeffs) { | |||||
const float A = -0.75f; | |||||
coeffs[0] = ((A * (x + 1) - 5 * A) * (x + 1) + 8 * A) * (x + 1) - 4 * A; | |||||
coeffs[1] = ((A + 2) * x - (A + 3)) * x * x + 1; | |||||
coeffs[2] = ((A + 2) * (1 - x) - (A + 3)) * (1 - x) * (1 - x) + 1; | |||||
coeffs[3] = 1.f - coeffs[0] - coeffs[1] - coeffs[2]; | |||||
} | |||||
static inline void interpolate_lanczos4(float x, float* coeffs) { | static inline void interpolate_lanczos4(float x, float* coeffs) { | ||||
static const double s45 = 0.70710678118654752440084436210485; | static const double s45 = 0.70710678118654752440084436210485; | ||||
static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, | static const double cs[][2] = {{1, 0}, {-s45, -s45}, {0, 1}, {s45, -s45}, | ||||
@@ -43,7 +43,7 @@ TEST_F(CUDA, RESIZE_CV) { | |||||
TEST_F(CUDA, RESIZE_FORWARD) { | TEST_F(CUDA, RESIZE_FORWARD) { | ||||
using namespace resize; | using namespace resize; | ||||
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; | |||||
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; | |||||
for (auto imode : modes) { | for (auto imode : modes) { | ||||
std::vector<TestArg> args = get_args(imode); | std::vector<TestArg> args = get_args(imode); | ||||
Checker<Resize> checker(handle_cuda()); | Checker<Resize> checker(handle_cuda()); | ||||
@@ -88,7 +88,7 @@ TEST_F(CUDA, RESIZE_NCHW4) { | |||||
} | } | ||||
TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { | TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { | ||||
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; | |||||
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; | |||||
for (auto imode : modes) { | for (auto imode : modes) { | ||||
param::Resize param; | param::Resize param; | ||||
param.format = param::Resize::Format::NCHW; | param.format = param::Resize::Format::NCHW; | ||||
@@ -117,7 +117,7 @@ TEST_F(CUDA, RESIZE_NCHW_WITH_STRIDE) { | |||||
} | } | ||||
TEST_F(CUDA, RESIZE_BACKWARD) { | TEST_F(CUDA, RESIZE_BACKWARD) { | ||||
IMode modes[2] = {IMode::INTER_LINEAR, IMode::NEAREST}; | |||||
IMode modes[] = {IMode::INTER_LINEAR, IMode::NEAREST, IMode::INTER_CUBIC}; | |||||
for (auto imode : modes) { | for (auto imode : modes) { | ||||
Checker<ResizeBackward> checker(handle_cuda()); | Checker<ResizeBackward> checker(handle_cuda()); | ||||
param::Resize param; | param::Resize param; | ||||
@@ -574,19 +574,25 @@ def interpolate( | |||||
raise ValueError("under linear mode, size can only be single value") | raise ValueError("under linear mode, size can only be single value") | ||||
dsize = size | dsize = size | ||||
if not align_corners and mode in ("bilinear", "nearest") and inp.ndim in [4, 5]: | |||||
if not align_corners: | |||||
# fastpath for interpolate | # fastpath for interpolate | ||||
op = builtin.Resize( | |||||
imode="linear" if mode == "bilinear" else "nearest", format="NCHW" | |||||
) | |||||
mode_map = { | |||||
"linear": "linear", | |||||
"bilinear": "linear", | |||||
"nearest": "nearest", | |||||
"bicubic": "cubic", | |||||
} | |||||
op = builtin.Resize(imode=mode_map[mode], format="NCHW") | |||||
shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) | shape = astensor1d(dsize, inp, dtype="int32", device=inp.device) | ||||
(result,) = apply(op, inp, shape) | |||||
return result | |||||
oh, ow = dsize[0], dsize[1] | |||||
ih, iw = inp.shape[2], inp.shape[3] | |||||
if align_corners: | |||||
(ret,) = apply(op, inp, shape) | |||||
else: | |||||
assert mode in [ | |||||
"linear", | |||||
"bilinear", | |||||
], "align_corners only support linear or bilinear mode" | |||||
oh, ow = dsize[0], dsize[1] | |||||
ih, iw = inp.shape[2], inp.shape[3] | |||||
hscale = (ih - 1.0) / (oh - 1.0) | hscale = (ih - 1.0) / (oh - 1.0) | ||||
wscale = 1.0 * iw / ow | wscale = 1.0 * iw / ow | ||||
if mode != "linear": | if mode != "linear": | ||||
@@ -607,34 +613,11 @@ def interpolate( | |||||
axis=0, | axis=0, | ||||
).reshape(1, 3, 3) | ).reshape(1, 3, 3) | ||||
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | ||||
else: | |||||
hscale = 1.0 * ih / oh | |||||
wscale = 1.0 * iw / ow | |||||
row0 = concat( | |||||
[wscale, Tensor(0, dtype="float32", device=inp.device), 0.5 * wscale - 0.5], | |||||
axis=0, | |||||
).reshape(1, 3) | |||||
row1 = concat( | |||||
[Tensor(0, dtype="float32", device=inp.device), hscale, 0.5 * hscale - 0.5], | |||||
axis=0, | |||||
).reshape(1, 3) | |||||
weight = concat( | |||||
[row0, row1, Tensor([[0, 0, 1]], dtype="float32", device=inp.device)], | |||||
axis=0, | |||||
).reshape(1, 3, 3) | |||||
weight = broadcast_to(weight, (inp.shape[0], 3, 3)) | |||||
weight = weight.astype("float32") | |||||
if mode in ["linear", "bilinear"]: | |||||
ret = warp_perspective(inp, weight, dsize, interp_mode="linear") | ret = warp_perspective(inp, weight, dsize, interp_mode="linear") | ||||
if mode == "linear": | |||||
ret = reshape(ret, ret.shape[0:3]) | |||||
else: | |||||
# only NHWC format support "cubic" mode | |||||
assert mode == "bicubic" | |||||
inp = transpose(inp, (0, 2, 3, 1)) | |||||
ret = warp_perspective(inp, weight, dsize, format="NHWC", interp_mode="cubic",) | |||||
ret = transpose(ret, (0, 3, 1, 2)) | |||||
if mode == "linear": | |||||
ret = reshape(ret, ret.shape[0:3]) | |||||
return ret | return ret | ||||