Browse Source

feat(dnn/cuda): add impl for fusing warp perspective and dimshuffle

GitOrigin-RevId: 51e025973f
release-1.2
Megvii Engine Team 4 years ago
parent
commit
61f917fb8e
13 changed files with 2023 additions and 417 deletions
  1. +6
    -0
      dnn/scripts/opr_param_defs.py
  2. +150
    -108
      dnn/src/common/warp_perspective.cpp
  3. +17
    -0
      dnn/src/cuda/warp_perspective/common.h
  4. +141
    -48
      dnn/src/cuda/warp_perspective/forward.cpp
  5. +823
    -77
      dnn/src/cuda/warp_perspective/forward.cu
  6. +220
    -5
      dnn/src/naive/warp_perspective/opr_impl.cpp
  7. +32
    -10
      dnn/src/naive/warp_perspective/opr_impl.h
  8. +139
    -2
      dnn/test/cuda/warp_perspective.cpp
  9. +143
    -148
      src/gopt/impl/framework.cpp
  10. +241
    -0
      src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
  11. +10
    -0
      src/gopt/include/megbrain/gopt/inference.h
  12. +62
    -1
      src/gopt/test/inference.cpp
  13. +39
    -18
      src/opr/impl/imgproc.cpp

+ 6
- 0
dnn/scripts/opr_param_defs.py View File

@@ -43,6 +43,12 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, '
'output tensor is nchw layout'),
Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'),
Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'),
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'))
)


+ 150
- 108
dnn/src/common/warp_perspective.cpp View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs.h"

@@ -14,20 +15,17 @@

namespace megdnn {

void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst)
{
void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
const TensorLayout& mat,
const TensorLayout& mat_idx,
const TensorLayout& dst) {
megdnn_assert_contiguous(mat);
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst);
auto errmsg = [&]() {
return megdnn_layout_msg(src) + ", " +
megdnn_layout_msg(mat) + ", " +
megdnn_layout_msg(mat_idx) + ", " +
megdnn_layout_msg(dst) + ", " +
param_msg();
return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(mat) + ", " +
megdnn_layout_msg(mat_idx) + ", " + megdnn_layout_msg(dst) +
", " + param_msg();
};
MEGDNN_MARK_USED_VAR(errmsg);
if (param().format == param::WarpPerspective::Format::NHWCD4 ||
@@ -35,9 +33,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str());
megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());

} else if (param().format ==
param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL ||
param().format ==
param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) {
megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
} else {
megdnn_assert(param().format == param::WarpPerspective::Format::NHWC ||
param().format == param::WarpPerspective::Format::NCHW);
param().format == param::WarpPerspective::Format::NCHW ||
param().format ==
param::WarpPerspective::Format::NHWC_NCHW);
megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str());
}
@@ -45,7 +51,7 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str());
if (mat_idx.ndim) {
megdnn_assert(mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1,
"%s", errmsg().c_str());
"%s", errmsg().c_str());
megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str());
megdnn_assert_contiguous(mat_idx);
} else {
@@ -54,35 +60,103 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str());
megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());

if (param().format == param::WarpPerspective::Format::NCHW) {
megdnn_assert(
src.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
(src.dtype.enumv() == DTypeEnum::Float16 ||
src.dtype.enumv() == DTypeEnum::BFloat16),
false) ||
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16", "") ".");
megdnn_assert(
(src.dtype.category() == DTypeCategory::FLOAT &&
(src.dtype == mat.dtype ||
mat.dtype.enumv() == DTypeEnum::Float32)) ||
((src.dtype.category() == DTypeCategory::INT ||
src.dtype.category() == DTypeCategory::QUANTIZED) &&
mat.dtype.enumv() == DTypeEnum::Float32),
"The input to WarpPerspective is in NCHW format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, otherwise, it should be in Float32, %s given.",
mat.dtype.name());
if (src.format == dst.format && dst.dtype == src.dtype) {
if (param().format == param::WarpPerspective::Format::NCHW) {
megdnn_assert(
src.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
(src.dtype.enumv() == DTypeEnum::Float16 ||
src.dtype.enumv() == DTypeEnum::BFloat16),
false) ||
src.dtype.enumv() == DTypeEnum::Int8 ||
src.dtype.enumv() == DTypeEnum::Uint8 ||
(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm),
"WarpPerspective NCHW input dtype should be "
"Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16", "") ".");
megdnn_assert(
(src.dtype.category() == DTypeCategory::FLOAT &&
(src.dtype == mat.dtype ||
mat.dtype.enumv() == DTypeEnum::Float32)) ||
((src.dtype.category() == DTypeCategory::INT ||
src.dtype.category() ==
DTypeCategory::QUANTIZED) &&
mat.dtype.enumv() == DTypeEnum::Float32),
"The input to WarpPerspective is in NCHW format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, otherwise, it should be in Float32, %s given.",
mat.dtype.name());

megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());

megdnn_assert(dst.dtype == src.dtype);
megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
megdnn_assert(param().imode ==
param::WarpPerspective::InterpolationMode::LINEAR);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::TRANSPARENT);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
} else if (param().format == param::WarpPerspective::Format::NHWC) {
megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str());
} else if (param().format == param::WarpPerspective::Format::NCHW4) {
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8,
"src expected QuantizedS8, but got %s",
src.dtype.name());
megdnn_assert(mat.dtype == dtype::Float32(),
"matrix dtype expected float, got %s",
mat.dtype.name());
megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4);
megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());

megdnn_assert(param().imode ==
param::WarpPerspective::InterpolationMode::LINEAR);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::TRANSPARENT);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
} else {
megdnn_assert(param().format ==
param::WarpPerspective::Format::NHWCD4);
megdnn_assert(
src.dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT(
(src.dtype == dtype::Float16() ||
src.dtype == dtype::BFloat16()),
false) ||
src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
"WarpPerspective NHWCD4 input dtype should be "
"Float32" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16",
"") ",QunatizedS8, Quantized8Asymm.");
megdnn_assert(
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
"The input to WarpPerspective is in NHWCD4 format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, %s given.",
mat.dtype.name());
//! number of channels is same
megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str());
megdnn_assert(param().imode ==
param::WarpPerspective::InterpolationMode::LINEAR);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::TRANSPARENT);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
}
} else if (param().format ==
param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL ||
param().format ==
param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) {
megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
src.dtype.enumv() == DTypeEnum::Uint8),
"src expected Quantized8Asymm or Uint8, but got %s",
src.dtype.name());
megdnn_assert(mat.dtype == dtype::Float32(),
"matrix dtype expected float, got %s", mat.dtype.name());
megdnn_assert(dst.shape[4] == 4);

megdnn_assert(param().imode ==
param::WarpPerspective::InterpolationMode::LINEAR);
@@ -90,16 +164,14 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
param::WarpPerspective::BorderMode::TRANSPARENT);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
} else if (param().format == param::WarpPerspective::Format::NHWC) {
megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str());
} else if (param().format == param::WarpPerspective::Format::NCHW4) {
megdnn_assert(dst.dtype == src.dtype);
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8,
"src expected QuantizedS8, but got %s", src.dtype.name());
} else if (param().format == param::WarpPerspective::Format::NHWC_NCHW) {
megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
src.dtype.enumv() == DTypeEnum::Uint8),
"src expected Quantized8Asymm or Uint8, but got %s",
src.dtype.name());
megdnn_assert(mat.dtype == dtype::Float32(),
"matrix dtype expected float, got %s", mat.dtype.name());
megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4);
megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
megdnn_assert(src.shape[3] == dst.shape[1], "%s", errmsg().c_str());

megdnn_assert(param().imode ==
param::WarpPerspective::InterpolationMode::LINEAR);
@@ -108,40 +180,14 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
} else {
megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4);
megdnn_assert(
src.dtype == dtype::Float32() ||
MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() ||
src.dtype == dtype::BFloat16()),
false) ||
src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
"WarpPerspective NHWCD4 input dtype should be "
"Float32" MEGDNN_FLOAT16_SELECT(
"/Float16/BFloat16",
"") ",QunatizedS8, Quantized8Asymm.");
megdnn_assert(
(src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
"The input to WarpPerspective is in NHWCD4 format, in this "
"case, if the input dtype is floating point, the "
"transformation matrix should have same dtype as the "
"input, %s given.",
mat.dtype.name());
megdnn_assert(dst.dtype == src.dtype);
//! number of channels is same
megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str());
megdnn_assert(param().imode ==
param::WarpPerspective::InterpolationMode::LINEAR);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::TRANSPARENT);
megdnn_assert(param().bmode !=
param::WarpPerspective::BorderMode::ISOLATED);
megdnn_assert(param().format == param::WarpPerspective::Format::NCHW);
megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
src.dtype.enumv() == DTypeEnum::Uint8) &&
dst.dtype.enumv() == DTypeEnum::Float32);
}
megdnn_assert(src.format == dst.format);
}

