Browse Source

feat(dnn): warp perspective support int4

GitOrigin-RevId: 826a43b349
release-1.5
Megvii Engine Team 4 years ago
parent
commit
df1af59b5c
10 changed files with 824 additions and 20 deletions
  1. +10
    -0
      dnn/src/common/rounding_converter.cuh
  2. +22
    -3
      dnn/src/common/warp_perspective.cpp
  3. +7
    -0
      dnn/src/cuda/warp_perspective/common.h
  4. +106
    -16
      dnn/src/cuda/warp_perspective/forward.cpp
  5. +272
    -0
      dnn/src/cuda/warp_perspective/forward.cu
  6. +138
    -0
      dnn/src/naive/warp_perspective/opr_impl.cpp
  7. +11
    -1
      dnn/src/naive/warp_perspective/opr_impl.h
  8. +59
    -0
      dnn/test/common/warp_perspective.h
  9. +91
    -0
      dnn/test/cuda/warp_perspective.cpp
  10. +108
    -0
      dnn/test/naive/warp_perspective.cpp

+ 10
- 0
dnn/src/common/rounding_converter.cuh View File

@@ -76,6 +76,16 @@ struct RoundingConverter<uint8_t> {
}
};

template <>
struct RoundingConverter<dt_qint4> {
__host__ __device__ __forceinline__ dt_qint4 operator()(float x) const {
#if MEGDNN_CC_HOST
using std::round;
#endif
return static_cast<dt_qint4>(round(x));
}
};

} // namespace rounding
} // namespace megdnn



+ 22
- 3
dnn/src/common/warp_perspective.cpp View File

@@ -29,7 +29,8 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
};
MEGDNN_MARK_USED_VAR(errmsg);
if (param().format == param::WarpPerspective::Format::NHWCD4 ||
param().format == param::WarpPerspective::Format::NCHW4) {
param().format == param::WarpPerspective::Format::NCHW4 ||
param().format == param::WarpPerspective::Format::NCHW64) {
megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str());
megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());

@@ -71,7 +72,8 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
src.dtype.enumv() == DTypeEnum::Quantized8Asymm) ||
src.dtype.enumv() == DTypeEnum::QuantizedS4,
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" DNN_FLOAT16_SELECT(
"/Float16/BFloat16", "") ".");
@@ -115,6 +117,22 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
param::WarpPerspective::BorderMode::TRANSPARENT);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
} else if (param().format == param::WarpPerspective::Format::NCHW64) {
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS4,
"src expected QuantizedS4, but got %s",
src.dtype.name());
megdnn_assert(mat.dtype == dtype::Float32(),
"matrix dtype expected float, got %s",
mat.dtype.name());
megdnn_assert(src.shape[4] == 64 && dst.shape[4] == 64);
megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());

megdnn_assert(param().imode ==
param::WarpPerspective::InterpolationMode::LINEAR);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::TRANSPARENT);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
} else {
megdnn_assert(param().format ==
param::WarpPerspective::Format::NHWCD4);
@@ -288,7 +306,8 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
param().format != Param::Format::NCHW4 &&
param().format != Param::Format::NHWC_NCHW &&
param().format != Param::Format::NHWC_NCHW4_IC_SMALL &&
param().format != Param::Format::NCHW_NCHW4_IC_SMALL) {
param().format != Param::Format::NCHW_NCHW4_IC_SMALL &&
param().format != Param::Format::NCHW64) {
megdnn_assert(!mat_idx.ndim,
"mat_idx not supported for current format");
}


+ 7
- 0
dnn/src/cuda/warp_perspective/common.h View File

@@ -35,6 +35,13 @@ void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx,
megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream);

template <typename ctype>
void forward_proxy_nchw64(const ctype* src, const float* mat,
const int* mat_idx, ctype* dst, int N_SRC, int N_MAT,
int C, int IH, int IW, int OH, int OW, ctype bval,
BorderMode bmode, megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream);