std::string WarpPerspectiveBase::param_msg() const
{
std::string WarpPerspectiveBase::param_msg() const {
std::string res;
res.append(megdnn_mangle("imode="));
switch (param().imode) {
@@ -191,31 +237,25 @@ std::string WarpPerspectiveBase::param_msg() const
return res;
}

int WarpPerspectiveBase::get_real_coord(int p, int len)
{
int WarpPerspectiveBase::get_real_coord(int p, int len) {
auto bmode = param().bmode;
if( (unsigned)p < (unsigned)len )
if ((unsigned)p < (unsigned)len)
;
else if( bmode == BorderMode::REPLICATE )
else if (bmode == BorderMode::REPLICATE)
p = p < 0 ? 0 : len - 1;
else if( bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101 )
{
else if (bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101) {
int delta = (bmode == BorderMode::REFLECT_101);
if( len == 1 )
if (len == 1)
return 0;
do
{
if( p < 0 )
do {
if (p < 0)
p = -p - 1 + delta;
else
p = len - 1 - (p - len) - delta;
}
while( (unsigned)p >= (unsigned)len );
}
else if( bmode == BorderMode::WRAP )
{
if( p < 0 )
p -= ((p-len+1)/len)*len;
} while ((unsigned)p >= (unsigned)len);
} else if (bmode == BorderMode::WRAP) {
if (p < 0)
p -= ((p - len + 1) / len) * len;
/*
if( p >= len )
p %= len;
@@ -223,18 +263,16 @@ int WarpPerspectiveBase::get_real_coord(int p, int len)
while (p >= len) {
p -= len;
}
}
else if( bmode == BorderMode::CONSTANT )
} else if (bmode == BorderMode::CONSTANT)
p = -1;
return p;
}

void WarpPerspectiveForward::check_exec(const TensorLayout &src,
const TensorLayout &mat,
const TensorLayout &mat_idx,
const TensorLayout &dst,
size_t workspace_in_bytes)
{
void WarpPerspectiveForward::check_exec(const TensorLayout& src,
const TensorLayout& mat,
const TensorLayout& mat_idx,
const TensorLayout& dst,
size_t workspace_in_bytes) {
check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes);
}

@@ -248,7 +286,10 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
if (param().format != Param::Format::NHWC &&
param().format != Param::Format::NCHW &&
param().format != Param::Format::NCHW4) {
param().format != Param::Format::NCHW4 &&
param().format != Param::Format::NHWC_NCHW &&
param().format != Param::Format::NHWC_NCHW4_IC_SMALL &&
param().format != Param::Format::NCHW_NCHW4_IC_SMALL) {
megdnn_assert(!mat_idx.ndim,
"mat_idx not supported for current format");
}
@@ -263,7 +304,8 @@ void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat,
megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16(
|| grad.dtype == dtype::BFloat16()),
"Backward WarpPerspective only supports Float32/BFloat16.");
auto required_workspace_in_bytes = get_workspace_in_bytes(mat, mat_idx, diff, grad);
auto required_workspace_in_bytes =
get_workspace_in_bytes(mat, mat_idx, diff, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

@@ -283,6 +325,6 @@ void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src,
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

} // namespace megdnn
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 17
- 0
dnn/src/cuda/warp_perspective/common.h View File

@@ -12,6 +12,7 @@
#pragma once
#include <cuda_runtime_api.h>
#include "src/common/cv/enums.h"
#include "src/cuda/utils.cuh"
#include "megcore_cdefs.h"

namespace megdnn {
@@ -34,6 +35,22 @@ void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx,
megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream);

template <typename src_dtype, typename src_ctype, typename dst_ctype>
void forward_proxy_quint8_dimshuffle_typecvt_nchw4(
bool is_nhwc, const src_ctype* src, const float* mat,
const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH,
int IW, int OH, int OW, src_ctype bval, DTypeParamImpl<src_dtype> param,
BorderMode bmode, megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream);

template <typename src_dtype, typename src_ctype, typename dst_ctype>
void forward_proxy_quint8_dimshuffle_typecvt_nchw(
bool is_nhwc, const src_ctype* src, const float* mat,
const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH,
int IW, int OH, int OW, src_ctype bval, DTypeParamImpl<src_dtype> param,
BorderMode bmode, megcore::AsyncErrorInfo* error_info,
void* error_tracker, cudaStream_t stream);

void backward_data_proxy(const float* mat, const int* midx, const float* diff,
float* grad, float* workspace, int N, int N_SRC, int C,
int IH, int IW, int OH, int OW, float bval,


+ 141
- 48
dnn/src/cuda/warp_perspective/forward.cpp View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/warp_perspective/opr_impl.h"
#include "src/cuda/warp_perspective/warp_perspective_cv.cuh"
@@ -166,6 +167,30 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
} else if (param().format == Param::Format::NHWC_NCHW) {
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
} else if (param().format == Param::Format::NHWC_NCHW4_IC_SMALL) {
C = src.layout.shape[3];
IH = src.layout.shape[1];
IW = src.layout.shape[2];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
megdnn_assert(
(C == 1) || (C == 3),
"NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3");
} else if (param().format == Param::Format::NCHW_NCHW4_IC_SMALL) {
C = src.layout.shape[1];
IH = src.layout.shape[2];
IW = src.layout.shape[3];
OH = dst.layout.shape[2];
OW = dst.layout.shape[3];
megdnn_assert(
(C == 1) || (C == 3),
"NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3");
} else {
megdnn_assert(
param().format == param::WarpPerspective::Format::NCHW,
@@ -180,55 +205,123 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
"unsupported interpolation mode for NCHW format");
auto bval = param().border_val;
auto bmode = warp_perspective::get_bmode(param().bmode);

if (src.layout.dtype == dtype::Float32{}) {
warp_perspective::forward_proxy(
is_nhwc, src.ptr<dt_float32>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_float32>(), src.layout[0], mat.layout[0], C,
IH, IW, OH, OW, bval, bmode, async_error_info(handle()),
m_error_tracker, stream);
} else if (MEGDNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16(), false)) {
if (src.layout.dtype == dst.layout.dtype) {
if (src.layout.dtype == dtype::Float32{}) {
warp_perspective::forward_proxy(
is_nhwc, src.ptr<dt_float32>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_float32>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, bval, bmode,
async_error_info(handle()), m_error_tracker,
stream);
} else if (MEGDNN_FLOAT16_SELECT(
src.layout.dtype == dtype::Float16(),
false)) {
#ifndef MEGDNN_DISABLE_FLOAT16
warp_perspective::forward_proxy(
is_nhwc, src.ptr<dt_float16>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_float16>(), src.layout[0], mat.layout[0], C,
IH, IW, OH, OW, static_cast<dt_float16>(bval), bmode,
async_error_info(handle()), m_error_tracker, stream);
warp_perspective::forward_proxy(
is_nhwc, src.ptr<dt_float16>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_float16>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, static_cast<dt_float16>(bval),
bmode, async_error_info(handle()), m_error_tracker,
stream);
#endif
} else if (src.layout.dtype == dtype::Uint8()) {
warp_perspective::forward_proxy<dt_uint8>(
is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0], C,
IH, IW, OH, OW, bval, bmode, async_error_info(handle()),
m_error_tracker, stream);
} else if (src.layout.dtype == dtype::Int8()) {
megdnn_assert(
!is_nhwc,
"WarpPerspective on CUDA does not support NHWC + Int8");
warp_perspective::forward_proxy<dt_int8>(
false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C, IH,
IW, OH, OW,
bval /* implicit float -> int8 conversion, should be
safe */
,
bmode, async_error_info(handle()), m_error_tracker,
stream);
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
megdnn_assert(param().format == Param::Format::NCHW4,
"WarpPerspective on CUDA supports NCHW4 + "
"QuantizedS8 only");
warp_perspective::forward_proxy_nchw4<dt_int8>(
src.compatible_ptr<dt_int8>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_int8>(), src.layout[0],
mat.layout[0], C, IH, IW, OH, OW, bval, bmode,
async_error_info(handle()), m_error_tracker, stream);
} else if (src.layout.dtype == dtype::Uint8()) {
warp_perspective::forward_proxy<dt_uint8>(
is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0],
C, IH, IW, OH, OW, bval, bmode,
async_error_info(handle()), m_error_tracker,
stream);
} else if (src.layout.dtype == dtype::Int8()) {
megdnn_assert(!is_nhwc,
"WarpPerspective on CUDA does not support "
"NHWC + Int8");
warp_perspective::forward_proxy<dt_int8>(
false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C,
IH, IW, OH, OW,
bval /* implicit float -> int8 conversion,
should be safe */
,
bmode, async_error_info(handle()), m_error_tracker,
stream);
} else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
megdnn_assert(param().format == Param::Format::NCHW4,
"WarpPerspective on CUDA supports NCHW4 + "
"QuantizedS8 only");
warp_perspective::forward_proxy_nchw4<dt_int8>(
src.compatible_ptr<dt_int8>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
dst.compatible_ptr<dt_int8>(), src.layout[0],
mat.layout[0], C, IH, IW, OH, OW, bval, bmode,
async_error_info(handle()), m_error_tracker,
stream);
}
} else if ((src.layout.dtype.enumv() ==
DTypeEnum::Quantized8Asymm ||
src.layout.dtype.enumv() == DTypeEnum::Uint8)) {
uint8_t zero_point = 0;
float scale = 1.f;
if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
zero_point =
src.layout.dtype.param<dtype::Quantized8Asymm>()
.zero_point;
scale = src.layout.dtype.param<dtype::Quantized8Asymm>()
.scale;
} else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
zero_point = 128;
scale = 1.f;
}
DTypeParamImpl<dt_quint8> src_dtype_param(scale, zero_point);