template <typename src_dtype, typename src_ctype, typename dst_ctype>
void forward_proxy_quint8_dimshuffle_typecvt_nchw4(
bool is_nhwc, const src_ctype* src, const float* mat,


+ 106
- 16
dnn/src/cuda/warp_perspective/forward.cpp View File

@@ -22,6 +22,43 @@
namespace megdnn {
namespace cuda {

namespace {
inline void deduce_reformat_layout(std::unique_ptr<RelayoutFormat>& relayout,
const TensorLayout& src_layout,
TensorLayout& dst_layout,
RelayoutFormat::Param::Mode mode,
const int oc = 0, const int group = 1) {
if (src_layout.ndim > 0) {
RelayoutFormat::Param trans_param;
trans_param.mode = mode;
trans_param.oc = oc;
trans_param.group = group;
relayout->param() = trans_param;
relayout->deduce_layout(src_layout, dst_layout);
} else {
dst_layout = src_layout;
}
}

void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
TensorLayout& inner_src, TensorLayout& inner_dst,
Handle* handle,
WarpPerspectiveForwardImpl::Param::Format format) {
if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
dst.dtype.enumv() == DTypeEnum::QuantizedS4 &&
format == param::WarpPerspective::Format::NCHW) {
auto relayout_opr = handle->create_operator<RelayoutFormat>();
deduce_reformat_layout(relayout_opr, src, inner_src,
RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
deduce_reformat_layout(relayout_opr, dst, inner_dst,
RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
} else {
megdnn_assert(0, "not support");
}
}

} // namespace

namespace warp_perspective {

void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
@@ -93,15 +130,22 @@ WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
TensorLayout fsrc = src;
TensorLayout fmat = mat;
TensorLayout fdst = dst;
auto get_workspace = [&sizes](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
sizes.push_back(layout.span().dist_byte());
}
};
get_workspace(fsrc);
get_workspace(fmat);
get_workspace(fdst);
if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
param().format == param::WarpPerspective::Format::NCHW) {
get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
sizes.push_back(fsrc.span().dist_byte());
sizes.push_back(fdst.span().dist_byte());
} else {
auto get_workspace = [&sizes](TensorLayout& layout) {
if (layout.dtype == dtype::BFloat16()) {
layout.dtype = dtype::Float32();
sizes.push_back(layout.span().dist_byte());
}
};
get_workspace(fsrc);
get_workspace(fmat);
get_workspace(fdst);
}
if (param().format == param::WarpPerspective::Format::NHWC) {
//! use double for the workspace dtype as float may cause
//! accuracy problems
@@ -123,6 +167,7 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
TensorND mat = smat;
TensorND mat_idx = smat_idx;
TensorND dst = sdst;
Param::Format inner_format = param().format;
auto bundle =
get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, smat.layout,
smat_idx.layout, sdst.layout);
@@ -132,11 +177,24 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
ctypecvt.src_to_comp_type(ssrc, src)
.src_to_comp_type(smat, mat)
.src_to_comp_type(sdst, dst);
} else if (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
param().format == Param::Format::NCHW) {
auto handle_ptr = handle();
get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout,
handle_ptr, param().format);
src.raw_ptr = bundle.get(0);
dst.raw_ptr = bundle.get(1);
auto relayout_opr = handle_ptr->create_operator<RelayoutFormat>();
RelayoutFormat::Param trans_param;
trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64;
relayout_opr->param() = trans_param;
relayout_opr->exec(ssrc, src, {});
inner_format = Param::Format::NCHW64;
}

{
auto stream = cuda_stream(this->handle());
bool is_nhwc = param().format == param::WarpPerspective::Format::NHWC;
bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC;

if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) {
// use opencv impl only for nhwc and non-linear interp
@@ -152,7 +210,7 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
} else {
megdnn_assert(warp::is_dnn_available(src.layout, mat.layout,
dst.layout, param().imode,
param().format));
inner_format));
size_t C, IH, IW, OH, OW;
if (is_nhwc) {
C = src.layout.shape[3];
@@ -160,19 +218,19 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
IW = src.layout.shape[2];
OH = dst.layout.shape[1];
OW = dst.layout.shape[2];
} else if (param().format == Param::Format::NCHW4) {
} else if (inner_format == Param::Format::NCHW4) {
C = src.layout.shape[1] * 4;
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
} else if (param().format == Param::Format::NHWC_NCHW) {
} else if (inner_format == Param::Format::NHWC_NCHW) {
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
} else if (param().format == Param::Format::NHWC_NCHW4_IC_SMALL) {
} else if (inner_format == Param::Format::NHWC_NCHW4_IC_SMALL) {
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
@@ -181,7 +239,7 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
megdnn_assert(
(C == 1) || (C == 3),
"NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3");
} else if (param().format == Param::Format::NCHW_NCHW4_IC_SMALL) {
} else if (inner_format == Param::Format::NCHW_NCHW4_IC_SMALL) {
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
@@ -190,9 +248,15 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
megdnn_assert(
(C == 1) || (C == 3),
"NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3");
} else if (inner_format == Param::Format::NCHW64) {
C = src.layout.shape[1] * 64;
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
} else {
megdnn_assert(
param().format == param::WarpPerspective::Format::NCHW,
inner_format == param::WarpPerspective::Format::NCHW,
"invalid warp_perspective format");
C = src.layout.shape[1];
IH = src.layout.shape[2];
@@ -261,6 +325,32 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
mat.layout[0], C, IH, IW, OH, OW, bval, bmode,
async_error_info(handle()), m_error_tracker,
stream);
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
megdnn_assert(
param().format == Param::Format::NCHW64 ||
param().format == Param::Format::NCHW,
"WarpPerspective on CUDA supports NCHW64 or NCHW+ "
"QuantizedS4 only");
bval = roundf(bval);
bval = fmin(fmax(-8.f, bval), 7.f);
warp_perspective::forward_proxy_nchw64<dt_qint4>(
src.compatible_ptr<dt_qint4>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_qint4>(), src.layout[0],
mat.layout[0], C, IH, IW, OH, OW,
static_cast<dt_qint4>(bval), bmode,
async_error_info(handle()), m_error_tracker,
stream);
if (param().format == Param::Format::NCHW) {
auto relayout_opr =
handle()->create_operator<RelayoutFormat>();
RelayoutFormat::Param trans_param;
trans_param.mode =
RelayoutFormat::Param::Mode::NCHW64_NCHW;
relayout_opr->param() = trans_param;
relayout_opr->exec(dst, sdst, {});
}
}
} else if ((src.layout.dtype.enumv() ==
DTypeEnum::Quantized8Asymm ||


+ 272
- 0
dnn/src/cuda/warp_perspective/forward.cu View File

@@ -142,6 +142,92 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat,
}
}

#define warp_perspective_transform(idx) \
static_cast<int>(output_converter(s00[idx] * nalpha * nbeta + \
s01[idx] * nalpha * pbeta + \
s10[idx] * palpha * nbeta + \
s11[idx] * palpha * pbeta) \
.as_int8())

#define pack_output \
transform_int8_to_int4x8( \
warp_perspective_transform(0), warp_perspective_transform(1), \
warp_perspective_transform(2), warp_perspective_transform(3), \
warp_perspective_transform(4), warp_perspective_transform(5), \
warp_perspective_transform(6), warp_perspective_transform(7))

template <typename ctype, typename Getter, typename SrcVisitor,
typename OutputConverter>
__global__ void kern_general_nchw64(SrcVisitor src, const float* __restrict mat,
ctype* __restrict dst, int C, int IH,
int IW, int OH, int OW) {
Getter getter;
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int c1 = ow % 2;
ow = ow / 2;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW / 2);
dst += blockIdx.z * C * OH * OW / 2;
mat += blockIdx.z * 3 * 3;
const int4* sptr_int4 = reinterpret_cast<const int4*>(sptr);
int4* dst_int4 = reinterpret_cast<int4*>(dst);
if (ow < OW && oh < OH) {
float denominator = mat[6] * ow + mat[7] * oh + mat[8];
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator;
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator;
int iw0 = getter(floor(iw) + 0, IW);
int iw1 = getter(floor(iw) + 1, IW);
int ih0 = getter(floor(ih) + 0, IH);
int ih1 = getter(floor(ih) + 1, IH);
float palpha = ih - floor(ih);
float pbeta = iw - floor(iw);
float nalpha = 1.0f - palpha;
float nbeta = 1.0f - pbeta;
int o_coor = (oh * OW + ow) << 1;
int i_coor_00 = (ih0 * IW + iw0) << 1;
int i_coor_01 = (ih0 * IW + iw1) << 1;
int i_coor_10 = (ih1 * IW + iw0) << 1;
int i_coor_11 = (ih1 * IW + iw1) << 1;
int s00[8], s01[8], s10[8], s11[8];
int4 s[4], d;
for (int c0 = 0, nr_chan = C / 64; c0 < nr_chan; ++c0) {
s[0] = __ldg(sptr_int4 + i_coor_00 + c1);
s[1] = __ldg(sptr_int4 + i_coor_01 + c1);
s[2] = __ldg(sptr_int4 + i_coor_10 + c1);
s[3] = __ldg(sptr_int4 + i_coor_11 + c1);

transform_int4x8_to_int8(s00, s[0].x);
transform_int4x8_to_int8(s01, s[1].x);
transform_int4x8_to_int8(s10, s[2].x);
transform_int4x8_to_int8(s11, s[3].x);
d.x = pack_output;

transform_int4x8_to_int8(s00, s[0].y);
transform_int4x8_to_int8(s01, s[1].y);
transform_int4x8_to_int8(s10, s[2].y);
transform_int4x8_to_int8(s11, s[3].y);
d.y = pack_output;

transform_int4x8_to_int8(s00, s[0].z);
transform_int4x8_to_int8(s01, s[1].z);
transform_int4x8_to_int8(s10, s[2].z);
transform_int4x8_to_int8(s11, s[3].z);
d.z = pack_output;

transform_int4x8_to_int8(s00, s[0].w);
transform_int4x8_to_int8(s01, s[1].w);
transform_int4x8_to_int8(s10, s[2].w);
transform_int4x8_to_int8(s11, s[3].w);
d.w = pack_output;

dst_int4[o_coor + c1] = d;
sptr_int4 += IH * IW * 2;
dst_int4 += OH * OW * 2;
}
}
}

template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_const_border(SrcVisitor src, const float* __restrict mat,
ctype* __restrict dst, int C, int IH, int IW,
@@ -233,6 +319,107 @@ __global__ void kern_const_border_nchw4(SrcVisitor src,
}
}