if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
dst.layout.dtype.param<dtype::QuantizedS8>().scale ==
scale) &&
((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) ||
(param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) {
bool is_nhwc_ic_small =
(param().format ==
Param::Format::NHWC_NCHW4_IC_SMALL);
warp_perspective::
forward_proxy_quint8_dimshuffle_typecvt_nchw4<
dt_quint8, dt_uint8, dt_int8>(
is_nhwc_ic_small,
src.compatible_ptr<dt_uint8>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>()
: nullptr,
dst.compatible_ptr<dt_int8>(),
src.layout[0], mat.layout[0], C, IH, IW, OH,
OW, bval, src_dtype_param, bmode,
async_error_info(handle()), m_error_tracker,
stream);
} else {
megdnn_assert(
((dst.layout.dtype.enumv() == DTypeEnum::Float32) &&
((param().format == Param::Format::NCHW) ||
(param().format == Param::Format::NHWC_NCHW))),
"invalid format for Quantized8Asymm input");
bool is_nhwc = (param().format == Param::Format::NHWC_NCHW);
warp_perspective::
forward_proxy_quint8_dimshuffle_typecvt_nchw<
dt_quint8, dt_uint8, dt_float32>(
is_nhwc, src.compatible_ptr<dt_uint8>(),
mat.ptr<dt_float32>(),
mat_idx.raw_ptr ? mat_idx.ptr<int>()
: nullptr,
dst.compatible_ptr<dt_float32>(),
src.layout[0], mat.layout[0], C, IH, IW, OH,
OW, bval, src_dtype_param, bmode,
async_error_info(handle()), m_error_tracker,
stream);
}
} else {
megdnn_throw(ssprintf("unsupported dtype: %s",
src.layout.dtype.name()));


+ 823
- 77
dnn/src/cuda/warp_perspective/forward.cu
File diff suppressed because it is too large
View File


+ 220
- 5
dnn/src/naive/warp_perspective/opr_impl.cpp View File

@@ -249,6 +249,162 @@ void WarpPerspectiveForwardImpl::kern_naive_nhwcd4(
MIDOUT_END();
}

template <typename ctype, typename dst_ctype, typename mtype>
void WarpPerspectiveForwardImpl::kern_naive_dimshuffle_typecvt(
const KernParam<ctype, mtype>& kern_param, size_t task_id) {
MEGDNN_MARK_USED_VAR(kern_param);
MIDOUT_BEGIN(megdnn_naive_warpperspective, ctype, mtype, midout_iv(2)) {
UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param);
MEGDNN_MARK_USED_VAR(N_MAT);
//! strides of C, H, W on src and dst
size_t sstrd[3], dstrd[3];
auto set_sstrd = [&](size_t s0, size_t s1, size_t s2) {
sstrd[0] = s0;
sstrd[1] = s1;
sstrd[2] = s2;
};
auto set_dstrd = [&](size_t s0, size_t s1, size_t s2) {
dstrd[0] = s0;
dstrd[1] = s1;
dstrd[2] = s2;
};
switch (kern_param.format) {
case Format::NCHW:
case Format::NCHW_NCHW4_IC_SMALL:
set_sstrd(IH * IW, IW, 1);
set_dstrd(OH * OW, OW, 1);
break;
case Format::NHWC_NCHW:
case Format::NHWC_NCHW4_IC_SMALL:
set_sstrd(1, IW * C, C);
set_dstrd(OH * OW, OW, 1);
break;
default:
megdnn_throw("bad format");
}

uint8_t zero_point = 0;
float scale = 1.f;

bool is_dst_float = kern_param.dst_dtype.enumv() == DTypeEnum::Float32;
if (kern_param.src_dtype.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv) {
auto dtype_param =
kern_param.src_dtype
.template param<dtype::Quantized8Asymm>();
zero_point = dtype_param.zero_point;
scale = dtype_param.scale;
} else if (kern_param.src_dtype.enumv() == DTypeEnum::Uint8) {
zero_point =
(kern_param.dst_dtype.enumv() == DTypeEnum::QuantizedS8)
? 128
: 0;
scale = 1.f;
}

dst_ctype* dst_ptr = reinterpret_cast<dst_ctype*>(dptr);

bool is_dst_nchw4 =
(kern_param.format == Format::NCHW_NCHW4_IC_SMALL) ||
(kern_param.format == Format::NHWC_NCHW4_IC_SMALL);
auto visit_src = [&sptr, sstrd](size_t c, int h, int w) -> float {
return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w];
};
auto visit_src_bd = [&sptr, sstrd, border_val](size_t c, int h,
int w) -> float {
if (h != -1 && w != -1) {
return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w];
} else
return border_val;
};
auto visit_dst = [&dst_ptr, dstrd, is_dst_nchw4](size_t c, int h,
int w) -> dst_ctype& {
if (!is_dst_nchw4)
return dst_ptr[dstrd[0] * c + dstrd[1] * h + dstrd[2] * w];
else
return dst_ptr[((dstrd[0] * (c >> 2) + dstrd[1] * h +
dstrd[2] * w)
<< 2) +
(c & 0b11)];
};

rounding::RoundingConverter<dst_ctype> output_converter;
auto orig_sptr = sptr;
size_t n = task_id / OH;
size_t oh = task_id % OH;
mptr = mptr + n * 3 * 3;
dst_ptr = is_dst_nchw4 ? (dst_ptr + n * OH * OW * 4)
: (dst_ptr + n * C * OH * OW);
if (midx_ptr) {
size_t idx = midx_ptr[n];
megdnn_assert(
idx < N_SRC,
"mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", n,
idx, N_SRC);
sptr = orig_sptr + idx * (C * IH * IW);
} else if (n) {
sptr += n * C * IH * IW;
}
rep(ow, OW) {
float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2];
float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5];
float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8];
float alphaw = numeratorw / denominator;
float alphah = numeratorh / denominator;

int iw0 = get_real_coord(std::floor(alphaw) + 0, IW);
int iw1 = get_real_coord(std::floor(alphaw) + 1, IW);
int ih0 = get_real_coord(std::floor(alphah) + 0, IH);
int ih1 = get_real_coord(std::floor(alphah) + 1, IH);

alphaw -= floor(alphaw);
alphah -= floor(alphah);
if (bmode != BorderMode::CONSTANT) {
rep(c, C) {
auto val =
visit_src(c, ih0, iw0) * (1.0f - alphaw) *
(1.0f - alphah) +
visit_src(c, ih0, iw1) * alphaw * (1.0f - alphah) +
visit_src(c, ih1, iw0) * (1.0f - alphaw) * alphah +
visit_src(c, ih1, iw1) * alphaw * alphah;
val = is_dst_float ? (val - zero_point) * scale
: val - zero_point;
visit_dst(c, oh, ow) = output_converter(val);
}
} else {
rep(c, C) {
auto val = visit_src_bd(c, ih0, iw0) * (1.0f - alphaw) *
(1.0f - alphah) +
visit_src_bd(c, ih0, iw1) * alphaw *
(1.0f - alphah) +
visit_src_bd(c, ih1, iw0) * (1.0f - alphaw) *
alphah +
visit_src_bd(c, ih1, iw1) * alphaw * alphah;
val = std::isfinite(val) ? val : border_val;
val = is_dst_float ? (val - zero_point) * scale
: val - zero_point;
visit_dst(c, oh, ow) = output_converter(val);
}
}
if (is_dst_nchw4) {
for (auto c = C; c < 4; ++c) {
visit_dst(c, oh, ow) = 0;
}
}
}
}
MIDOUT_END();
}

#define INST(ctype, drc_ctype, mtype) \
template void WarpPerspectiveForwardImpl::kern_naive_dimshuffle_typecvt< \
ctype, drc_ctype, mtype>(const KernParam<ctype, mtype>&, size_t);

INST(uint8_t, int8_t, float);
INST(uint8_t, float, float);

#undef INST

void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in mat,
_megdnn_tensor_in mat_idx,
@@ -320,6 +476,65 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
src.layout.dtype.name())
.c_str());
}

bool is_fusion_dtype = src.layout.dtype.enumv() != dst.layout.dtype.enumv();
bool is_u8_or_qu8_in =
src.layout.dtype.enumv() == DTypeTrait<dtype::Uint8>::enumv ||
src.layout.dtype.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv;

if (is_fusion_dtype && is_u8_or_qu8_in &&
((param().format == Format::NCHW_NCHW4_IC_SMALL) ||
(param().format == Format::NHWC_NCHW4_IC_SMALL) ||
(param().format == Format::NHWC_NCHW) ||
(param().format == Format::NCHW))) {
if (src.layout.dtype.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv ||
src.layout.dtype.enumv() == DTypeTrait<dtype::Uint8>::enumv) {
float scale = 1.f;

if (src.layout.dtype.enumv() ==
DTypeTrait<dtype::Quantized8Asymm>::enumv) {
scale = src.layout.dtype.param<dtype::Quantized8Asymm>().scale;
}

auto kparam = KernParam<uint8_t, float>::from_tensors(
param().format, param().bmode, param().border_val, src, mat,
mat_idx, dst, workspace);

if (dst.layout.dtype.enumv() == DTypeTrait<dtype::Float32>::enumv) {
auto run = [kparam, this](size_t index, size_t) {
kern_naive_dimshuffle_typecvt<uint8_t, float, float>(kparam,
index);
};
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run,
kparam.oh * batch);
return;
} else if ((dst.layout.dtype.enumv() ==
DTypeTrait<dtype::QuantizedS8>::enumv) &&
(dst.layout.dtype.param<dtype::QuantizedS8>().scale ==
scale)) {
auto run = [kparam, this](size_t index, size_t) {
kern_naive_dimshuffle_typecvt<uint8_t, int8_t, float>(
kparam, index);
};
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run,
kparam.oh * batch);
return;
} else {
megdnn_throw(ssprintf("Unsupported DType in "
"WarpPerspective Dimshuffle Typecvt: %s",
src.layout.dtype.name())
.c_str());
}
}

megdnn_throw(ssprintf("Unsupported input DType in "
"WarpPerspective: %s",
src.layout.dtype.name())
.c_str());
}

if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode,
param().format)) {
MIDOUT_BEGIN(megdnn_naive_warpperspective, void) {
@@ -331,12 +546,12 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src,
megdnn_assert(warp::is_dnn_available(src.layout, mat.layout, dst.layout,
param().imode, param().format));
/*!
* We currently use floating point for all WarpPerspective computation,
* so even if the input ctype is one of the integer type, mtype should
* always be float32.
* We currently use floating point for all WarpPerspective
* computation, so even if the input ctype is one of the integer
* type, mtype should always be float32.
*
* \warning It's different with \c WarpAffine, with mtype be float16 if
* input type is float16.
* \warning It's different with \c WarpAffine, with mtype be float16
* if input type is float16.
*/

DISPATCH_ST(dtype::Float32, float, float, KERN);


+ 32
- 10
dnn/src/naive/warp_perspective/opr_impl.h View File

@@ -26,6 +26,7 @@ protected:
float border_val;
size_t n_src, n_mat, c, ih, iw, oh, ow;
ctype *sptr, *dptr;
DType src_dtype, dst_dtype;
mtype* mptr;
int* midx_ptr; //!< can be null
Workspace workspace;
@@ -41,6 +42,8 @@ protected:
ret.bmode = bmode;
ret.border_val = border_val;
ret.n_src = src.layout.shape[0];
ret.src_dtype = src.layout.dtype;
ret.dst_dtype = dst.layout.dtype;
if (mat_idx.raw_ptr) {
megdnn_assert(mat_idx.layout.ndim == 1);
ret.n_mat = mat_idx.layout.shape[0];
@@ -50,7 +53,8 @@ protected:
ret.n_mat = ret.n_src;
ret.midx_ptr = nullptr;
}
if (format == Format::NCHW) {
if (format == Format::NCHW ||
format == Format::NCHW_NCHW4_IC_SMALL) {
ret.c = src.layout.shape[1];
ret.ih = src.layout.shape[2];
ret.iw = src.layout.shape[3];
@@ -62,6 +66,13 @@ protected:
ret.iw = src.layout.shape[2];
ret.oh = dst.layout.shape[1];
ret.ow = dst.layout.shape[2];
} else if (format == Format::NHWC_NCHW ||
format == Format::NHWC_NCHW4_IC_SMALL) {
ret.c = src.layout.shape[3];
ret.ih = src.layout.shape[1];
ret.iw = src.layout.shape[2];
ret.oh = dst.layout.shape[2];
ret.ow = dst.layout.shape[3];
} else if (format == Format::NCHW4) {
ret.c = src.layout.shape[1] * 4;
ret.ih = src.layout.shape[2];
@@ -76,15 +87,16 @@ protected:
ret.oh = dst.layout.shape[1];
ret.ow = dst.layout.shape[3];
}
if (src.layout.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
(src.layout.dtype.enumv() == DTypeEnum::Float16 ||
src.layout.dtype.enumv() == DTypeEnum::BFloat16),
false) ||
src.layout.dtype.enumv() == DTypeEnum::Int8 ||
src.layout.dtype.enumv() == DTypeEnum::Uint8 ||
src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
if ((src.layout.dtype.enumv() == DTypeEnum::Float32 ||
MEGDNN_FLOAT16_SELECT(
(src.layout.dtype.enumv() == DTypeEnum::Float16 ||
src.layout.dtype.enumv() == DTypeEnum::BFloat16),
false) ||
src.layout.dtype.enumv() == DTypeEnum::Int8 ||
src.layout.dtype.enumv() == DTypeEnum::Uint8 ||
src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 ||
src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) &&
(src.layout.dtype == dst.layout.dtype)) {
ret.sptr = src.compatible_ptr<ctype>();
ret.mptr = mat.ptr<mtype>();
ret.dptr = dst.compatible_ptr<ctype>();
@@ -92,6 +104,13 @@ protected:
ret.sptr = src.compatible_ptr<ctype>();
ret.mptr = mat.ptr<mtype>();
ret.dptr = dst.compatible_ptr<ctype>();
} else if ((src.layout.dtype.enumv() == DTypeEnum::Uint8 ||
src.layout.dtype.enumv() ==
DTypeEnum::Quantized8Asymm) &&
src.layout.dtype.enumv() != dst.layout.dtype.enumv()) {
ret.sptr = src.compatible_ptr<ctype>();
ret.mptr = mat.ptr<mtype>();
ret.dptr = reinterpret_cast<ctype*>(dst.raw_ptr);
} else {
ret.sptr = nullptr;
ret.mptr = nullptr;
@@ -122,6 +141,9 @@ private:
template <typename ctype, typename mtype>
void kern_naive_nhwcd4(const KernParam<ctype, mtype>& kern_param,
size_t task_id);
template <typename ctype, typename dst_ctype, typename mtype>
void kern_naive_dimshuffle_typecvt(
const KernParam<ctype, mtype>& kern_param, size_t task_id);
};

class WarpPerspectiveBackwardDataImpl : public WarpPerspectiveBackwardData {


+ 139
- 2
dnn/test/cuda/warp_perspective.cpp View File

@@ -23,8 +23,7 @@ using namespace megdnn;
using namespace test;

class NanMatRNG : public RNG {
void gen(const TensorND& tensor_) override
{
void gen(const TensorND& tensor_) override {
auto& gen = RandomState::generator();
std::uniform_real_distribution<dt_float32> pdist3(1.9f, 2.1f);
std::uniform_real_distribution<dt_float32> pdist(0.9f, 1.1f);
@@ -335,6 +334,144 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW4) {
}
}

TEST_F(CUDA, WARP_PERSPECTIVE_NCHW_NCHW4_IC_SMALL) {
using Param = WarpPerspective::Param;
WarpPerspective::Param param;
Checker<WarpPerspectiveForward> checker(handle_cuda());
WarpPerspectiveMatRNG rng;

param.format = Param::Format::NCHW_NCHW4_IC_SMALL;

checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::Quantized8Asymm(0.1f, 128));
checker.set_dtype(2, dtype::QuantizedS8(0.1f));
for (auto bmode : {WarpPerspective::BorderMode::WRAP,
WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;

checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
checker.execs({{2, 3, 10, 11}, {2, 3, 3}, {2, 1, 11, 12, 4}});
checker.execs({{1, 3, 25, 510}, {1, 3, 3}, {1, 1, 25, 25, 4}});
checker.execs({{1, 3, 25, 25}, {1, 3, 3}, {1, 1, 51, 51, 4}});
checker.execs({{1, 3, 51, 51}, {1, 3, 3}, {1, 1, 25, 25, 4}});
}
{
Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker(
handle_cuda());
constexpr int N_SRC = 5;
UniformIntRNG mat_idx_rng{0, N_SRC - 1};
checker.set_dtype(0, dtype::Quantized8Asymm(0.1f, 128));
checker.set_rng(1, &rng);
checker.set_dtype(2, dtype::Int32());
checker.set_rng(2, &mat_idx_rng);
checker.set_dtype(3, dtype::QuantizedS8(0.1f));
param.bmode = WarpPerspective::Param::BorderMode::REFLECT;
param.imode = param::WarpPerspective::InterpolationMode::LINEAR;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
checker.execs({{N_SRC, 3, 10, 11}, {2, 3, 3}, {2}, {2, 1, 11, 12, 4}});
checker.execs(
{{N_SRC, 3, 17, 13}, {123, 3, 3}, {123}, {123, 1, 16, 15, 4}});
}
}

TEST_F(CUDA, WARP_PERSPECTIVE_NHWC_NCHW4_IC_SMALL) {
using Param = WarpPerspective::Param;
WarpPerspective::Param param;
Checker<WarpPerspectiveForward> checker(handle_cuda());
WarpPerspectiveMatRNG rng;

param.format = Param::Format::NHWC_NCHW4_IC_SMALL;

checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::Uint8());
checker.set_dtype(2, dtype::QuantizedS8(1.f));
for (auto bmode : {WarpPerspective::BorderMode::WRAP,
WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;

checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
checker.execs({{2, 10, 11, 3}, {2, 3, 3}, {2, 1, 11, 12, 4}});
checker.execs({{1, 25, 510, 3}, {1, 3, 3}, {1, 1, 25, 25, 4}});
checker.execs({{1, 25, 25, 3}, {1, 3, 3}, {1, 1, 51, 51, 4}});
checker.execs({{1, 51, 51, 3}, {1, 3, 3}, {1, 1, 25, 25, 4}});
}
{
Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker(
handle_cuda());
constexpr int N_SRC = 5;
UniformIntRNG mat_idx_rng{0, N_SRC - 1};
checker.set_dtype(0, dtype::Uint8());
checker.set_rng(1, &rng);
checker.set_dtype(2, dtype::Int32());
checker.set_rng(2, &mat_idx_rng);
checker.set_dtype(3, dtype::QuantizedS8(1.f));
param.bmode = WarpPerspective::Param::BorderMode::REFLECT;
param.imode = param::WarpPerspective::InterpolationMode::LINEAR;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 1, 11, 12, 4}});
checker.execs(
{{N_SRC, 17, 13, 3}, {123, 3, 3}, {123}, {123, 1, 16, 15, 4}});
}
}

TEST_F(CUDA, WARP_PERSPECTIVE_NHWC_NCHW) {
using Param = WarpPerspective::Param;
WarpPerspective::Param param;
Checker<WarpPerspectiveForward> checker(handle_cuda());
WarpPerspectiveMatRNG rng;

param.format = Param::Format::NHWC_NCHW;

checker.set_rng(1, &rng);
checker.set_dtype(0, dtype::Uint8());
checker.set_dtype(2, dtype::Float32());
for (auto bmode : {WarpPerspective::BorderMode::WRAP,
WarpPerspective::BorderMode::REFLECT,
WarpPerspective::BorderMode::REPLICATE,
WarpPerspective::BorderMode::CONSTANT}) {
param.border_val = 0.3f;
param.bmode = bmode;
param.imode = Param::InterpolationMode::LINEAR;

checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
checker.execs({{2, 10, 11, 3}, {2, 3, 3}, {2, 3, 11, 12}});
checker.execs({{1, 25, 510, 3}, {1, 3, 3}, {1, 3, 25, 25}});
checker.execs({{1, 25, 25, 3}, {1, 3, 3}, {1, 3, 51, 51}});
checker.execs({{1, 51, 51, 3}, {1, 3, 3}, {1, 3, 25, 25}});
}
{
Checker<WarpPerspective, WarpPerspectiveMatIdxProxy> checker(
handle_cuda());
constexpr int N_SRC = 5;
UniformIntRNG mat_idx_rng{0, N_SRC - 1};
checker.set_dtype(0, dtype::Uint8());
checker.set_rng(1, &rng);
checker.set_dtype(2, dtype::Int32());
checker.set_rng(2, &mat_idx_rng);
checker.set_dtype(3, dtype::Float32());
param.bmode = WarpPerspective::Param::BorderMode::REFLECT;
param.imode = param::WarpPerspective::InterpolationMode::LINEAR;
checker.set_param(param);
checker.set_epsilon(1 + 1e-3);
checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 3, 11, 12}});
checker.execs(
{{N_SRC, 17, 13, 3}, {123, 3, 3}, {123}, {123, 3, 16, 15}});
}
}

TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_NCHW_INT8) {
warp_perspective::run_int8_test(handle_cuda());
}


+ 143
- 148
src/gopt/impl/framework.cpp View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/gopt/framework.h"
@@ -35,13 +36,13 @@ using namespace gopt;
/* ================ SubGraph ================ */

OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
OperatorNodeBase *opr) {
auto &&new_inp = m_opr_new_inp_cache;
OperatorNodeBase* opr) {
auto&& new_inp = m_opr_new_inp_cache;
new_inp.clear();
new_inp.reserve(opr->input().size());
bool has_replaced_inp = false;

for (auto i: opr->input()) {
for (auto i : opr->input()) {
auto new_var = get_var(i);
if (new_var != i) {
has_replaced_inp = true;
@@ -52,14 +53,14 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
}

if (has_replaced_inp) {
auto new_opr = serialization::copy_opr_shallow(
*opr, new_inp, opr->config());
auto new_opr =
serialization::copy_opr_shallow(*opr, new_inp, opr->config());
auto &&out0 = opr->output(), &&out1 = new_opr->output();
size_t i = 0;
auto err_msg = [opr, new_opr] {
return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}",
opr->cname(), opr->dyn_typeinfo()->name,
new_opr->cname(), new_opr->dyn_typeinfo()->name);
return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}", opr->cname(),
opr->dyn_typeinfo()->name, new_opr->cname(),
new_opr->dyn_typeinfo()->name);
};
MGB_MARK_USED_VAR(err_msg);
// opr output size mismatch may be caused by:
@@ -67,33 +68,33 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs(
// 1) other post-insert optimization (e.g. const folding)
// we can't handle only usable_output here, since some output var with
// volatile flag could be the graph's endpoint (e.g. RemoteSend)
for (; i < std::min(out0.size(), out1.size()); ++ i) {
for (; i < std::min(out0.size(), out1.size()); ++i) {
bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT);
mgb_assert(v0 == v1, "%s", err_msg().c_str());

auto &&ins = m_varmap.insert({out0[i], {true, nullptr}});
auto&& ins = m_varmap.insert({out0[i], {true, nullptr}});
mgb_assert(ins.second || ins.first->second.first,
"opr output already replaced");
// handle repeated call on the same opr
ins.first->second.second = out1[i];
on_var_replaced(out0[i], out1[i], nullptr);
}
for (; i < out0.size(); ++ i) {
for (; i < out0.size(); ++i) {
mgb_assert(out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
"%s", err_msg().c_str());
"%s", err_msg().c_str());
}
for (; i < out1.size(); ++ i) {
for (; i < out1.size(); ++i) {
mgb_assert(out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
"%s", err_msg().c_str());
"%s", err_msg().c_str());
}
return new_opr;
}
return opr;
}

void SubGraph::Rewriter::replace_var(
VarNode *src, VarNode *dst, const char *msg) {
void SubGraph::Rewriter::replace_var(VarNode* src, VarNode* dst,
const char* msg) {
if (src == dst)
return;

@@ -103,19 +104,19 @@ void SubGraph::Rewriter::replace_var(
"dst %s maps back to src %s in SubGraph::Rewriter::replace_var",
dst->cname(), src->cname());

auto &&ins = m_varmap.insert({src, {false, dst}});
auto&& ins = m_varmap.insert({src, {false, dst}});
if (!ins.second) {
auto &&old_rep = ins.first->second;
auto&& old_rep = ins.first->second;
mgb_assert(old_rep.first || old_rep.second == dst,
"can not replace a var twice");
"can not replace a var twice");
old_rep.first = false;
old_rep.second = dst;
}
on_var_replaced(src, dst, msg);
}