template <typename ctype, typename SrcVisitor, typename OutputConverter>
__global__ void kern_const_border_nchw64(SrcVisitor src,
const float* __restrict mat,
ctype* __restrict dst, int C, int IH,
int IW, int OH, int OW, ctype bval) {
OutputConverter output_converter;
int ow = blockIdx.x * blockDim.x + threadIdx.x;
int c1 = ow %2;
ow = ow / 2;
int oh = blockIdx.y * blockDim.y + threadIdx.y;
const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW / 2);
dst += blockIdx.z * C * OH * OW / 2;
mat += blockIdx.z * 3 * 3;
const int4* sptr_int4 = reinterpret_cast<const int4*>(sptr);
int4* dst_int4 = reinterpret_cast<int4*>(dst);
if (ow < OW && oh < OH) {
float denominator = mat[6] * ow + mat[7] * oh + mat[8];
float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator;
float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator;
int iw0 = floor(iw) + 0;
int iw1 = floor(iw) + 1;
int ih0 = floor(ih) + 0;
int ih1 = floor(ih) + 1;
bool okw0 = (iw0 >= 0 && iw0 < IW);
bool okw1 = (iw1 >= 0 && iw1 < IW);
bool okh0 = (ih0 >= 0 && ih0 < IH);
bool okh1 = (ih1 >= 0 && ih1 < IH);
float palpha = ih - floor(ih);
float pbeta = iw - floor(iw);
float nalpha = 1.0f - palpha;
float nbeta = 1.0f - pbeta;
int o_coor = (oh * OW + ow) << 1;
int i_coor_00 = (ih0 * IW + iw0) << 1;
int i_coor_01 = (ih0 * IW + iw1) << 1;
int i_coor_10 = (ih1 * IW + iw0) << 1;
int i_coor_11 = (ih1 * IW + iw1) << 1;
bool flag00 = okh0 && okw0, flag01 = okh0 && okw1,
flag10 = okh1 && okw0, flag11 = okh1 && okw1;
int8_t bval_4 = bval.as_int8() & 0xF;
int bval_8 = transform_int8_to_int4x8(bval_4, bval_4, bval_4, bval_4,
bval_4, bval_4, bval_4, bval_4);
int4 bval_int4;
bval_int4.x = bval_8;
bval_int4.y = bval_8;
bval_int4.z = bval_8;
bval_int4.w = bval_8;
int s00[8], s01[8], s10[8], s11[8];
int4 s[4], d;
for (int c0 = 0, nr_chan = C / 64; c0 < nr_chan; ++c0) {
if (flag00) {
s[0] = __ldg(sptr_int4 + i_coor_00 + c1);
} else {
s[0] = bval_int4;
}
if (flag01) {
s[1] = __ldg(sptr_int4 + i_coor_01 + c1);
} else {
s[1] = bval_int4;
}
if (flag10) {
s[2] = __ldg(sptr_int4 + i_coor_10 + c1);
} else {
s[2] = bval_int4;
}
if (flag11) {
s[3] = __ldg(sptr_int4 + i_coor_11 + c1);
} else {
s[3] = bval_int4;
}

transform_int4x8_to_int8(s00, s[0].x);
transform_int4x8_to_int8(s01, s[1].x);
transform_int4x8_to_int8(s10, s[2].x);
transform_int4x8_to_int8(s11, s[3].x);
d.x = pack_output;

transform_int4x8_to_int8(s00, s[0].y);
transform_int4x8_to_int8(s01, s[1].y);
transform_int4x8_to_int8(s10, s[2].y);
transform_int4x8_to_int8(s11, s[3].y);
d.y = pack_output;

transform_int4x8_to_int8(s00, s[0].z);
transform_int4x8_to_int8(s01, s[1].z);
transform_int4x8_to_int8(s10, s[2].z);
transform_int4x8_to_int8(s11, s[3].z);
d.z = pack_output;

transform_int4x8_to_int8(s00, s[0].w);
transform_int4x8_to_int8(s01, s[1].w);
transform_int4x8_to_int8(s10, s[2].w);
transform_int4x8_to_int8(s11, s[3].w);
d.w = pack_output;

dst_int4[o_coor + c1] = d;
sptr_int4 += IH * IW * 2;
dst_int4 += OH * OW * 2;
}
}
}

template <typename ctype, typename Getter, typename SrcVisitor,
typename OutputConverter>
__global__ void kern_general_nhwc(SrcVisitor src, const float* __restrict mat,
@@ -423,6 +610,58 @@ void dispatch_with_visitor_nchw4(SrcVisitor src, const float* mat, ctype* dst,
}
}

template <typename ctype, typename SrcVisitor>
void dispatch_with_visitor_nchw64(SrcVisitor src, const float* mat, ctype* dst,
int N, int C, int IH, int IW, int OH, int OW,
ctype bval, BorderMode bmode,
cudaStream_t stream) {
const int BY = 16, BX = 32;
#define DISPATCH(Getter) \
do { \
kern_general_nchw64<ctype, Getter, SrcVisitor, \
rounding::RoundingConverter<ctype>> \
<<<blocks, threads, 0, stream>>>(src, mat, dst, C, IH, IW, OH, \
OW); \
} while (0)

const int max_batch_size = 65535;
while (N) {
size_t curr_batch_size = N < max_batch_size ? N : max_batch_size;
dim3 threads(BX, BY);
dim3 blocks((OW * 2 + BX - 1) / BX, (OH + BY - 1) / BY,
curr_batch_size);

switch (bmode) {
case BORDER_REPLICATE:
DISPATCH(ReplicateGetter);
break;
case BORDER_REFLECT:
DISPATCH(ReflectGetter);
break;
case BORDER_REFLECT_101:
DISPATCH(Reflect101Getter);
break;
case BORDER_WRAP:
DISPATCH(WrapGetter);
break;
#undef DISPATCH
case BORDER_CONSTANT:
kern_const_border_nchw64<ctype, SrcVisitor,
rounding::RoundingConverter<ctype>>
<<<blocks, threads, 0, stream>>>(src, mat, dst, C, IH,
IW, OH, OW, bval);
break;
default:
break;
}

N -= curr_batch_size;
src.move_batch(curr_batch_size, C * IH * IW / 2);
mat += curr_batch_size * 3 * 3;
dst += curr_batch_size * C * OH * OW / 2;
}
}