void SubGraph::Rewriter::on_var_replaced(
VarNode* src, VarNode* dst, const char* msg) {
void SubGraph::Rewriter::on_var_replaced(VarNode* src, VarNode* dst,
const char* msg) {
if (auto state = m_owner_graph->owner_opt_state()) {
state->on_var_replaced(src, dst, msg);
}
@@ -124,7 +125,7 @@ void SubGraph::Rewriter::on_var_replaced(
void SubGraph::Rewriter::apply_inplace() const {
m_owner_graph->m_endpoint_oprs.clear();
m_owner_graph->m_endpoint_vars_set.clear();
for (auto &&var: m_owner_graph->m_endpoint_vars) {
for (auto&& var : m_owner_graph->m_endpoint_vars) {
var = get_var(var.node());
m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr());
m_owner_graph->m_endpoint_vars_set.insert(var.node());
@@ -150,33 +151,30 @@ std::pair<bool, VarNode*> SubGraph::Rewriter::get_var_internal(VarNode* var) {
return it->second = {it_next->second.first & it->second.first, next.second};
}

SubGraph::SubGraph(const SymbolVarArray &endpoint_vars):
m_endpoint_vars(endpoint_vars)
{
SubGraph::SubGraph(const SymbolVarArray& endpoint_vars)
: m_endpoint_vars(endpoint_vars) {
mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty");
m_comp_graph = endpoint_vars[0].node()->owner_graph();
for (auto i: endpoint_vars) {
for (auto i : endpoint_vars) {
m_endpoint_oprs.insert(i.node()->owner_opr());
m_endpoint_vars_set.insert(i.node());
mgb_assert(m_comp_graph == i.node()->owner_graph(),
"endpoints belong to different computing graphs");
"endpoints belong to different computing graphs");
}
}

void SubGraph::iter(
const Callback& cb,
std::shared_ptr<ExtraDep> extra_dep) const {
void SubGraph::iter(const Callback& cb,
std::shared_ptr<ExtraDep> extra_dep) const {
Callback on_opr;

if (m_owner_opt_state) {
on_opr = [state=m_owner_opt_state, &cb](OperatorNodeBase *opr) {
on_opr = [state = m_owner_opt_state, &cb](OperatorNodeBase* opr) {
state->m_opr_property_flag = OprPropertyFlag::ALL;
state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr);
state->m_cur_iter_opr_priority =
opr->node_prop().attribute().priority;
opr->node_prop().attribute().priority;
state->m_cur_iter_opr_stream_prop_type =
state->m_comp_node_opt.stream_prop_type(
opr->output(0));
state->m_comp_node_opt.stream_prop_type(opr->output(0));
mgb_assert(state->m_oprs_inserted.empty());
cb(opr);
state->m_opr_property_flag = OprPropertyFlag::NONE;
@@ -188,19 +186,19 @@ void SubGraph::iter(
}

cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)};
for (auto i: m_endpoint_oprs)
for (auto i : m_endpoint_oprs)
dep_iter.add(i);
}

ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {
ThinHashMap<VarNode*, size_t> ret;
auto cb = [&](OperatorNodeBase *opr) {
for (auto &&i: opr->node_prop().dep_map()) {
auto cb = [&](OperatorNodeBase* opr) {
for (auto&& i : opr->node_prop().dep_map()) {
if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) {
++ ret.at(i.first);
++ret.at(i.first);
}
}
for (auto i: opr->output()) {
for (auto i : opr->output()) {
if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
auto ins = ret.insert({i, 0});
mgb_assert(ins.second);
@@ -208,13 +206,13 @@ ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {
}
};
iter(cb);
for (auto i: m_endpoint_vars_set) {
for (auto i : m_endpoint_vars_set) {
auto iter = ret.find(i);
if (iter == ret.end()) {
mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
ret[i] = 1;
} else {
++ ret.at(i);
++ret.at(i);
}
}
return ret;
@@ -222,10 +220,8 @@ ThinHashMap<VarNode*, size_t> SubGraph::get_var2nr_val_dep_oprs() const {

/* ================ UniqReaderCheck ================ */

UniqReaderCheck::UniqReaderCheck(const SubGraph &graph):
m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()}
{
}
UniqReaderCheck::UniqReaderCheck(const SubGraph& graph)
: m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} {}

void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr,
OperatorNodeBase* repl_opr) {
@@ -253,32 +249,30 @@ void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr,

/* ================ OptState ================ */

OptState::OptState(
const GraphOptimizer *owner_optimizer, const SubGraph& graph):
m_owner_optimizer{owner_optimizer},
m_var_replace_map{
const_cast<ThinHashMap<VarNode*, VarNode*>*>(
&GraphOptimizer::var_replace_map(*graph.comp_graph()))},
m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()},
m_graph{graph}
{
OptState::OptState(const GraphOptimizer* owner_optimizer, const SubGraph& graph)
: m_owner_optimizer{owner_optimizer},
m_var_replace_map{const_cast<ThinHashMap<VarNode*, VarNode*>*>(
&GraphOptimizer::var_replace_map(*graph.comp_graph()))},
m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()},
m_graph{graph} {
mgb_assert(!m_graph.m_owner_opt_state);
m_var_replace_map->clear();
m_graph.m_owner_opt_state = this;
m_oprs_inserted.clear();

auto on_opr_insert = [this](const cg::event::OprInserted &ev) {
auto on_opr_insert = [this](const cg::event::OprInserted& ev) {
auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR,
need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY;
if (need_src_opr)
mgb_assert(m_cur_iter_src_opr, "opr %s{%s} created outside from "
"SubGraph::iter",
ev.opr->cname(), ev.opr->dyn_typeinfo()->name);
mgb_assert(m_cur_iter_src_opr,
"opr %s{%s} created outside from "
"SubGraph::iter",
ev.opr->cname(), ev.opr->dyn_typeinfo()->name);
if (ev.exc || ev.is_dedup)
return;

auto &&new_attr = ev.opr->node_prop().attribute();
auto &&ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE});
auto&& new_attr = ev.opr->node_prop().attribute();
auto&& ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE});
mgb_assert(ins.second);

if (need_src_opr && !new_attr.src_opr) {
@@ -296,20 +290,22 @@ OptState::OptState(

auto csp = m_cur_iter_opr_stream_prop_type;
if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) {
for (auto i: ev.opr->output())
for (auto i : ev.opr->output())
m_comp_node_opt.register_stream_var(i, csp);
}
};
m_on_opr_insert_handler = graph.comp_graph()->event().register_receiver<
cg::event::OprInserted>(on_opr_insert);
m_on_opr_insert_handler =
graph.comp_graph()
->event()
.register_receiver<cg::event::OprInserted>(on_opr_insert);
}

void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) {
void OptState::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) {
if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
// this can only happen in auto_replace_outputs()
mgb_assert(dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) &&
src->owner_opr()->dyn_typeinfo() ==
dst->owner_opr()->dyn_typeinfo());
src->owner_opr()->dyn_typeinfo() ==
dst->owner_opr()->dyn_typeinfo());
mgb_assert(!msg);
return;
}
@@ -362,7 +358,7 @@ void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) {
return f & (InferType::RT_STATIC | InferType::CONST);
};
if (!(norm(it0.shape) == norm(it1.shape) &&
norm(it0.value) <= norm(it1.value))) {
norm(it0.value) <= norm(it1.value))) {
suc = false;
fail_chks.push_back("infer-type");
}
@@ -407,22 +403,21 @@ void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) {

#if MGB_ENABLE_LOGGING
if (msg && m_owner_optimizer->verbosity()) {
m_log_msg.
append("\n ").
append(std::to_string(m_log_nr_item)).
append(": ").
append(src->owner_opr()->cname()).
append(" => ").
append(dst->owner_opr()->cname()).
append(" (").
append(msg).
append(")");
}
++ m_log_nr_item;
m_log_msg.append("\n ")
.append(std::to_string(m_log_nr_item))
.append(": ")
.append(src->owner_opr()->cname())
.append(" => ")
.append(dst->owner_opr()->cname())
.append(" (")
.append(msg)
.append(")");
}
++m_log_nr_item;
#endif
}

size_t OptState::flush_log(const char *title) {
size_t OptState::flush_log(const char* title) {
if (m_owner_optimizer->verbosity() >= 2) {
if (m_log_msg.empty()) {
m_log_msg = mgb_cstr_log(" no var replacement logged");
@@ -435,42 +430,40 @@ size_t OptState::flush_log(const char *title) {
return ret;
}

void OptState::call_with_opr(OperatorNodeBase *opr, thin_function<void(void)> func,
void OptState::call_with_opr(OperatorNodeBase* opr,
thin_function<void(void)> func,
OprPropertyFlag opr_property_flag) {
auto src_opr = cg::get_opr_root_source_opr(opr);
auto opr_priority = opr->node_prop().attribute().priority;
auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0));
ThinHashMap<OperatorNodeBase*, OprPropertyFlag> oprs_inserted;

auto swap_properties = [&,
need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR,
need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] {
if (need_src_opr) {
std::swap(m_cur_iter_src_opr, src_opr);
}
if (need_priority) {
std::swap(m_cur_iter_opr_priority, opr_priority);
}
std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type);
std::swap(m_opr_property_flag, opr_property_flag);
std::swap(m_oprs_inserted, oprs_inserted);
};
auto swap_properties =
[&, need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR,
need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] {
if (need_src_opr) {
std::swap(m_cur_iter_src_opr, src_opr);
}
if (need_priority) {
std::swap(m_cur_iter_opr_priority, opr_priority);
}
std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type);
std::swap(m_opr_property_flag, opr_property_flag);
std::swap(m_oprs_inserted, oprs_inserted);
};
MGB_TRY {
swap_properties();
func();
} MGB_FINALLY({
swap_properties();
});
}
MGB_FINALLY({ swap_properties(); });
}

/* ================ RecursiveSubGraphRewriteHelper ================ */
RecursiveSubGraphRewriteHelper::
~RecursiveSubGraphRewriteHelper() noexcept = default;
RecursiveSubGraphRewriteHelper::~RecursiveSubGraphRewriteHelper() noexcept =
default;

RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState &state):
m_opt_state{state}, m_rewriter{state.graph().make_rewriter()}
{
}
RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState& state)
: m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} {}