template <typename SrcType, typename DstType>
struct CudaTypeCvt;

@@ -1154,6 +1393,30 @@ void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx,
after_kernel_launch();
}

template <typename ctype>
void forward_proxy_nchw64(const ctype* src, const float* mat, const int* mat_idx,
ctype* dst, int N_SRC, int N_MAT, int C, int IH,
int IW, int OH, int OW, ctype bval, BorderMode bmode,
megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream) {
if (mat_idx) {
IndexedSrcVisitor<ctype> visitor;
visitor.ptr = src;
visitor.idx = mat_idx;
visitor.N_SRC = N_SRC;
visitor.error_info = error_info;
visitor.error_tracker = error_tracker;
dispatch_with_visitor_nchw64(visitor, mat, dst, N_MAT, C, IH, IW, OH, OW,
bval, bmode, stream);
} else {
DirectSrcVisitor<ctype> visitor;
visitor.ptr = src;
dispatch_with_visitor_nchw64(visitor, mat, dst, N_MAT, C, IH, IW, OH, OW,
bval, bmode, stream);
}
after_kernel_launch();
}

#define INST(ctype) \
template void forward_proxy(bool, const ctype*, const float*, const int*, \
ctype*, int, int, int, int, int, int, int, \
@@ -1176,6 +1439,15 @@ INST(int8_t)
INST(int8_t)
#undef INST

#define INST(ctype) \
template void forward_proxy_nchw64( \
const ctype*, const float*, const int*, ctype*, int, int, int, \
int, int, int, int, ctype, BorderMode, megcore::AsyncErrorInfo*, \
void*, cudaStream_t);

INST(dt_qint4)
#undef INST

template <typename src_dtype, typename src_ctype, typename dst_ctype>
void forward_proxy_quint8_dimshuffle_typecvt_nchw4(
bool is_nhwc, const src_ctype* src, const float* mat,


+ 138
- 0
dnn/src/naive/warp_perspective/opr_impl.cpp View File

@@ -249,6 +249,127 @@ void WarpPerspectiveForwardImpl::kern_naive_nhwcd4(
MIDOUT_END();
}

template <typename ctype, typename mtype>
void WarpPerspectiveForwardImpl::kern_naive_int4(
const KernParam<ctype, mtype>& kern_param, size_t task_id) {
MEGDNN_MARK_USED_VAR(kern_param);
MIDOUT_BEGIN(megdnn_naive_warpperspective, ctype, mtype, midout_iv(0)) {
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param);
MEGDNN_MARK_USED_VAR(N_MAT);
uint8_t c_shift, c_mask, iw_shift = 0, ow_shift = 0;
switch (param().format) {
case Format::NCHW:
c_shift = 0;
c_mask = 0;
iw_shift = IW % 2;
ow_shift = OW % 2;
break;
case Format::NCHW64:
c_shift = 6;
c_mask = 0x3F;
break;
default:
megdnn_throw("bad format");
break;
}
//! strides of C, H, W on src and dst
size_t sstrd[2] = {IH * (IW + iw_shift), IW + iw_shift},
dstrd[2] = {OH * (OW + ow_shift), OW + ow_shift};
static constexpr uint8_t mask = (uint8_t)((1 << 4) - 1);
auto visit_src = [&sptr, sstrd, c_shift, c_mask](size_t c, int h,
int w) -> float {
size_t index = ((sstrd[0] * (c >> c_shift) + sstrd[1] * h + w)
<< c_shift) +
(c & c_mask);
uint8_t result =
(sptr[index / 2].as_int8() >> (4 * (index % 2))) & 0xF;
return result & uint8_t(1 << 3) ? result | ~mask : result;
};
auto visit_src_bd = [&sptr, sstrd, border_val, c_shift, c_mask](
size_t c, int h, int w) -> float {
if (h != -1 && w != -1) {
size_t index = ((sstrd[0] * (c >> c_shift) + sstrd[1] * h + w)
<< c_shift) +
(c & c_mask);
uint8_t result =
(sptr[index / 2].as_int8() >> (4 * (index % 2))) & 0xF;
return result & uint8_t(1 << 3) ? result | ~mask : result;
} else
return border_val;
};
auto set_visit_dst = [&dptr, dstrd, c_shift, c_mask](size_t c, int h,
int w, ctype v) {
size_t index = ((dstrd[0] * (c >> c_shift) + dstrd[1] * h + w)
<< c_shift) +
(c & c_mask);
dptr[index / 2] =
(dptr[index / 2].as_int8() & (0xF0 >> (4 * (index % 2)))) |
(v.as_int8() << (4 * (index % 2)));
};

rounding::RoundingConverter<ctype> output_converter;
auto orig_sptr = sptr;
size_t n = task_id / OH;
size_t oh = task_id % OH;
mptr = mptr + n * 3 * 3;
dptr = dptr + n * C * OH * OW / 2;
if (midx_ptr) {
size_t idx = midx_ptr[n];
megdnn_assert(
idx < N_SRC,
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", n,
idx, N_SRC);
sptr = orig_sptr + idx * (C * IH * IW) / 2;
} else if (n) {
sptr += n * C * IH * IW / 2;
}
rep(ow, OW) {
float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2];
float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5];
float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8];
float alphaw = numeratorw / denominator;
float alphah = numeratorh / denominator;

int iw0 = get_real_coord(std::floor(alphaw) + 0, IW);
int iw1 = get_real_coord(std::floor(alphaw) + 1, IW);
int ih0 = get_real_coord(std::floor(alphah) + 0, IH);
int ih1 = get_real_coord(std::floor(alphah) + 1, IH);

alphaw -= floor(alphaw);
alphah -= floor(alphah);
if (bmode != BorderMode::CONSTANT) {
rep(c, C) {
set_visit_dst(
c, oh, ow,
output_converter(
visit_src(c, ih0, iw0) * (1.0f - alphaw) *
(1.0f - alphah) +
visit_src(c, ih0, iw1) * alphaw *
(1.0f - alphah) +
visit_src(c, ih1, iw0) * (1.0f - alphaw) *
alphah +
visit_src(c, ih1, iw1) * alphaw * alphah));
}
} else {
rep(c, C) {
auto val = visit_src_bd(c, ih0, iw0) * (1.0f - alphaw) *
(1.0f - alphah) +
visit_src_bd(c, ih0, iw1) * alphaw *
(1.0f - alphah) +
visit_src_bd(c, ih1, iw0) * (1.0f - alphaw) *
alphah +
visit_src_bd(c, ih1, iw1) * alphaw * alphah;
set_visit_dst(
c, oh, ow,
output_converter(std::isfinite(val) ? val
: border_val));
}
}
}
}
MIDOUT_END();
}

template <typename ctype, typename dst_ctype, typename mtype>
void WarpPerspectiveForwardImpl::kern_naive_dimshuffle_typecvt(
const KernParam<ctype, mtype>& kern_param, size_t task_id) {
@@ -444,6 +565,15 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, kparam.oh* batch);

#define KERN_INT4(ct, mct) \
auto kparam = KernParam<ct, mct>::from_tensors( \
param().format, param().bmode, param().border_val, src, mat, \
mat_idx, dst, workspace); \
auto run = [kparam, this](size_t index, size_t) { \
kern_naive_int4(kparam, index); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, kparam.oh* batch);

#define DISPATCH_ST(dt, ct, mct, kern) \
if (src.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
kern(ct, mct); \
@@ -477,6 +607,14 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
.c_str());
}

if (src.layout.dtype.enumv() == DTypeTrait<dtype::QuantizedS4>::enumv) {
DISPATCH_ST(dtype::QuantizedS4, dt_qint4, float, KERN_INT4);
megdnn_throw(ssprintf("Unsupported input DType in "
"WarpPerspective: %s",
src.layout.dtype.name())
.c_str());
}

bool is_fusion_dtype = src.layout.dtype.enumv() != dst.layout.dtype.enumv();
bool is_u8_or_qu8_in =
src.layout.dtype.enumv() == DTypeTrait<dtype::Uint8>::enumv ||


+ 11
- 1
dnn/src/naive/warp_perspective/opr_impl.h View File

@@ -79,6 +79,12 @@ protected:
ret.iw = src.layout.shape[3];
ret.oh = dst.layout.shape[2];
ret.ow = dst.layout.shape[3];
} else if (format == Format::NCHW64) {
ret.c = src.layout.shape[1] * 64;
ret.ih = src.layout.shape[2];
ret.iw = src.layout.shape[3];
ret.oh = dst.layout.shape[2];
ret.ow = dst.layout.shape[3];
} else {
megdnn_assert(format == Format::NHWCD4);
ret.c = src.layout.shape[2] * 4;
@@ -100,7 +106,8 @@ protected:
ret.sptr = src.compatible_ptr<ctype>();
ret.mptr = mat.ptr<mtype>();
ret.dptr = dst.compatible_ptr<ctype>();
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
ret.sptr = src.compatible_ptr<ctype>();
ret.mptr = mat.ptr<mtype>();
ret.dptr = dst.compatible_ptr<ctype>();
@@ -141,6 +148,9 @@ private:
template <typename ctype, typename mtype>
void kern_naive_nhwcd4(const KernParam<ctype, mtype>& kern_param,
size_t task_id);
template <typename ctype, typename mtype>
void kern_naive_int4(const KernParam<ctype, mtype>& kern_param,
size_t task_id);
template <typename ctype, typename dst_ctype, typename mtype>
void kern_naive_dimshuffle_typecvt(
const KernParam<ctype, mtype>& kern_param, size_t task_id);


+ 59
- 0
dnn/test/common/warp_perspective.h View File

@@ -55,6 +55,65 @@ private:
size_t idx;
};