void RecursiveSubGraphRewriteHelper::apply() {
using namespace std::placeholders;
@@ -479,8 +472,8 @@ void RecursiveSubGraphRewriteHelper::apply() {
m_rewriter.apply_inplace();
}

void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) {
auto on_new_opr = [this](OperatorNodeBase *opr) {
void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase* opr) {
auto on_new_opr = [this](OperatorNodeBase* opr) {
auto repl_opr = m_rewriter.auto_replace_outputs(opr);
return on_new_opr_check_should_process(opr, repl_opr);
};
@@ -493,8 +486,8 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) {
return;

mgb_assert(m_opr_stack.empty());
m_opr_stack.push_back({
orig_out, m_rewriter.get_var(orig_out)->owner_opr()});
m_opr_stack.push_back(
{orig_out, m_rewriter.get_var(orig_out)->owner_opr()});

bool first = true;
while (!m_opr_stack.empty()) {
@@ -515,9 +508,9 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) {
if (should_process) {
auto trans = process_opr(cur_out);
if (trans.valid()) {
m_opr_stack.push_back({
cur_frame.orig_var, trans->result->owner_opr()});
for (auto i: reverse_adaptor(trans->internal)) {
m_opr_stack.push_back(
{cur_frame.orig_var, trans->result->owner_opr()});
for (auto i : reverse_adaptor(trans->internal)) {
if (i)
m_opr_stack.push_back({i, i->owner_opr()});
}
@@ -532,7 +525,7 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) {

auto src = cur_frame.orig_var;
if (m_rewriter.get_var(src) != cur_out) {
const char *msg = nullptr;
const char* msg = nullptr;
if (m_opr_stack.empty()) {
msg = m_log_msg.c_str();
}
@@ -550,11 +543,12 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) {

GraphOptimizer::~GraphOptimizer() noexcept = default;

class GraphOptimizer::VarReplaceMapStorage :public UserDataContainer::UserData {
class GraphOptimizer::VarReplaceMapStorage
: public UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;

public:
ThinHashMap<VarNode*, VarNode*> map;
public:
ThinHashMap<VarNode*, VarNode*> map;
};
MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage);

@@ -565,7 +559,7 @@ GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr<Pass> pass) {
return *this;
}

SubGraph GraphOptimizer::apply(const SubGraph &graph) const {
SubGraph GraphOptimizer::apply(const SubGraph& graph) const {
RealTimer timer;
OptState state{this, graph};

@@ -574,38 +568,38 @@ SubGraph GraphOptimizer::apply(const SubGraph &graph) const {
// first update output var shapes of all oprs
state.graph().iter(cg::update_output_var_shapes);

auto &&opt = graph.comp_graph()->options();
auto&& opt = graph.comp_graph()->options();
auto orig_setting = opt.graph_opt_level;
Pass *cur_pass = nullptr;
Pass* cur_pass = nullptr;
MGB_MARK_USED_VAR(cur_pass);
MGB_TRY {
for (auto &&i: m_passes) {
for (auto&& i : m_passes) {
state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL);
cur_pass = i.get();
opt.graph_opt_level = 1;
i->apply(state);
tot_nr_replace += state.flush_log(
mgb_ssprintf_log(
"apply optimization pass %s:", i->name()).c_str());
mgb_ssprintf_log("apply optimization pass %s:", i->name())
.c_str());
}
} MGB_CATCH(std::exception &exc, {
}
MGB_CATCH(std::exception & exc, {
mgb_log_error("error while applying optimization pass %s: %s",
cur_pass->name(), exc.what());
cur_pass->name(), exc.what());
opt.graph_opt_level = orig_setting;
throw;
})
MGB_FINALLY(
opt.graph_opt_level = orig_setting
);
MGB_FINALLY(opt.graph_opt_level = orig_setting);
if (verbosity() >= 1) {
mgb_log_debug("graph optimization: applied %zu passes, "
mgb_log_debug(
"graph optimization: applied %zu passes, "
"total %zu var(s) replaced; time=%.2fms",
m_passes.size(), tot_nr_replace, timer.get_msecs());
}
return state.graph();
}

const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray &vars) const {
const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray& vars) const {
if (m_passes.empty()) {
// this check is necessary, since OptState would clear
// var_replace_map()
@@ -613,7 +607,7 @@ const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray &vars) const {
}

auto g = apply({{vars.begin(), vars.end()}});
for (size_t i = 0; i < vars.size(); ++ i) {
for (size_t i = 0; i < vars.size(); ++i) {
vars[i] = g.endpoint_vars()[i].node();
}
return *this;
@@ -653,7 +647,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
#if MGB_JIT
bool need_jit = false;
if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 ||
comp_graph_opt->graph_opt.jit)) {
comp_graph_opt->graph_opt.jit)) {
need_jit = true;
}
if (need_jit && after_grad) {
@@ -679,7 +673,6 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
add_passes_for_optimize_options(*inference_opt);
}


if (inference_opt) {
// merge params to reduce loading time and graph overhead
add_pass<ParamMergePass>();
@@ -689,15 +682,16 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(
}

const ThinHashMap<VarNode*, VarNode*>& GraphOptimizer::var_replace_map(
ComputingGraph &graph) {
auto storage = graph.options().user_data.get_user_data_or_create<
VarReplaceMapStorage>();
ComputingGraph& graph) {
auto storage =
graph.options()
.user_data.get_user_data_or_create<VarReplaceMapStorage>();
return storage->map;
}

VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) {
auto &&map = var_replace_map(*(var->owner_graph()));
for (; ; ) {
VarNode* GraphOptimizer::var_replace_lookup(VarNode* var) {
auto&& map = var_replace_map(*(var->owner_graph()));
for (;;) {
auto iter = map.find(var);
if (iter == map.end())
return var;
@@ -705,7 +699,6 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) {
}
}


const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
const cg::GraphCommonOptimizeOptions& options) {
return add_passes_for_optimize_options(
@@ -723,12 +716,14 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
options.disable_##_option(); \
} \
}
cb(fuse_preprocess, {add_pass(FuseNCHW4Int8Preprocess::make());});

cb(fuse_preprocess, {
add_pass(FuseNCHW4Int8Preprocess::make());
add_pass<FuseWarpPerspectiveDimshufflePass>();
});
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });


cb(nchw4, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
@@ -763,6 +758,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
add_pass(FuseNCHW4Int8Preprocess::make());
add_pass<FuseWarpPerspectiveDimshufflePass>();
});
cb(chwn4, {
add_pass<FuseConvBiasNonlinPass>();
@@ -790,9 +786,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
/* ================ ConstVarPropogateBase ================ */

ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(
OperatorNodeBase *opr) {
OperatorNodeBase* opr) {
using ProfFlag = OperatorNodeBase::NodeProp::Flag;
auto &&info = m_oprinfo[opr];
auto&& info = m_oprinfo[opr];
if (info.processed)
return info.result;
info.processed = true;
@@ -819,15 +815,14 @@ ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(
if (opr->input().empty())
return make_ret();

if (opr->node_prop().contain(
ProfFlag::FORCE_UPDATE_INPUT_VAR |
ProfFlag::IMPURE_FUNC)) {
if (opr->node_prop().contain(ProfFlag::FORCE_UPDATE_INPUT_VAR |
ProfFlag::IMPURE_FUNC)) {
return make_ret();
}

size_t max_input_size = 0;
ret.all_const_inp = true;
for (auto i: opr->input()) {
for (auto i : opr->input()) {
auto io = i->owner_opr();
auto iter = m_oprinfo.find(io);
if (iter == m_oprinfo.end()) {
@@ -835,7 +830,7 @@ ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr(
iter = m_oprinfo.find(io);
mgb_assert(iter != m_oprinfo.end());
}
auto &&src = iter->second;
auto&& src = iter->second;
if (src.is_const) {
update_max(max_input_size, src.max_size);
ret.has_const_inp = true;


+ 241
- 0
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp View File

@@ -19,6 +19,7 @@
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/opr/imgproc.h"

using namespace mgb;
using namespace gopt;
@@ -443,4 +444,244 @@ void FuseNCHW4Int8Preprocess::apply(OptState& state) const {
};
state.graph().iter(on_opr);
rewriter.apply_inplace();
}

/* ==================== FuseWarpPerspectiveDimshufflePass ================= */
const char* FuseWarpPerspectiveDimshufflePass::name() const {
return mgb_cstr_log("Fuse warp perspective dimshuffle pass");
}

void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const {
auto rewriter = opt.graph().make_rewriter();
auto uniq_reader_check = UniqReaderCheck{opt.graph()};

auto make_new_warp = [&rewriter](opr::WarpPerspective* warp,
opr::WarpPerspective::Param new_param,
megdnn::DType dst_dtype,
SymbolVar& new_warp) {
OperatorNodeConfig new_config(dst_dtype);
if (warp->input().size() == 3) {
auto src = rewriter.get_var(warp->input(0)),
mat = rewriter.get_var(warp->input(1)),
out_shape = rewriter.get_var(warp->input(2));
new_warp = opr::WarpPerspective::make(src, mat, out_shape,
new_param, new_config);
} else {
mgb_assert(warp->input().size() == 4);
auto src = rewriter.get_var(warp->input(0)),
mat = rewriter.get_var(warp->input(1)),
mat_idx = rewriter.get_var(warp->input(2)),
out_shape = rewriter.get_var(warp->input(3));
new_warp = opr::WarpPerspective::make(src, mat, mat_idx, out_shape,
new_param, new_config);
}
};

auto is_warp_nchw = [&uniq_reader_check](OperatorNodeBase* bottom_opr,
OperatorNodeBase*& top_opr) {
// check warp
auto warp = try_cast_as_op<opr::WarpPerspective>(bottom_opr);
if (warp == nullptr)
return false;
auto inp_dtype = warp->input(0)->dtype();
bool is_u8_or_qu8 = inp_dtype.enumv() == DTypeEnum::Quantized8Asymm ||
inp_dtype.enumv() == DTypeEnum::Uint8;

bool is_nchw = warp->param().format ==
megdnn::param::WarpPerspective::Format::NCHW;
if (!(is_u8_or_qu8 && is_nchw))
return false;
if (!uniq_reader_check(warp->input(0)))
return false;

top_opr = warp;
return true;
};

auto is_warp_nhwc2nchw = [&uniq_reader_check](OperatorNodeBase* bottom_opr,
OperatorNodeBase*& top_opr) {
// check shuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(bottom_opr);
if (shuffle == nullptr)
return false;
auto&& shuffle_param = shuffle->param();
if (shuffle_param.pattern_len != 4)
return false;
bool is_nhwc2nchw = shuffle_param.pattern[0] == 0 &&
shuffle_param.pattern[1] == 3 &&
shuffle_param.pattern[2] == 1 &&
shuffle_param.pattern[3] == 2;
if (!is_nhwc2nchw)
return false;
if (!uniq_reader_check(shuffle->input(0)))
return false;

// check warp
auto warp = try_cast_as_op<opr::WarpPerspective>(
shuffle->input(0)->owner_opr());
if (warp == nullptr)
return false;
auto inp_dtype = warp->input(0)->dtype();
bool is_u8_or_qu8 = inp_dtype.enumv() == DTypeEnum::Quantized8Asymm ||
inp_dtype.enumv() == DTypeEnum::Uint8;
bool is_nhwc = warp->param().format ==
megdnn::param::WarpPerspective::Format::NHWC;
if (!(is_u8_or_qu8 && is_nhwc))
return false;

top_opr = warp;
return true;
};

auto try_warp_nchw_typecvt = [&rewriter, &uniq_reader_check, &is_warp_nchw,
&make_new_warp](OperatorNodeBase* opr) {
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
if (typecvt == nullptr)
return false;
bool is_to_f32 =
typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (!is_to_f32)
return false;
if (!uniq_reader_check(typecvt->input(0)))
return false;

OperatorNodeBase* top_opr = nullptr;
if (!is_warp_nchw(typecvt->input(0)->owner_opr(), top_opr))
return false;
auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr);
SymbolVar new_warp;
make_new_warp(warp, warp->param(), opr->output()[0]->dtype(), new_warp);

rewriter.replace_var(opr->output(0), new_warp.node(),
mgb_cstr_log("replace warp + typecvt"
"fuse warp_dimshuffle(NCHW)"));

return true;
};

auto try_warp_nhwc2nchw_typecvt = [&rewriter, &uniq_reader_check,
&is_warp_nhwc2nchw,
&make_new_warp](OperatorNodeBase* opr) {
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
if (typecvt == nullptr)
return false;
bool is_to_f32 =
typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (!is_to_f32)
return false;
if (!uniq_reader_check(typecvt->input(0)))
return false;

OperatorNodeBase* top_opr = nullptr;
if (!is_warp_nhwc2nchw(typecvt->input(0)->owner_opr(), top_opr))
return false;
auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr);
opr::WarpPerspective::Param new_param = warp->param();
new_param.format = megdnn::param::WarpPerspective::Format::NHWC_NCHW;
SymbolVar new_warp;
make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp);

rewriter.replace_var(
opr->output(0), new_warp.node(),
mgb_cstr_log("replace conv_bias + dimshuffle + "
"typecvt to warp_dimshuffle(NHWC_NCHW)"));

return true;
};

auto try_warp_nhwc2nchw4_typecvt = [&rewriter, &uniq_reader_check,
&is_warp_nhwc2nchw,
&make_new_warp](OperatorNodeBase* opr) {
// check relayout
auto relayout = try_cast_as_op<opr::RelayoutFormat>(opr);
if (relayout == nullptr)
return false;
bool is_to_q8 =
relayout->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
bool is_to_nchw2nchw4 = relayout->param().mode ==
opr::RelayoutFormat::Param::Mode::NCHW_NCHW4;
if (!(is_to_q8 && is_to_nchw2nchw4))
return false;
if (!uniq_reader_check(relayout->input(0)))
return false;

OperatorNodeBase* top_opr = nullptr;
if (!is_warp_nhwc2nchw(relayout->input(0)->owner_opr(), top_opr))
return false;

auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr);

bool is_small_chn = warp->input(0)->shape()[3] < 4;
if (!is_small_chn)
return false;

opr::WarpPerspective::Param new_param = warp->param();
new_param.format =
megdnn::param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL;

SymbolVar new_warp;
make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp);

rewriter.replace_var(
opr->output(0), new_warp.node(),
mgb_cstr_log("replace warp + dimshuffle + relayout(NCHW_NCHW4)"
"to warp_dimshuffle(NHWC_NCHW4_IC_SMALL)"));

return true;
};

auto try_warp_nchw2nchw4_typecvt = [&rewriter, &uniq_reader_check,
&is_warp_nchw,
&make_new_warp](OperatorNodeBase* opr) {
// check relayout
auto relayout = try_cast_as_op<opr::RelayoutFormat>(opr);
if (relayout == nullptr)
return false;
bool is_to_q8 =
relayout->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
bool is_to_nchw2nchw4 = relayout->param().mode ==
opr::RelayoutFormat::Param::Mode::NCHW_NCHW4;
if (!(is_to_q8 && is_to_nchw2nchw4))
return false;
if (!uniq_reader_check(relayout->input(0)))
return false;

OperatorNodeBase* top_opr = nullptr;
if (!is_warp_nchw(relayout->input(0)->owner_opr(), top_opr))
return false;

auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr);

bool is_small_chn = warp->input(0)->shape()[1] < 4;
if (!is_small_chn)
return false;

opr::WarpPerspective::Param new_param = warp->param();
new_param.format =
megdnn::param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL;

SymbolVar new_warp;
make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp);

rewriter.replace_var(
opr->output(0), new_warp.node(),
mgb_cstr_log("replace warp + relayout(NCHW_NCHW4)"
"to warp_dimshuffle(NCHW_NCHW4_IC_SMALL)"));

return true;
};

auto on_opr = [&try_warp_nchw_typecvt, &try_warp_nhwc2nchw_typecvt,
&try_warp_nhwc2nchw4_typecvt, &try_warp_nchw2nchw4_typecvt,
&rewriter](OperatorNodeBase* opr) {
if (!try_warp_nchw_typecvt(opr) && !try_warp_nhwc2nchw_typecvt(opr) &&
!try_warp_nhwc2nchw4_typecvt(opr) &&
!try_warp_nchw2nchw4_typecvt(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
}

+ 10
- 0
src/gopt/include/megbrain/gopt/inference.h View File

@@ -173,6 +173,16 @@ namespace gopt {
};

/*!
* \brief fuse warp perspective and dimshuffle, quint8/uint8 to qint8/float
*/
class FuseWarpPerspectiveDimshufflePass : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
};


/*!
* \brief fuse deconv and typecvt to a deconv opr
*/
class FuseDeconvCvtPass : public Pass {


+ 62
- 1
src/gopt/test/inference.cpp View File

@@ -1172,7 +1172,8 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) {
param.pad_h = param.pad_w = 1;
auto w2 = mkcvar("w2", {4, 4, 3, 3}),
y = opr::Convolution::make(elem, w2, param),
z = opr::AxisAddRemove::make(y, {opr::AxisAddRemove::AxisDesc::make_add(0)});
z = opr::AxisAddRemove::make(
y, {opr::AxisAddRemove::AxisDesc::make_add(0)});

SymbolVar y_opt, z_opt;
auto options = gopt::OptimizeForInferenceOptions{};
@@ -3722,5 +3723,65 @@ TEST(TestGoptInference, PreProcessCase1) {

ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::RelayoutFormat>());
}

TEST(TestGoptInference, WarpAndPreProcessCase) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("gpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;

size_t n = 1;
size_t c = 3;
size_t h = 16;
size_t w = 16;
auto host_x1 = gen({n, h, w, c}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);

auto mat_host = std::make_shared<HostTensorND>(cn, TensorShape{n, 3, 3},
dtype::Float32());
warp_perspective_mat_gen(*mat_host, n, h, w);
auto mat = opr::Host2DeviceCopy::make(*graph, mat_host).rename("mat");

opr::WarpPerspective::Param warp_param;
warp_param.format = opr::WarpPerspective::Param::Format::NHWC;
auto x_warp =
opr::WarpPerspective::make(x, mat, TensorShape{h, w}, warp_param);
auto x_nchw = opr::Dimshuffle::make(x_warp, {0, 3, 1, 2}, 4, cn);

auto x_u8 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn);
auto x_s8 = x_u8 - 128;
auto zero = DTypeScalar(dtype::Float32());
auto zero_tensor = opr::ImmutableTensor::make(*graph, zero, cn);
auto pad_channel_tensor =
opr::Broadcast::make(zero_tensor, {n, 1, h, w}, cn);
auto paded_x = opr::Concat::make({x_s8, pad_channel_tensor}, 1, cn)
.reshape({n, 1, 4, h, w});

auto nchw4_out = opr::Dimshuffle::make(paded_x, {0, 1, 3, 4, 2}, 5, cn);
auto result = opr::TypeCvt::make(nchw4_out, dtype::QuantizedS8(1.f));

auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_preprocess();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);

ASSERT_TRUE(y_opt.node()->owner_opr()->same_type<opr::WarpPerspective>());

ASSERT_EQ(opr::WarpPerspective::Param::Format::NHWC_NCHW4_IC_SMALL,
find_opr<opr::WarpPerspective>(y_opt).param().format);

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.WarpAndPreProcessCase.json"));

HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 39
- 18
src/opr/impl/imgproc.cpp View File

@@ -47,7 +47,11 @@ SymbolVar WarpPerspectiveForward::make(SymbolVar i0, SymbolVar i1, SymbolVar i2,
}

void WarpPerspectiveForward::init_output_dtype() {
output(0)->dtype(input(0)->dtype());
if (config().output_dtype().valid()) {
output(0)->dtype(config().output_dtype());
} else {
output(0)->dtype(input(0)->dtype());
}
}

void WarpPerspectiveForward::add_input_layout_constraint() {
@@ -78,23 +82,40 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape(
mat_idx_shp.to_string().c_str());
}

//! The index of height, e.g.,[b, h, w, c], the height_idx = 1
size_t height_idx = 0;
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4) {
height_idx = 2;
} else {
height_idx = 1;
}

dest = imgshp;
dest[0] = matshp[0];
if (param().format == Param::Format::NHWCD4) {
dest.shape[height_idx] = oshp2d.shape[0];
dest.shape[height_idx + 2] = oshp2d.shape[1];
} else {
for (int i = 0; i < 2; ++i)
dest.shape[height_idx + i] = oshp2d.shape[i];
switch (param().format) {
case Param::Format::NCHW_NCHW4_IC_SMALL:
case Param::Format::NHWC_NCHW4_IC_SMALL:
dest.ndim = 5;
dest[0] = matshp[0];
dest.shape[1] = 1;
dest.shape[2] = oshp2d.shape[0];
dest.shape[3] = oshp2d.shape[1];
dest.shape[4] = 4;
break;
case Param::Format::NHWC_NCHW:
dest[0] = matshp[0];
dest.shape[1] = imgshp.shape[3];
dest.shape[2] = oshp2d.shape[0];
dest.shape[3] = oshp2d.shape[1];
break;
default:
size_t height_idx = 0;
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4) {
height_idx = 2;
} else {
height_idx = 1;
}
dest = imgshp;
dest[0] = matshp[0];
if (param().format == Param::Format::NHWCD4) {
dest.shape[height_idx] = oshp2d.shape[0];
dest.shape[height_idx + 2] = oshp2d.shape[1];
} else {
for (int i = 0; i < 2; ++i)
dest.shape[height_idx + i] = oshp2d.shape[i];
}
break;
}
}



Loading…
Cancel
Save