class WarpPerspectiveMatRNG_V2 final : public IIDRNG {
public:
WarpPerspectiveMatRNG_V2() : idx(0) {}
void set_hw(size_t h, size_t w) {
ih = h;
iw = w;
idx = 0;
rng.seed(time(NULL));
}
dt_float32 gen_single_val() override {
auto rand_real = [&](double lo, double hi) {
return rng() / (std::mt19937::max() + 1.0) * (hi - lo) + lo;
};
auto rand_real2 = [&](double range) {
return rand_real(-range, range);
};
dt_float32 res;
switch (idx) {
case 0:
rot = rand_real(0, M_PI * 2);
scale = rand_real(0.8, 1.2);
sheer = rand_real(0.9, 1.1);
res = cos(rot) * scale;
break;
case 1:
res = -sin(rot) * scale;
break;
case 2:
res = rand_real2(iw * 0.5);
break;
case 3:
res = sin(rot) * scale * sheer;
break;
case 4:
res = cos(rot) * scale * sheer;
break;
case 5:
res = rand_real2(ih * 0.5);
break;
case 6:
res = rand_real2(0.1 / iw);
break;
case 7:
res = rand_real2(0.1 / ih);
break;
case 8:
res = rand_real2(0.1) + 1;
break;
}
idx = (idx + 1) % 9;
return res;
}

private:
size_t idx, ih, iw;
float rot, scale, sheer;
std::mt19937 rng;
};

namespace warp_perspective {

struct TestArg {


+ 91
- 0
dnn/test/cuda/warp_perspective.cpp View File

@@ -622,6 +622,31 @@ TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_BFLOAT16) {
}
}

TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_QINT4) {
using Param = WarpPerspective::Param;
Checker<WarpPerspectiveForward> checker(handle_cuda());
WarpPerspectiveMatRNG rng;
checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::QuantizedS4(0.1f))
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::QuantizedS4(0.1f));
for (auto bmode : {WarpPerspective::BorderMode::WRAP,
WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
WarpPerspective::Param param;
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;

param.format = Param::Format::NCHW;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
checker.execs({{1, 64, 11, 11}, {1, 3, 3}, {1, 64, 11, 11}});
checker.execs({{20, 640, 11, 12}, {20, 3, 3}, {20, 640, 11, 12}});
}
}

TEST_F(CUDA, WARP_PERSPECTIVE_BACKWARD_DATA_BFLOAT16) {
Checker<WarpPerspectiveBackwardData> checker(handle_cuda());
WarpPerspectiveMatRNG rng;
@@ -676,6 +701,72 @@ TEST_F(CUDA, WARP_PERSPECTIVE_MAT_IDX) {
warp_perspective::run_mat_idx_test(handle_cuda());
}

TEST_F(CUDA, WARP_PERSPECTIVE_NCHW64) {
using Param = WarpPerspective::Param;
WarpPerspective::Param param;
Checker<WarpPerspectiveForward> checker(handle_cuda());
WarpPerspectiveMatRNG_V2 rng;
checker.set_dtype(0, dtype::QuantizedS4(0.1f));
checker.set_dtype(2, dtype::QuantizedS4(0.1f));
for (auto bmode : {WarpPerspective::BorderMode::WRAP,
WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;

param.format = Param::Format::NCHW64;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
rng.set_hw(10, 11);
checker.set_rng(1, &rng);
checker.execs({{2, 1, 10, 11, 64}, {2, 3, 3}, {2, 1, 11, 12, 64}});
checker.execs(
{{20, 300, 10, 11, 64}, {20, 3, 3}, {20, 300, 11, 12, 64}});
checker.execs(
{{2200, 3, 10, 11, 64}, {2200, 3, 3}, {2200, 3, 11, 12, 64}});
rng.set_hw(25, 25);
checker.set_rng(1, &rng);
checker.execs({{1, 25, 25, 25, 64}, {1, 3, 3}, {1, 25, 25, 51, 64}});
rng.set_hw(25, 510);
checker.set_rng(1, &rng);
checker.execs({{1, 1, 25, 510, 64}, {1, 3, 3}, {1, 1, 25, 25, 64}});
rng.set_hw(25, 25);
checker.set_rng(1, &rng);
checker.execs({{1, 1, 25, 25, 64}, {1, 3, 3}, {1, 1, 51, 51, 64}});
rng.set_hw(51, 51);
checker.set_rng(1, &rng);
checker.execs({{1, 1, 51, 51, 64}, {1, 3, 3}, {1, 1, 25, 25, 64}});
}
{
Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker(
handle_cuda());
constexpr int N_SRC = 5;
UniformIntRNG mat_idx_rng{0, N_SRC - 1};
checker.set_dtype(0, dtype::QuantizedS4(0.1f));
checker.set_rng(1, &rng);
checker.set_dtype(2, dtype::Int32());
checker.set_rng(2, &mat_idx_rng);
checker.set_dtype(3, dtype::QuantizedS4(0.1f));
param.bmode = WarpPerspective::Param::BorderMode::REFLECT;
param.imode = param::WarpPerspective::InterpolationMode::LINEAR;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
rng.set_hw(10, 11);
checker.set_rng(1, &rng);
checker.execs(
{{N_SRC, 3, 10, 11, 64}, {2, 3, 3}, {2}, {2, 3, 11, 12, 64}});
rng.set_hw(17, 13);
checker.set_rng(1, &rng);
checker.execs({{N_SRC, 14, 17, 13, 64},
{123, 3, 3},
{123},
{123, 14, 16, 15, 64}});
}
}

#if MEGDNN_WITH_BENCHMARK

TEST_F(CUDA, BENCHMARK_WARP_PERSPECTIVE_NCHW4) {


+ 108
- 0
dnn/test/naive/warp_perspective.cpp View File

@@ -189,6 +189,29 @@ TEST_F(NAIVE, WARP_PERSPECTIVE) {
{156, 183, 181, 195})});
}

TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW_QINT4) {
Checker<WarpPerspective> checker(handle(), false);
WarpPerspective::Param param;
param.bmode = WarpPerspective::Param::BorderMode::BORDER_REFLECT;
param.imode = WarpPerspective::Param::InterpolationMode::LINEAR;
param.format = WarpPerspective::Param::Format::NCHW;

std::vector<int> input_values = {1, 3, 2, 2, 0, 0, 0, 0, 2},
output_values = {1, 2, 2, 2};

checker.set_param(param).exect(
Testcase{TensorValueLowbit4({1, 1, 3, 3}, dtype::QuantizedS4(0.1),
input_values),
TensorValue({1, 3, 3}, dtype::Float32{},
{1.2f, 1.2f, 0.6f, -1.05f, -2.0f, -0.7f, 1.3f,
1.5f, 3.0f}),
{}},
Testcase{{},
{},
TensorValueLowbit4({1, 1, 2, 2}, dtype::QuantizedS4(0.1),
output_values)});
}

TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_NCHW4) {
using Param = WarpPerspective::Param;

@@ -518,4 +541,89 @@ TEST_F(NAIVE, WARP_PERSPECTIVE_BACKWARD_MAT_BFLOAT16) {
{1000, 3, 3}});
}

TEST_F(NAIVE, WARP_PERSPECTIVE_NCHW64) {
using Param = WarpPerspective::Param;

auto convert_true_format = [](const TensorLayout& layout) {
if (layout.ndim == 4)
return layout
.reshape({layout[0], layout[1] / 64, layout[2], layout[3],
64})
.dimshuffle({0, 1, 4, 2, 3});
else
return layout;
};

WarpPerspective::Param param;
auto extra_impl = [&param, this,
convert_true_format](const TensorNDArray& tensors) {
auto warp_perspective = handle()->create_operator<WarpPerspective>();
warp_perspective->param() = param;
warp_perspective->param().format = Param::Format::NCHW;

TensorNDArray nchw_tensors;
for (size_t i = 0; i < tensors.size(); ++i) {
auto layout = tensors[i].layout;
if (layout.dtype.enumv() == DTypeEnum::QuantizedS4)
layout.dtype = dtype::QuantizedS4();
if (layout.ndim == 5) {
layout = layout.reshape({layout[0], layout[1] * layout[4],
layout[2], layout[3]});
}
nchw_tensors.emplace_back(malloc(layout.span().dist_byte()),
layout);
}
TensorNDArray nchw64_tensors;
for (size_t i = 0; i < tensors.size(); ++i) {
auto layout = convert_true_format(nchw_tensors[i].layout);
nchw64_tensors.emplace_back(tensors[i].raw_ptr, std::move(layout));
}

auto workspace_size = warp_perspective->get_workspace_in_bytes(
tensors[0].layout, tensors[1].layout, tensors[2].layout);
dt_byte* workspace_ptr = static_cast<dt_byte*>(malloc(workspace_size));
Workspace workspace{workspace_ptr, workspace_size};

auto relayout = handle()->create_operator<RelayoutForward>();
relayout->exec(nchw64_tensors[0], nchw_tensors[0]);
relayout->exec(nchw64_tensors[1], nchw_tensors[1]);

warp_perspective->exec(nchw_tensors[0], nchw_tensors[1],
nchw_tensors[2], workspace);

relayout->exec(nchw_tensors[2], nchw64_tensors[2]);

free(workspace_ptr);
for (auto&& tensor : nchw_tensors) {
free(tensor.raw_ptr);
}
};

Checker<WarpPerspectiveForward> checker(handle());
WarpPerspectiveMatRNG rng;
checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::QuantizedS4(0.1f));
checker.set_dtype(2, dtype::QuantizedS4(0.1f));
checker.set_extra_opr_impl(extra_impl);
for (auto bmode : {WarpPerspective::BorderMode::WRAP,
WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;

param.format = Param::Format::NCHW64;
checker.set_param(param);
checker.execs({{2, 1, 10, 10, 64}, {2, 3, 3}, {2, 1, 10, 12, 64}});
checker.execs(
{{20, 30, 10, 12, 64}, {20, 3, 3}, {20, 30, 11, 12, 64}});
checker.execs(
{{220, 3, 10, 10, 64}, {220, 3, 3}, {220, 3, 10, 12, 64}});
checker.execs({{1, 25, 25, 24, 64}, {1, 3, 3}, {1, 25, 25, 510, 64}});
checker.execs({{1, 25, 25, 510, 64}, {1, 3, 3}, {1, 25, 25, 24, 64}});
checker.execs({{1, 25, 25, 24, 64}, {1, 3, 3}, {1, 25, 51, 50, 64}});
checker.execs({{1, 25, 51, 50, 64}, {1, 3, 3}, {1, 25, 25, 24, 64}});
}
}
// vim: syntax=cpp.doxygen

Loading…
Cancel
Save