diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 85eb3ca6..63540ed9 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -43,6 +43,12 @@ pdef('Axis').add_fields('int32', 'axis', 0) Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), + Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' + 'output tensor is nchw layout'), + Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' + 'output tensor is nchw4 layout, padding c=4'), + Doc('NCHW_NCHW4_IC_SMALL', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, ' + 'output tensor is nchw4 layout, padding c=4'), Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) ) diff --git a/dnn/src/common/warp_perspective.cpp b/dnn/src/common/warp_perspective.cpp index a0d558cb..5c96930f 100644 --- a/dnn/src/common/warp_perspective.cpp +++ b/dnn/src/common/warp_perspective.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megdnn/oprs.h" @@ -14,20 +15,17 @@ namespace megdnn { -void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &dst) -{ +void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, + const TensorLayout& mat, + const TensorLayout& mat_idx, + const TensorLayout& dst) { megdnn_assert_contiguous(mat); megdnn_assert_contiguous(src); megdnn_assert_contiguous(dst); auto errmsg = [&]() { - return megdnn_layout_msg(src) + ", " + - megdnn_layout_msg(mat) + ", " + - megdnn_layout_msg(mat_idx) + ", " + - megdnn_layout_msg(dst) + ", " + - param_msg(); + return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(mat) + ", " + + megdnn_layout_msg(mat_idx) + ", " + megdnn_layout_msg(dst) + + ", " + param_msg(); }; MEGDNN_MARK_USED_VAR(errmsg); if (param().format == param::WarpPerspective::Format::NHWCD4 || @@ -35,9 +33,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str()); megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); + } else if (param().format == + param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL || + param().format == + param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) { + megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); + megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); } else { megdnn_assert(param().format == param::WarpPerspective::Format::NHWC || - param().format == param::WarpPerspective::Format::NCHW); + param().format == param::WarpPerspective::Format::NCHW || + param().format == + param::WarpPerspective::Format::NHWC_NCHW); megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str()); } @@ -45,7 +51,7 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str()); if (mat_idx.ndim) { megdnn_assert(mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1, - "%s", errmsg().c_str()); + "%s", errmsg().c_str()); megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str()); megdnn_assert_contiguous(mat_idx); } else { @@ -54,35 +60,103 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str()); megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); - if (param().format == param::WarpPerspective::Format::NCHW) { - megdnn_assert( - src.dtype.enumv() == DTypeEnum::Float32 || - MEGDNN_FLOAT16_SELECT( - (src.dtype.enumv() == DTypeEnum::Float16 || - src.dtype.enumv() == DTypeEnum::BFloat16), - false) || - src.dtype.enumv() == DTypeEnum::Int8 || - src.dtype.enumv() == DTypeEnum::Uint8 || - (src.dtype.enumv() == DTypeEnum::QuantizedS8 || - src.dtype.enumv() == DTypeEnum::Quantized8Asymm), - "WarpPerspective NCHW input dtype should be " - "Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( - "/Float16/BFloat16", "") "."); - megdnn_assert( - (src.dtype.category() == DTypeCategory::FLOAT && - (src.dtype == mat.dtype || - mat.dtype.enumv() == DTypeEnum::Float32)) || - ((src.dtype.category() == DTypeCategory::INT || - src.dtype.category() == DTypeCategory::QUANTIZED) && - mat.dtype.enumv() == DTypeEnum::Float32), - "The input to WarpPerspective is in NCHW format, in this " - "case, if the input dtype is floating point, the " - "transformation matrix should have same dtype as the " - "input, otherwise, it should be in Float32, %s given.", - mat.dtype.name()); + if (src.format == dst.format && dst.dtype == src.dtype) { + if (param().format == param::WarpPerspective::Format::NCHW) { + megdnn_assert( + src.dtype.enumv() == DTypeEnum::Float32 || + MEGDNN_FLOAT16_SELECT( + (src.dtype.enumv() == DTypeEnum::Float16 || + src.dtype.enumv() == DTypeEnum::BFloat16), + false) || + src.dtype.enumv() == DTypeEnum::Int8 || + src.dtype.enumv() == DTypeEnum::Uint8 || + (src.dtype.enumv() == DTypeEnum::QuantizedS8 || + src.dtype.enumv() == DTypeEnum::Quantized8Asymm), + "WarpPerspective NCHW input dtype should be " + "Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( + "/Float16/BFloat16", "") "."); + megdnn_assert( + (src.dtype.category() == DTypeCategory::FLOAT && + (src.dtype == mat.dtype || + mat.dtype.enumv() == DTypeEnum::Float32)) || + ((src.dtype.category() == DTypeCategory::INT || + src.dtype.category() == + DTypeCategory::QUANTIZED) && + mat.dtype.enumv() == DTypeEnum::Float32), + "The input to WarpPerspective is in NCHW format, in this " + "case, if the input dtype is floating point, the " + "transformation matrix should have same dtype as the " + "input, otherwise, it should be in Float32, %s given.", + mat.dtype.name()); + + megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); - megdnn_assert(dst.dtype == src.dtype); - megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); + megdnn_assert(param().imode == + param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert(param().bmode != + param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert(param().bmode != + param::WarpPerspective::BorderMode::ISOLATED); + } else if (param().format == param::WarpPerspective::Format::NHWC) { + megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str()); + } else if (param().format == param::WarpPerspective::Format::NCHW4) { + megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8, + "src expected QuantizedS8, but got %s", + src.dtype.name()); + megdnn_assert(mat.dtype == dtype::Float32(), + "matrix dtype expected float, got %s", + mat.dtype.name()); + megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4); + megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); + + megdnn_assert(param().imode == + param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert(param().bmode != + param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert(param().bmode != + param::WarpPerspective::BorderMode::ISOLATED); + } else { + megdnn_assert(param().format == + param::WarpPerspective::Format::NHWCD4); + megdnn_assert( + src.dtype == dtype::Float32() || + MEGDNN_FLOAT16_SELECT( + (src.dtype == dtype::Float16() || + src.dtype == dtype::BFloat16()), + false) || + src.dtype.enumv() == DTypeEnum::QuantizedS8 || + src.dtype.enumv() == DTypeEnum::Quantized8Asymm, + "WarpPerspective NHWCD4 input dtype should be " + "Float32" MEGDNN_FLOAT16_SELECT( + "/Float16/BFloat16", + "") ",QunatizedS8, Quantized8Asymm."); + megdnn_assert( + (src.dtype == mat.dtype || mat.dtype == dtype::Float32()), + "The input to WarpPerspective is in NHWCD4 format, in this " + "case, if the input dtype is floating point, the " + "transformation matrix should have same dtype as the " + "input, %s given.", + mat.dtype.name()); + //! number of channels is same + megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); + megdnn_assert(param().imode == + param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert(param().bmode != + param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert(param().bmode != + param::WarpPerspective::BorderMode::ISOLATED); + } + } else if (param().format == + param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL || + param().format == + param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) { + megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || + src.dtype.enumv() == DTypeEnum::Uint8), + "src expected Quantized8Asymm or Uint8, but got %s", + src.dtype.name()); + megdnn_assert(mat.dtype == dtype::Float32(), + "matrix dtype expected float, got %s", mat.dtype.name()); + megdnn_assert(dst.shape[4] == 4); megdnn_assert(param().imode == param::WarpPerspective::InterpolationMode::LINEAR); @@ -90,16 +164,14 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, param::WarpPerspective::BorderMode::TRANSPARENT); megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::ISOLATED); - } else if (param().format == param::WarpPerspective::Format::NHWC) { - megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str()); - } else if (param().format == param::WarpPerspective::Format::NCHW4) { - megdnn_assert(dst.dtype == src.dtype); - megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8, - "src expected QuantizedS8, but got %s", src.dtype.name()); + } else if (param().format == param::WarpPerspective::Format::NHWC_NCHW) { + megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || + src.dtype.enumv() == DTypeEnum::Uint8), + "src expected Quantized8Asymm or Uint8, but got %s", + src.dtype.name()); megdnn_assert(mat.dtype == dtype::Float32(), "matrix dtype expected float, got %s", mat.dtype.name()); - megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4); - megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); + megdnn_assert(src.shape[3] == dst.shape[1], "%s", errmsg().c_str()); megdnn_assert(param().imode == param::WarpPerspective::InterpolationMode::LINEAR); @@ -108,40 +180,14 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::ISOLATED); } else { - megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); - megdnn_assert( - src.dtype == dtype::Float32() || - MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() || - src.dtype == dtype::BFloat16()), - false) || - src.dtype.enumv() == DTypeEnum::QuantizedS8 || - src.dtype.enumv() == DTypeEnum::Quantized8Asymm, - "WarpPerspective NHWCD4 input dtype should be " - "Float32" MEGDNN_FLOAT16_SELECT( - "/Float16/BFloat16", - "") ",QunatizedS8, Quantized8Asymm."); - megdnn_assert( - (src.dtype == mat.dtype || mat.dtype == dtype::Float32()), - "The input to WarpPerspective is in NHWCD4 format, in this " - "case, if the input dtype is floating point, the " - "transformation matrix should have same dtype as the " - "input, %s given.", - mat.dtype.name()); - megdnn_assert(dst.dtype == src.dtype); - //! number of channels is same - megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert(param().format == param::WarpPerspective::Format::NCHW); + megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || + src.dtype.enumv() == DTypeEnum::Uint8) && + dst.dtype.enumv() == DTypeEnum::Float32); } - megdnn_assert(src.format == dst.format); } -std::string WarpPerspectiveBase::param_msg() const -{ +std::string WarpPerspectiveBase::param_msg() const { std::string res; res.append(megdnn_mangle("imode=")); switch (param().imode) { @@ -191,31 +237,25 @@ std::string WarpPerspectiveBase::param_msg() const return res; } -int WarpPerspectiveBase::get_real_coord(int p, int len) -{ +int WarpPerspectiveBase::get_real_coord(int p, int len) { auto bmode = param().bmode; - if( (unsigned)p < (unsigned)len ) + if ((unsigned)p < (unsigned)len) ; - else if( bmode == BorderMode::REPLICATE ) + else if (bmode == BorderMode::REPLICATE) p = p < 0 ? 0 : len - 1; - else if( bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101 ) - { + else if (bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101) { int delta = (bmode == BorderMode::REFLECT_101); - if( len == 1 ) + if (len == 1) return 0; - do - { - if( p < 0 ) + do { + if (p < 0) p = -p - 1 + delta; else p = len - 1 - (p - len) - delta; - } - while( (unsigned)p >= (unsigned)len ); - } - else if( bmode == BorderMode::WRAP ) - { - if( p < 0 ) - p -= ((p-len+1)/len)*len; + } while ((unsigned)p >= (unsigned)len); + } else if (bmode == BorderMode::WRAP) { + if (p < 0) + p -= ((p - len + 1) / len) * len; /* if( p >= len ) p %= len; @@ -223,18 +263,16 @@ int WarpPerspectiveBase::get_real_coord(int p, int len) while (p >= len) { p -= len; } - } - else if( bmode == BorderMode::CONSTANT ) + } else if (bmode == BorderMode::CONSTANT) p = -1; return p; } -void WarpPerspectiveForward::check_exec(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void WarpPerspectiveForward::check_exec(const TensorLayout& src, + const TensorLayout& mat, + const TensorLayout& mat_idx, + const TensorLayout& dst, + size_t workspace_in_bytes) { check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes); } @@ -248,7 +286,10 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); if (param().format != Param::Format::NHWC && param().format != Param::Format::NCHW && - param().format != Param::Format::NCHW4) { + param().format != Param::Format::NCHW4 && + param().format != Param::Format::NHWC_NCHW && + param().format != Param::Format::NHWC_NCHW4_IC_SMALL && + param().format != Param::Format::NCHW_NCHW4_IC_SMALL) { megdnn_assert(!mat_idx.ndim, "mat_idx not supported for current format"); } @@ -263,7 +304,8 @@ void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat, megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( || grad.dtype == dtype::BFloat16()), "Backward WarpPerspective only supports Float32/BFloat16."); - auto required_workspace_in_bytes = get_workspace_in_bytes(mat, mat_idx, diff, grad); + auto required_workspace_in_bytes = + get_workspace_in_bytes(mat, mat_idx, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } @@ -283,6 +325,6 @@ void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/warp_perspective/common.h b/dnn/src/cuda/warp_perspective/common.h index 1687238c..982d39f2 100644 --- a/dnn/src/cuda/warp_perspective/common.h +++ b/dnn/src/cuda/warp_perspective/common.h @@ -12,6 +12,7 @@ #pragma once #include #include "src/common/cv/enums.h" +#include "src/cuda/utils.cuh" #include "megcore_cdefs.h" namespace megdnn { @@ -34,6 +35,22 @@ void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx, megcore::AsyncErrorInfo* error_info, void* error_tracker, cudaStream_t stream); +template +void forward_proxy_quint8_dimshuffle_typecvt_nchw4( + bool is_nhwc, const src_ctype* src, const float* mat, + const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH, + int IW, int OH, int OW, src_ctype bval, DTypeParamImpl param, + BorderMode bmode, megcore::AsyncErrorInfo* error_info, + void* error_tracker, cudaStream_t stream); + +template +void forward_proxy_quint8_dimshuffle_typecvt_nchw( + bool is_nhwc, const src_ctype* src, const float* mat, + const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH, + int IW, int OH, int OW, src_ctype bval, DTypeParamImpl param, + BorderMode bmode, megcore::AsyncErrorInfo* error_info, + void* error_tracker, cudaStream_t stream); + void backward_data_proxy(const float* mat, const int* midx, const float* diff, float* grad, float* workspace, int N, int N_SRC, int C, int IH, int IW, int OH, int OW, float bval, diff --git a/dnn/src/cuda/warp_perspective/forward.cpp b/dnn/src/cuda/warp_perspective/forward.cpp index 52823132..0f76cfe6 100644 --- a/dnn/src/cuda/warp_perspective/forward.cpp +++ b/dnn/src/cuda/warp_perspective/forward.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/cuda/warp_perspective/opr_impl.h" #include "src/cuda/warp_perspective/warp_perspective_cv.cuh" @@ -166,6 +167,30 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, IW = src.layout.shape[3]; OH = dst.layout.shape[2]; OW = dst.layout.shape[3]; + } else if (param().format == Param::Format::NHWC_NCHW) { + C = src.layout.shape[3]; + IH = src.layout.shape[1]; + IW = src.layout.shape[2]; + OH = dst.layout.shape[2]; + OW = dst.layout.shape[3]; + } else if (param().format == Param::Format::NHWC_NCHW4_IC_SMALL) { + C = src.layout.shape[3]; + IH = src.layout.shape[1]; + IW = src.layout.shape[2]; + OH = dst.layout.shape[2]; + OW = dst.layout.shape[3]; + megdnn_assert( + (C == 1) || (C == 3), + "NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3"); + } else if (param().format == Param::Format::NCHW_NCHW4_IC_SMALL) { + C = src.layout.shape[1]; + IH = src.layout.shape[2]; + IW = src.layout.shape[3]; + OH = dst.layout.shape[2]; + OW = dst.layout.shape[3]; + megdnn_assert( + (C == 1) || (C == 3), + "NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3"); } else { megdnn_assert( param().format == param::WarpPerspective::Format::NCHW, @@ -180,55 +205,123 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, "unsupported interpolation mode for NCHW format"); auto bval = param().border_val; auto bmode = warp_perspective::get_bmode(param().bmode); - - if (src.layout.dtype == dtype::Float32{}) { - warp_perspective::forward_proxy( - is_nhwc, src.ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.ptr(), src.layout[0], mat.layout[0], C, - IH, IW, OH, OW, bval, bmode, async_error_info(handle()), - m_error_tracker, stream); - } else if (MEGDNN_FLOAT16_SELECT( - src.layout.dtype == dtype::Float16(), false)) { + if (src.layout.dtype == dst.layout.dtype) { + if (src.layout.dtype == dtype::Float32{}) { + warp_perspective::forward_proxy( + is_nhwc, src.ptr(), + mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.ptr(), src.layout[0], mat.layout[0], + C, IH, IW, OH, OW, bval, bmode, + async_error_info(handle()), m_error_tracker, + stream); + } else if (MEGDNN_FLOAT16_SELECT( + src.layout.dtype == dtype::Float16(), + false)) { #ifndef MEGDNN_DISABLE_FLOAT16 - warp_perspective::forward_proxy( - is_nhwc, src.ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.ptr(), src.layout[0], mat.layout[0], C, - IH, IW, OH, OW, static_cast(bval), bmode, - async_error_info(handle()), m_error_tracker, stream); + warp_perspective::forward_proxy( + is_nhwc, src.ptr(), + mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.ptr(), src.layout[0], mat.layout[0], + C, IH, IW, OH, OW, static_cast(bval), + bmode, async_error_info(handle()), m_error_tracker, + stream); #endif - } else if (src.layout.dtype == dtype::Uint8()) { - warp_perspective::forward_proxy( - is_nhwc, src.ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.ptr(), src.layout[0], mat.layout[0], C, - IH, IW, OH, OW, bval, bmode, async_error_info(handle()), - m_error_tracker, stream); - } else if (src.layout.dtype == dtype::Int8()) { - megdnn_assert( - !is_nhwc, - "WarpPerspective on CUDA does not support NHWC + Int8"); - warp_perspective::forward_proxy( - false, src.ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.ptr(), src.layout[0], mat.layout[0], C, IH, - IW, OH, OW, - bval /* implicit float -> int8 conversion, should be - safe */ - , - bmode, async_error_info(handle()), m_error_tracker, - stream); - } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { - megdnn_assert(param().format == Param::Format::NCHW4, - "WarpPerspective on CUDA supports NCHW4 + " - "QuantizedS8 only"); - warp_perspective::forward_proxy_nchw4( - src.compatible_ptr(), mat.ptr(), - mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, - dst.compatible_ptr(), src.layout[0], - mat.layout[0], C, IH, IW, OH, OW, bval, bmode, - async_error_info(handle()), m_error_tracker, stream); + } else if (src.layout.dtype == dtype::Uint8()) { + warp_perspective::forward_proxy( + is_nhwc, src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.ptr(), src.layout[0], mat.layout[0], + C, IH, IW, OH, OW, bval, bmode, + async_error_info(handle()), m_error_tracker, + stream); + } else if (src.layout.dtype == dtype::Int8()) { + megdnn_assert(!is_nhwc, + "WarpPerspective on CUDA does not support " + "NHWC + Int8"); + warp_perspective::forward_proxy( + false, src.ptr(), mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.ptr(), src.layout[0], mat.layout[0], C, + IH, IW, OH, OW, + bval /* implicit float -> int8 conversion, + should be safe */ + , + bmode, async_error_info(handle()), m_error_tracker, + stream); + } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { + megdnn_assert(param().format == Param::Format::NCHW4, + "WarpPerspective on CUDA supports NCHW4 + " + "QuantizedS8 only"); + warp_perspective::forward_proxy_nchw4( + src.compatible_ptr(), + mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, + dst.compatible_ptr(), src.layout[0], + mat.layout[0], C, IH, IW, OH, OW, bval, bmode, + async_error_info(handle()), m_error_tracker, + stream); + } + } else if ((src.layout.dtype.enumv() == + DTypeEnum::Quantized8Asymm || + src.layout.dtype.enumv() == DTypeEnum::Uint8)) { + uint8_t zero_point = 0; + float scale = 1.f; + if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { + zero_point = + src.layout.dtype.param() + .zero_point; + scale = src.layout.dtype.param() + .scale; + } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { + zero_point = 128; + scale = 1.f; + } + DTypeParamImpl src_dtype_param(scale, zero_point); + + if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && + dst.layout.dtype.param().scale == + scale) && + ((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) || + (param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) { + bool is_nhwc_ic_small = + (param().format == + Param::Format::NHWC_NCHW4_IC_SMALL); + warp_perspective:: + forward_proxy_quint8_dimshuffle_typecvt_nchw4< + dt_quint8, dt_uint8, dt_int8>( + is_nhwc_ic_small, + src.compatible_ptr(), + mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() + : nullptr, + dst.compatible_ptr(), + src.layout[0], mat.layout[0], C, IH, IW, OH, + OW, bval, src_dtype_param, bmode, + async_error_info(handle()), m_error_tracker, + stream); + } else { + megdnn_assert( + ((dst.layout.dtype.enumv() == DTypeEnum::Float32) && + ((param().format == Param::Format::NCHW) || + (param().format == Param::Format::NHWC_NCHW))), + "invalid format for Quantized8Asymm input"); + bool is_nhwc = (param().format == Param::Format::NHWC_NCHW); + warp_perspective:: + forward_proxy_quint8_dimshuffle_typecvt_nchw< + dt_quint8, dt_uint8, dt_float32>( + is_nhwc, src.compatible_ptr(), + mat.ptr(), + mat_idx.raw_ptr ? mat_idx.ptr() + : nullptr, + dst.compatible_ptr(), + src.layout[0], mat.layout[0], C, IH, IW, OH, + OW, bval, src_dtype_param, bmode, + async_error_info(handle()), m_error_tracker, + stream); + } } else { megdnn_throw(ssprintf("unsupported dtype: %s", src.layout.dtype.name())); diff --git a/dnn/src/cuda/warp_perspective/forward.cu b/dnn/src/cuda/warp_perspective/forward.cu index a16c7830..6a532f59 100644 --- a/dnn/src/cuda/warp_perspective/forward.cu +++ b/dnn/src/cuda/warp_perspective/forward.cu @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/cuda/warp_perspective/common.h" @@ -23,20 +24,19 @@ using namespace warp_perspective; namespace { -template +template struct DirectSrcVisitor { const ctype* ptr; __device__ __forceinline__ const ctype* get(int batch, int im_size) { - return ptr + static_cast(batch) * static_cast(im_size); + return ptr + + static_cast(batch) * static_cast(im_size); } - void move_batch(size_t batch, size_t im_size) { - ptr += batch * im_size; - } + void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; } }; -template +template struct IndexedSrcVisitor { const ctype* ptr; const int* idx; @@ -49,17 +49,17 @@ struct IndexedSrcVisitor { int orig_batch = batch; batch = idx[batch]; if (batch < 0 || batch >= N_SRC) { - set_async_error_info(error_info, error_tracker, + set_async_error_info( + error_info, error_tracker, "mat_idx out of bound: mat_idx[%d]=%d src_batch=%d", orig_batch, batch, N_SRC); batch = 0; } - return ptr + static_cast(batch) * static_cast(im_size); + return ptr + + static_cast(batch) * static_cast(im_size); } - void move_batch(size_t batch, size_t) { - idx += batch; - } + void move_batch(size_t batch, size_t) { idx += batch; } }; template __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat, - ctype* __restrict dst, int C, int IH, int IW, - int OH, int OW) { + ctype* __restrict dst, int C, int IH, int IW, + int OH, int OW) { Getter getter; OutputConverter output_converter; int ow = blockIdx.x * blockDim.x + threadIdx.x; @@ -142,21 +142,20 @@ __global__ void kern_general_nchw4(SrcVisitor src, const float* __restrict mat, } } -template -__global__ void kern_const_border( - SrcVisitor src, const float *__restrict mat, ctype *__restrict dst, - int C, int IH, int IW, int OH, int OW, ctype bval) -{ +template +__global__ void kern_const_border(SrcVisitor src, const float* __restrict mat, + ctype* __restrict dst, int C, int IH, int IW, + int OH, int OW, ctype bval) { OutputConverter output_converter; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); - dst += blockIdx.z * C*OH*OW; - mat += blockIdx.z * 3*3; + dst += blockIdx.z * C * OH * OW; + mat += blockIdx.z * 3 * 3; if (ow < OW && oh < OH) { - float denominator = mat[6]*ow + mat[7]*oh + mat[8]; - float iw = (mat[0]*ow + mat[1]*oh + mat[2]) / denominator; - float ih = (mat[3]*ow + mat[4]*oh + mat[5]) / denominator; + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; int iw0 = floor(iw) + 0; int iw1 = floor(iw) + 1; int ih0 = floor(ih) + 0; @@ -170,16 +169,16 @@ __global__ void kern_const_border( float nalpha = 1.0f - palpha; float nbeta = 1.0f - pbeta; for (int c = 0; c < C; ++c) { - ctype v00 = (okh0 && okw0 ? sptr[ih0*IW+iw0] : bval); - ctype v01 = (okh0 && okw1 ? sptr[ih0*IW+iw1] : bval); - ctype v10 = (okh1 && okw0 ? sptr[ih1*IW+iw0] : bval); - ctype v11 = (okh1 && okw1 ? sptr[ih1*IW+iw1] : bval); + ctype v00 = (okh0 && okw0 ? sptr[ih0 * IW + iw0] : bval); + ctype v01 = (okh0 && okw1 ? sptr[ih0 * IW + iw1] : bval); + ctype v10 = (okh1 && okw0 ? sptr[ih1 * IW + iw0] : bval); + ctype v11 = (okh1 && okw1 ? sptr[ih1 * IW + iw1] : bval); ctype val = output_converter( - v00*nalpha*nbeta + v01*nalpha*pbeta + - v10*palpha*nbeta + v11*palpha*pbeta); - dst[oh*OW+ow] = val; - sptr += IH*IW; - dst += OH*OW; + v00 * nalpha * nbeta + v01 * nalpha * pbeta + + v10 * palpha * nbeta + v11 * palpha * pbeta); + dst[oh * OW + ow] = val; + sptr += IH * IW; + dst += OH * OW; } } } @@ -268,21 +267,21 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* __restrict mat, } } -template -__global__ void kern_const_border_nhwc( - SrcVisitor src, const float *__restrict mat, ctype *__restrict dst, - int C, int IH, int IW, int OH, int OW, ctype bval) -{ +template +__global__ void kern_const_border_nhwc(SrcVisitor src, + const float* __restrict mat, + ctype* __restrict dst, int C, int IH, + int IW, int OH, int OW, ctype bval) { OutputConverter output_converter; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); - dst += blockIdx.z * C*OH*OW; - mat += blockIdx.z * 3*3; + dst += blockIdx.z * C * OH * OW; + mat += blockIdx.z * 3 * 3; if (ow < OW && oh < OH) { - float denominator = mat[6]*ow + mat[7]*oh + mat[8]; - float iw = (mat[0]*ow + mat[1]*oh + mat[2]) / denominator; - float ih = (mat[3]*ow + mat[4]*oh + mat[5]) / denominator; + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; int iw0 = floor(iw) + 0; int iw1 = floor(iw) + 1; int ih0 = floor(ih) + 0; @@ -296,14 +295,14 @@ __global__ void kern_const_border_nhwc( float nalpha = 1.0f - palpha; float nbeta = 1.0f - pbeta; for (int c = 0; c < C; ++c) { - ctype v00 = (okh0 && okw0 ? sptr[(ih0*IW+iw0)*C+c] : bval); - ctype v01 = (okh0 && okw1 ? sptr[(ih0*IW+iw1)*C+c] : bval); - ctype v10 = (okh1 && okw0 ? sptr[(ih1*IW+iw0)*C+c] : bval); - ctype v11 = (okh1 && okw1 ? sptr[(ih1*IW+iw1)*C+c] : bval); + ctype v00 = (okh0 && okw0 ? sptr[(ih0 * IW + iw0) * C + c] : bval); + ctype v01 = (okh0 && okw1 ? sptr[(ih0 * IW + iw1) * C + c] : bval); + ctype v10 = (okh1 && okw0 ? sptr[(ih1 * IW + iw0) * C + c] : bval); + ctype v11 = (okh1 && okw1 ? sptr[(ih1 * IW + iw1) * C + c] : bval); ctype val = output_converter( - v00*nalpha*nbeta + v01*nalpha*pbeta + - v10*palpha*nbeta + v11*palpha*pbeta); - dst[(oh*OW+ow)*C+c] = val; + v00 * nalpha * nbeta + v01 * nalpha * pbeta + + v10 * palpha * nbeta + v11 * palpha * pbeta); + dst[(oh * OW + ow) * C + c] = val; } } } @@ -424,22 +423,695 @@ void dispatch_with_visitor_nchw4(SrcVisitor src, const float* mat, ctype* dst, } } -} // anonymous namespace +template +struct CudaTypeCvt; + +template <> +struct CudaTypeCvt { + CudaDTypeParamImpl m_src_param; + CudaTypeCvt(CudaDTypeParamImpl src_param) { + m_src_param = src_param; + }; + inline __device__ int8_t operator()(uint8_t val) { + return val - m_src_param.zero_point; + } +}; + +template <> +struct CudaTypeCvt { + CudaDTypeParamImpl m_src_param; + CudaTypeCvt(CudaDTypeParamImpl src_param) { + m_src_param = src_param; + }; + __device__ __forceinline__ float operator()(uint8_t val) { + return m_src_param.dequantize(dt_quint8(val)); + } +}; + +#define INST(dst_ctype, vec_dst_type) \ + template \ + __global__ void kern_general_quint8_nhw_nchw4( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int IH, int IW, int OH, int OW, \ + CudaTypeCvt type_cvt) { \ + Getter getter; \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, IH * IW); \ + dst += blockIdx.z * OH * OW * 4; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = getter(floor(iw) + 0, IW); \ + int iw1 = getter(floor(iw) + 1, IW); \ + int ih0 = getter(floor(ih) + 0, IH); \ + int ih1 = getter(floor(ih) + 1, IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + vec_dst_type result; \ + src_ctype val_x = \ + warp_out_converter(sptr[ih0 * IW + iw0] * nalpha * nbeta + \ + sptr[ih0 * IW + iw1] * nalpha * pbeta + \ + sptr[ih1 * IW + iw0] * palpha * nbeta + \ + sptr[ih1 * IW + iw1] * palpha * pbeta); \ + result.x = type_cvt(val_x); \ + result.y = result.z = result.w = 0; \ + *((vec_dst_type*)dst + oh * OW + ow) = result; \ + } \ + } + +INST(int8_t, char4) +#undef INST + +#define INST(dst_ctype, vec_dst_type) \ + template \ + __global__ void kern_const_border_quint8_nhw_nchw4( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int IH, int IW, int OH, int OW, \ + src_ctype bval, CudaTypeCvt type_cvt) { \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, IH * IW); \ + dst += blockIdx.z * OH * OW * 4; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = floor(iw) + 0; \ + int iw1 = floor(iw) + 1; \ + int ih0 = floor(ih) + 0; \ + int ih1 = floor(ih) + 1; \ + bool okw0 = (iw0 >= 0 && iw0 < IW); \ + bool okw1 = (iw1 >= 0 && iw1 < IW); \ + bool okh0 = (ih0 >= 0 && ih0 < IH); \ + bool okh1 = (ih1 >= 0 && ih1 < IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + vec_dst_type result; \ + src_ctype v00 = (okh0 && okw0 ? sptr[ih0 * IW + iw0] : bval); \ + src_ctype v01 = (okh0 && okw1 ? sptr[ih0 * IW + iw1] : bval); \ + src_ctype v10 = (okh1 && okw0 ? sptr[ih1 * IW + iw0] : bval); \ + src_ctype v11 = (okh1 && okw1 ? sptr[ih1 * IW + iw1] : bval); \ + src_ctype val_x = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + result.x = type_cvt(val_x); \ + result.y = result.z = result.w = 0; \ + *((vec_dst_type*)dst + oh * OW + ow) = result; \ + } \ + } + +INST(int8_t, char4) +#undef INST + +#define INST(dst_ctype, vec_dst_type) \ + template \ + __global__ void kern_general_quint8_n3hw_nchw4( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int IH, int IW, int OH, int OW, \ + CudaTypeCvt type_cvt) { \ + Getter getter; \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, 3 * IH * IW); \ + dst += blockIdx.z * OH * OW * 4; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = getter(floor(iw) + 0, IW); \ + int iw1 = getter(floor(iw) + 1, IW); \ + int ih0 = getter(floor(ih) + 0, IH); \ + int ih1 = getter(floor(ih) + 1, IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + vec_dst_type result; \ + src_ctype val_x = \ + warp_out_converter(sptr[ih0 * IW + iw0] * nalpha * nbeta + \ + sptr[ih0 * IW + iw1] * nalpha * pbeta + \ + sptr[ih1 * IW + iw0] * palpha * nbeta + \ + sptr[ih1 * IW + iw1] * palpha * pbeta); \ + src_ctype val_y = warp_out_converter( \ + sptr[IW * IH + ih0 * IW + iw0] * nalpha * nbeta + \ + sptr[IW * IH + ih0 * IW + iw1] * nalpha * pbeta + \ + sptr[IW * IH + ih1 * IW + iw0] * palpha * nbeta + \ + sptr[IW * IH + ih1 * IW + iw1] * palpha * pbeta); \ + src_ctype val_z = warp_out_converter( \ + sptr[2 * IW * IH + ih0 * IW + iw0] * nalpha * nbeta + \ + sptr[2 * IW * IH + ih0 * IW + iw1] * nalpha * pbeta + \ + sptr[2 * IW * IH + ih1 * IW + iw0] * palpha * nbeta + \ + sptr[2 * IW * IH + ih1 * IW + iw1] * palpha * pbeta); \ + result.x = type_cvt(val_x); \ + result.y = type_cvt(val_y); \ + result.z = type_cvt(val_z); \ + result.w = 0; \ + *((vec_dst_type*)dst + oh * OW + ow) = result; \ + } \ + } + +INST(int8_t, char4) +#undef INST + +#define INST(dst_ctype, vec_dst_type) \ + template \ + __global__ void kern_const_border_quint8_n3hw_nchw4( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int IH, int IW, int OH, int OW, \ + src_ctype bval, CudaTypeCvt type_cvt) { \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, 3 * IH * IW); \ + dst += blockIdx.z * OH * OW * 4; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = floor(iw) + 0; \ + int iw1 = floor(iw) + 1; \ + int ih0 = floor(ih) + 0; \ + int ih1 = floor(ih) + 1; \ + bool okw0 = (iw0 >= 0 && iw0 < IW); \ + bool okw1 = (iw1 >= 0 && iw1 < IW); \ + bool okh0 = (ih0 >= 0 && ih0 < IH); \ + bool okh1 = (ih1 >= 0 && ih1 < IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + vec_dst_type result; \ + src_ctype v00, v01, v10, v11; \ + v00 = (okh0 && okw0 ? sptr[ih0 * IW + iw0] : bval); \ + v01 = (okh0 && okw1 ? sptr[ih0 * IW + iw1] : bval); \ + v10 = (okh1 && okw0 ? sptr[ih1 * IW + iw0] : bval); \ + v11 = (okh1 && okw1 ? sptr[ih1 * IW + iw1] : bval); \ + src_ctype val_x = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + v00 = (okh0 && okw0 ? sptr[IH * IW + ih0 * IW + iw0] : bval); \ + v01 = (okh0 && okw1 ? sptr[IH * IW + ih0 * IW + iw1] : bval); \ + v10 = (okh1 && okw0 ? sptr[IH * IW + ih1 * IW + iw0] : bval); \ + v11 = (okh1 && okw1 ? sptr[IH * IW + ih1 * IW + iw1] : bval); \ + src_ctype val_y = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + v00 = (okh0 && okw0 ? sptr[2 * IH * IW + ih0 * IW + iw0] : bval); \ + v01 = (okh0 && okw1 ? sptr[2 * IH * IW + ih0 * IW + iw1] : bval); \ + v10 = (okh1 && okw0 ? sptr[2 * IH * IW + ih1 * IW + iw0] : bval); \ + v11 = (okh1 && okw1 ? sptr[2 * IH * IW + ih1 * IW + iw1] : bval); \ + src_ctype val_z = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + result.x = type_cvt(val_x); \ + result.y = type_cvt(val_y); \ + result.z = type_cvt(val_z); \ + result.w = 0; \ + *((vec_dst_type*)dst + oh * OW + ow) = result; \ + } \ + } + +INST(int8_t, char4) +#undef INST + +#define INST(dst_ctype, vec_dst_type) \ + template \ + __global__ void kern_general_quint8_nhw3_nchw4( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int IH, int IW, int OH, int OW, \ + CudaTypeCvt type_cvt) { \ + Getter getter; \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, 3 * IH * IW); \ + dst += blockIdx.z * OH * OW * 4; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = getter(floor(iw) + 0, IW); \ + int iw1 = getter(floor(iw) + 1, IW); \ + int ih0 = getter(floor(ih) + 0, IH); \ + int ih1 = getter(floor(ih) + 1, IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + vec_dst_type result; \ + src_ctype val_x = warp_out_converter( \ + sptr[(ih0 * IW + iw0) * 3] * nalpha * nbeta + \ + sptr[(ih0 * IW + iw1) * 3] * nalpha * pbeta + \ + sptr[(ih1 * IW + iw0) * 3] * palpha * nbeta + \ + sptr[(ih1 * IW + iw1) * 3] * palpha * pbeta); \ + src_ctype val_y = warp_out_converter( \ + sptr[(ih0 * IW + iw0) * 3 + 1] * nalpha * nbeta + \ + sptr[(ih0 * IW + iw1) * 3 + 1] * nalpha * pbeta + \ + sptr[(ih1 * IW + iw0) * 3 + 1] * palpha * nbeta + \ + sptr[(ih1 * IW + iw1) * 3 + 1] * palpha * pbeta); \ + src_ctype val_z = warp_out_converter( \ + sptr[(ih0 * IW + iw0) * 3 + 2] * nalpha * nbeta + \ + sptr[(ih0 * IW + iw1) * 3 + 2] * nalpha * pbeta + \ + sptr[(ih1 * IW + iw0) * 3 + 2] * palpha * nbeta + \ + sptr[(ih1 * IW + iw1) * 3 + 2] * palpha * pbeta); \ + result.x = type_cvt(val_x); \ + result.y = type_cvt(val_y); \ + result.z = type_cvt(val_z); \ + result.w = 0; \ + *((vec_dst_type*)dst + oh * OW + ow) = result; \ + } \ + } + +INST(int8_t, char4) +#undef INST + +#define INST(dst_ctype, vec_dst_type) \ + template \ + __global__ void kern_const_border_quint8_nhw3_nchw4( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int IH, int IW, int OH, int OW, \ + src_ctype bval, CudaTypeCvt type_cvt) { \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, 3 * IH * IW); \ + dst += blockIdx.z * OH * OW * 4; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = floor(iw) + 0; \ + int iw1 = floor(iw) + 1; \ + int ih0 = floor(ih) + 0; \ + int ih1 = floor(ih) + 1; \ + bool okw0 = (iw0 >= 0 && iw0 < IW); \ + bool okw1 = (iw1 >= 0 && iw1 < IW); \ + bool okh0 = (ih0 >= 0 && ih0 < IH); \ + bool okh1 = (ih1 >= 0 && ih1 < IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + vec_dst_type result; \ + src_ctype v00, v01, v10, v11; \ + v00 = (okh0 && okw0 ? sptr[(ih0 * IW + iw0) * 3] : bval); \ + v01 = (okh0 && okw1 ? sptr[(ih0 * IW + iw1) * 3] : bval); \ + v10 = (okh1 && okw0 ? sptr[(ih1 * IW + iw0) * 3] : bval); \ + v11 = (okh1 && okw1 ? sptr[(ih1 * IW + iw1) * 3] : bval); \ + src_ctype val_x = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + v00 = (okh0 && okw0 ? sptr[(ih0 * IW + iw0) * 3 + 1] : bval); \ + v01 = (okh0 && okw1 ? sptr[(ih0 * IW + iw1) * 3 + 1] : bval); \ + v10 = (okh1 && okw0 ? sptr[(ih1 * IW + iw0) * 3 + 1] : bval); \ + v11 = (okh1 && okw1 ? sptr[(ih1 * IW + iw1) * 3 + 1] : bval); \ + src_ctype val_y = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + v00 = (okh0 && okw0 ? sptr[(ih0 * IW + iw0) * 3 + 2] : bval); \ + v01 = (okh0 && okw1 ? sptr[(ih0 * IW + iw1) * 3 + 2] : bval); \ + v10 = (okh1 && okw0 ? sptr[(ih1 * IW + iw0) * 3 + 2] : bval); \ + v11 = (okh1 && okw1 ? sptr[(ih1 * IW + iw1) * 3 + 2] : bval); \ + src_ctype val_z = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + result.x = type_cvt(val_x); \ + result.y = type_cvt(val_y); \ + result.z = type_cvt(val_z); \ + result.w = 0; \ + *((vec_dst_type*)dst + oh * OW + ow) = result; \ + } \ + } + +INST(int8_t, char4) +#undef INST + +template +void dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw4( + bool is_nhwc, SrcVisitor src, const float* mat, dst_ctype* dst, int N, + int C, int IH, int IW, int OH, int OW, src_ctype bval, + CudaDTypeParamImpl param, BorderMode bmode, + cudaStream_t stream) { + const int BY = 16, BX = 32; + CudaTypeCvt type_cvt(param); +#define DISPATCH(Getter) \ + do { \ + if (C == 1) { \ + kern_general_quint8_nhw_nchw4 \ + <<>>(src, mat, dst, IH, IW, \ + OH, OW, type_cvt); \ + } else if (is_nhwc) { \ + kern_general_quint8_nhw3_nchw4 \ + <<>>(src, mat, dst, IH, IW, \ + OH, OW, type_cvt); \ + } else { \ + kern_general_quint8_n3hw_nchw4 \ + <<>>(src, mat, dst, IH, IW, \ + OH, OW, type_cvt); \ + } \ + } while (0) + + const int max_batch_size = 65535; + while (N) { + size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; + dim3 threads(BX, BY); + dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); + + switch (bmode) { + case BORDER_REPLICATE: + DISPATCH(ReplicateGetter); + break; + case BORDER_REFLECT: + DISPATCH(ReflectGetter); + break; + case BORDER_REFLECT_101: + DISPATCH(Reflect101Getter); + break; + case BORDER_WRAP: + DISPATCH(WrapGetter); + break; +#undef DISPATCH + case BORDER_CONSTANT: + if (C == 1) { + kern_const_border_quint8_nhw_nchw4 + <<>>(src, mat, dst, IH, + IW, OH, OW, bval, + type_cvt); + } else if (is_nhwc) { + kern_const_border_quint8_nhw3_nchw4 + <<>>(src, mat, dst, IH, + IW, OH, OW, bval, + type_cvt); + } else { + kern_const_border_quint8_n3hw_nchw4 + <<>>(src, mat, dst, IH, + IW, OH, OW, bval, + type_cvt); + } + break; + default: + break; + } + + N -= curr_batch_size; + src.move_batch(curr_batch_size, C * IH * IW); + mat += curr_batch_size * 3 * 3; + dst += curr_batch_size * 4 * OH * OW; + } +} + +#define INST(dst_ctype) \ + template \ + __global__ void kern_general_quint8_nchw( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int C, int IH, int IW, int OH, int OW, \ + CudaTypeCvt type_cvt) { \ + Getter getter; \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); \ + dst += blockIdx.z * C * OH * OW; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = getter(floor(iw) + 0, IW); \ + int iw1 = getter(floor(iw) + 1, IW); \ + int ih0 = getter(floor(ih) + 0, IH); \ + int ih1 = getter(floor(ih) + 1, IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + for (int c = 0; c < C; ++c) { \ + src_ctype val = warp_out_converter( \ + sptr[ih0 * IW + iw0] * nalpha * nbeta + \ + sptr[ih0 * IW + iw1] * nalpha * pbeta + \ + sptr[ih1 * IW + iw0] * palpha * nbeta + \ + sptr[ih1 * IW + iw1] * palpha * pbeta); \ + dst_ctype result; \ + result = type_cvt(val); \ + dst[oh * OW + ow] = result; \ + sptr += IH * IW; \ + dst += OH * OW; \ + } \ + } \ + } + +INST(float) +#undef INST + +#define INST(dst_ctype) \ + template \ + __global__ void kern_const_border_quint8_nchw( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int C, int IH, int IW, int OH, int OW, \ + src_ctype bval, CudaTypeCvt type_cvt) { \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); \ + dst += blockIdx.z * C * OH * OW; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = floor(iw) + 0; \ + int iw1 = floor(iw) + 1; \ + int ih0 = floor(ih) + 0; \ + int ih1 = floor(ih) + 1; \ + bool okw0 = (iw0 >= 0 && iw0 < IW); \ + bool okw1 = (iw1 >= 0 && iw1 < IW); \ + bool okh0 = (ih0 >= 0 && ih0 < IH); \ + bool okh1 = (ih1 >= 0 && ih1 < IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + for (int c = 0; c < C; ++c) { \ + src_ctype v00 = (okh0 && okw0 ? sptr[ih0 * IW + iw0] : bval); \ + src_ctype v01 = (okh0 && okw1 ? sptr[ih0 * IW + iw1] : bval); \ + src_ctype v10 = (okh1 && okw0 ? sptr[ih1 * IW + iw0] : bval); \ + src_ctype v11 = (okh1 && okw1 ? sptr[ih1 * IW + iw1] : bval); \ + src_ctype val = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + dst_ctype result; \ + result = type_cvt(val); \ + dst[oh * OW + ow] = result; \ + sptr += IH * IW; \ + dst += OH * OW; \ + } \ + } \ + } + +INST(float) +#undef INST + +#define INST(dst_ctype) \ + template \ + __global__ void kern_general_quint8_nhwc_nchw( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int C, int IH, int IW, int OH, int OW, \ + CudaTypeCvt type_cvt) { \ + Getter getter; \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); \ + dst += blockIdx.z * C * OH * OW; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = getter(floor(iw) + 0, IW); \ + int iw1 = getter(floor(iw) + 1, IW); \ + int ih0 = getter(floor(ih) + 0, IH); \ + int ih1 = getter(floor(ih) + 1, IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + for (int c = 0; c < C; ++c) { \ + src_ctype val = warp_out_converter( \ + sptr[(ih0 * IW + iw0) * C + c] * nalpha * nbeta + \ + sptr[(ih0 * IW + iw1) * C + c] * nalpha * pbeta + \ + sptr[(ih1 * IW + iw0) * C + c] * palpha * nbeta + \ + sptr[(ih1 * IW + iw1) * C + c] * palpha * pbeta); \ + dst_ctype result; \ + result = type_cvt(val); \ + dst[oh * OW + ow] = result; \ + dst += OH * OW; \ + } \ + } \ + } + +INST(float) +#undef INST + +#define INST(dst_ctype) \ + template \ + __global__ void kern_const_border_quint8_nhwc_nchw( \ + SrcVisitor src, const float* __restrict mat, \ + dst_ctype* __restrict dst, int C, int IH, int IW, int OH, int OW, \ + src_ctype bval, CudaTypeCvt type_cvt) { \ + rounding::RoundingConverter warp_out_converter; \ + int ow = blockIdx.x * blockDim.x + threadIdx.x; \ + int oh = blockIdx.y * blockDim.y + threadIdx.y; \ + const src_ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); \ + dst += blockIdx.z * C * OH * OW; \ + mat += blockIdx.z * 3 * 3; \ + if (ow < OW && oh < OH) { \ + float denominator = mat[6] * ow + mat[7] * oh + mat[8]; \ + float iw = (mat[0] * ow + mat[1] * oh + mat[2]) / denominator; \ + float ih = (mat[3] * ow + mat[4] * oh + mat[5]) / denominator; \ + int iw0 = floor(iw) + 0; \ + int iw1 = floor(iw) + 1; \ + int ih0 = floor(ih) + 0; \ + int ih1 = floor(ih) + 1; \ + bool okw0 = (iw0 >= 0 && iw0 < IW); \ + bool okw1 = (iw1 >= 0 && iw1 < IW); \ + bool okh0 = (ih0 >= 0 && ih0 < IH); \ + bool okh1 = (ih1 >= 0 && ih1 < IH); \ + float palpha = ih - floor(ih); \ + float pbeta = iw - floor(iw); \ + float nalpha = 1.0f - palpha; \ + float nbeta = 1.0f - pbeta; \ + for (int c = 0; c < C; ++c) { \ + src_ctype v00 = (okh0 && okw0 ? sptr[(ih0 * IW + iw0) * C + c] \ + : bval); \ + src_ctype v01 = (okh0 && okw1 ? sptr[(ih0 * IW + iw1) * C + c] \ + : bval); \ + src_ctype v10 = (okh1 && okw0 ? sptr[(ih1 * IW + iw0) * C + c] \ + : bval); \ + src_ctype v11 = (okh1 && okw1 ? sptr[(ih1 * IW + iw1) * C + c] \ + : bval); \ + float val = warp_out_converter( \ + v00 * nalpha * nbeta + v01 * nalpha * pbeta + \ + v10 * palpha * nbeta + v11 * palpha * pbeta); \ + dst_ctype result; \ + result = type_cvt(val); \ + dst[oh * OW + ow] = result; \ + dst += OH * OW; \ + } \ + } \ + } + +INST(float) +#undef INST + +template +void dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw( + bool is_nhwc, SrcVisitor src, const float* mat, dst_ctype* dst, int N, + int C, int IH, int IW, int OH, int OW, src_ctype bval, + CudaDTypeParamImpl param, BorderMode bmode, + cudaStream_t stream) { + const int BY = 16, BX = 32; + CudaTypeCvt type_cvt(param); +#define DISPATCH(Getter) \ + do { \ + if (is_nhwc) { \ + kern_general_quint8_nhwc_nchw \ + <<>>(src, mat, dst, C, IH, IW, \ + OH, OW, type_cvt); \ + } else { \ + kern_general_quint8_nchw \ + <<>>(src, mat, dst, C, IH, IW, \ + OH, OW, type_cvt); \ + } \ + } while (0) + + const int max_batch_size = 65535; + while (N) { + size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; + dim3 threads(BX, BY); + dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); + + switch (bmode) { + case BORDER_REPLICATE: + DISPATCH(ReplicateGetter); + break; + case BORDER_REFLECT: + DISPATCH(ReflectGetter); + break; + case BORDER_REFLECT_101: + DISPATCH(Reflect101Getter); + break; + case BORDER_WRAP: + DISPATCH(WrapGetter); + break; +#undef DISPATCH + case BORDER_CONSTANT: + if (is_nhwc) { + kern_const_border_quint8_nhwc_nchw + <<>>(src, mat, dst, C, + IH, IW, OH, OW, + bval, type_cvt); + } else { + kern_const_border_quint8_nchw + <<>>(src, mat, dst, C, + IH, IW, OH, OW, + bval, type_cvt); + } + break; + default: + break; + } + + N -= curr_batch_size; + src.move_batch(curr_batch_size, C * IH * IW); + mat += curr_batch_size * 3 * 3; + dst += curr_batch_size * C * OH * OW; + } +} + +} // anonymous namespace namespace megdnn { namespace cuda { namespace warp_perspective { -template -void forward_proxy( - bool is_nhwc, - const ctype *src, const float *mat, const int *mat_idx, - ctype *dst, int N_SRC, int N_MAT, - int C, int IH, int IW, int OH, int OW, ctype bval, - BorderMode bmode, - megcore::AsyncErrorInfo* error_info, void* error_tracker, - cudaStream_t stream) -{ +template +void forward_proxy(bool is_nhwc, const ctype* src, const float* mat, + const int* mat_idx, ctype* dst, int N_SRC, int N_MAT, int C, + int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode, + megcore::AsyncErrorInfo* error_info, void* error_tracker, + cudaStream_t stream) { if (mat_idx) { IndexedSrcVisitor visitor; visitor.ptr = src; @@ -447,15 +1119,13 @@ void forward_proxy( visitor.N_SRC = N_SRC; visitor.error_info = error_info; visitor.error_tracker = error_tracker; - dispatch_with_visitor(is_nhwc, - visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, - bmode, stream); + dispatch_with_visitor(is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, + OW, bval, bmode, stream); } else { DirectSrcVisitor visitor; visitor.ptr = src; - dispatch_with_visitor(is_nhwc, - visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, - bmode, stream); + dispatch_with_visitor(is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, + OW, bval, bmode, stream); } after_kernel_launch(); } @@ -506,8 +1176,84 @@ INST(int8_t) INST(int8_t) #undef INST -} // namespace warp_perspective -} // namespace cuda -} // namespace megdnn +template +void forward_proxy_quint8_dimshuffle_typecvt_nchw4( + bool is_nhwc, const src_ctype* src, const float* mat, + const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH, + int IW, int OH, int OW, src_ctype bval, DTypeParamImpl param, + BorderMode bmode, megcore::AsyncErrorInfo* error_info, + void* error_tracker, cudaStream_t stream) { + CudaDTypeParamImpl dtype_param(param); + if (mat_idx) { + IndexedSrcVisitor visitor; + visitor.ptr = src; + visitor.idx = mat_idx; + visitor.N_SRC = N_SRC; + visitor.error_info = error_info; + visitor.error_tracker = error_tracker; + dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw4( + is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, + dtype_param, bmode, stream); + } else { + DirectSrcVisitor visitor; + visitor.ptr = src; + dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw4( + is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, + dtype_param, bmode, stream); + } + after_kernel_launch(); +} + +#define INST(src_dtype, src_ctype, dst_ctype) \ + template void forward_proxy_quint8_dimshuffle_typecvt_nchw4( \ + bool is_nhwc, const src_ctype*, const float*, const int*, \ + dst_ctype*, int, int, int, int, int, int, int, src_ctype, \ + DTypeParamImpl param, BorderMode, \ + megcore::AsyncErrorInfo*, void*, cudaStream_t); + +INST(dt_quint8, uint8_t, int8_t) +#undef INST + +template +void forward_proxy_quint8_dimshuffle_typecvt_nchw( + bool is_nhwc, const src_ctype* src, const float* mat, + const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH, + int IW, int OH, int OW, src_ctype bval, DTypeParamImpl param, + BorderMode bmode, megcore::AsyncErrorInfo* error_info, + void* error_tracker, cudaStream_t stream) { + CudaDTypeParamImpl dtype_param(param); + if (mat_idx) { + IndexedSrcVisitor visitor; + visitor.ptr = src; + visitor.idx = mat_idx; + visitor.N_SRC = N_SRC; + visitor.error_info = error_info; + visitor.error_tracker = error_tracker; + dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw( + is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, + dtype_param, bmode, stream); + } else { + DirectSrcVisitor visitor; + visitor.ptr = src; + dispatch_with_visitor_quint8_dimshuffle_typecvt_nchw( + is_nhwc, visitor, mat, dst, N_MAT, C, IH, IW, OH, OW, bval, + dtype_param, bmode, stream); + } + after_kernel_launch(); +} + +#define INST(src_dtype, src_ctype, dst_ctype) \ + template void forward_proxy_quint8_dimshuffle_typecvt_nchw( \ + bool is_nhwc, const src_ctype*, const float*, const int*, \ + dst_ctype*, int, int, int, int, int, int, int, src_ctype, \ + DTypeParamImpl param, BorderMode, \ + megcore::AsyncErrorInfo*, void*, cudaStream_t); + +INST(dt_quint8, uint8_t, float) +#undef INST + +} // namespace warp_perspective +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/warp_perspective/opr_impl.cpp b/dnn/src/naive/warp_perspective/opr_impl.cpp index ca46d191..370b8dd9 100644 --- a/dnn/src/naive/warp_perspective/opr_impl.cpp +++ b/dnn/src/naive/warp_perspective/opr_impl.cpp @@ -249,6 +249,162 @@ void WarpPerspectiveForwardImpl::kern_naive_nhwcd4( MIDOUT_END(); } +template +void WarpPerspectiveForwardImpl::kern_naive_dimshuffle_typecvt( + const KernParam& kern_param, size_t task_id) { + MEGDNN_MARK_USED_VAR(kern_param); + MIDOUT_BEGIN(megdnn_naive_warpperspective, ctype, mtype, midout_iv(2)) { + UNPACK_WARP_PERSPECTIVE_FWD_KERN_PARAM(kern_param); + MEGDNN_MARK_USED_VAR(N_MAT); + //! strides of C, H, W on src and dst + size_t sstrd[3], dstrd[3]; + auto set_sstrd = [&](size_t s0, size_t s1, size_t s2) { + sstrd[0] = s0; + sstrd[1] = s1; + sstrd[2] = s2; + }; + auto set_dstrd = [&](size_t s0, size_t s1, size_t s2) { + dstrd[0] = s0; + dstrd[1] = s1; + dstrd[2] = s2; + }; + switch (kern_param.format) { + case Format::NCHW: + case Format::NCHW_NCHW4_IC_SMALL: + set_sstrd(IH * IW, IW, 1); + set_dstrd(OH * OW, OW, 1); + break; + case Format::NHWC_NCHW: + case Format::NHWC_NCHW4_IC_SMALL: + set_sstrd(1, IW * C, C); + set_dstrd(OH * OW, OW, 1); + break; + default: + megdnn_throw("bad format"); + } + + uint8_t zero_point = 0; + float scale = 1.f; + + bool is_dst_float = kern_param.dst_dtype.enumv() == DTypeEnum::Float32; + if (kern_param.src_dtype.enumv() == + DTypeTrait::enumv) { + auto dtype_param = + kern_param.src_dtype + .template param(); + zero_point = dtype_param.zero_point; + scale = dtype_param.scale; + } else if (kern_param.src_dtype.enumv() == DTypeEnum::Uint8) { + zero_point = + (kern_param.dst_dtype.enumv() == DTypeEnum::QuantizedS8) + ? 128 + : 0; + scale = 1.f; + } + + dst_ctype* dst_ptr = reinterpret_cast(dptr); + + bool is_dst_nchw4 = + (kern_param.format == Format::NCHW_NCHW4_IC_SMALL) || + (kern_param.format == Format::NHWC_NCHW4_IC_SMALL); + auto visit_src = [&sptr, sstrd](size_t c, int h, int w) -> float { + return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w]; + }; + auto visit_src_bd = [&sptr, sstrd, border_val](size_t c, int h, + int w) -> float { + if (h != -1 && w != -1) { + return sptr[sstrd[0] * c + sstrd[1] * h + sstrd[2] * w]; + } else + return border_val; + }; + auto visit_dst = [&dst_ptr, dstrd, is_dst_nchw4](size_t c, int h, + int w) -> dst_ctype& { + if (!is_dst_nchw4) + return dst_ptr[dstrd[0] * c + dstrd[1] * h + dstrd[2] * w]; + else + return dst_ptr[((dstrd[0] * (c >> 2) + dstrd[1] * h + + dstrd[2] * w) + << 2) + + (c & 0b11)]; + }; + + rounding::RoundingConverter output_converter; + auto orig_sptr = sptr; + size_t n = task_id / OH; + size_t oh = task_id % OH; + mptr = mptr + n * 3 * 3; + dst_ptr = is_dst_nchw4 ? (dst_ptr + n * OH * OW * 4) + : (dst_ptr + n * C * OH * OW); + if (midx_ptr) { + size_t idx = midx_ptr[n]; + megdnn_assert( + idx < N_SRC, + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", n, + idx, N_SRC); + sptr = orig_sptr + idx * (C * IH * IW); + } else if (n) { + sptr += n * C * IH * IW; + } + rep(ow, OW) { + float numeratorw = mptr[0] * ow + mptr[1] * oh + mptr[2]; + float numeratorh = mptr[3] * ow + mptr[4] * oh + mptr[5]; + float denominator = mptr[6] * ow + mptr[7] * oh + mptr[8]; + float alphaw = numeratorw / denominator; + float alphah = numeratorh / denominator; + + int iw0 = get_real_coord(std::floor(alphaw) + 0, IW); + int iw1 = get_real_coord(std::floor(alphaw) + 1, IW); + int ih0 = get_real_coord(std::floor(alphah) + 0, IH); + int ih1 = get_real_coord(std::floor(alphah) + 1, IH); + + alphaw -= floor(alphaw); + alphah -= floor(alphah); + if (bmode != BorderMode::CONSTANT) { + rep(c, C) { + auto val = + visit_src(c, ih0, iw0) * (1.0f - alphaw) * + (1.0f - alphah) + + visit_src(c, ih0, iw1) * alphaw * (1.0f - alphah) + + visit_src(c, ih1, iw0) * (1.0f - alphaw) * alphah + + visit_src(c, ih1, iw1) * alphaw * alphah; + val = is_dst_float ? (val - zero_point) * scale + : val - zero_point; + visit_dst(c, oh, ow) = output_converter(val); + } + } else { + rep(c, C) { + auto val = visit_src_bd(c, ih0, iw0) * (1.0f - alphaw) * + (1.0f - alphah) + + visit_src_bd(c, ih0, iw1) * alphaw * + (1.0f - alphah) + + visit_src_bd(c, ih1, iw0) * (1.0f - alphaw) * + alphah + + visit_src_bd(c, ih1, iw1) * alphaw * alphah; + val = std::isfinite(val) ? val : border_val; + val = is_dst_float ? (val - zero_point) * scale + : val - zero_point; + visit_dst(c, oh, ow) = output_converter(val); + } + } + if (is_dst_nchw4) { + for (auto c = C; c < 4; ++c) { + visit_dst(c, oh, ow) = 0; + } + } + } + } + MIDOUT_END(); +} + +#define INST(ctype, drc_ctype, mtype) \ + template void WarpPerspectiveForwardImpl::kern_naive_dimshuffle_typecvt< \ + ctype, drc_ctype, mtype>(const KernParam&, size_t); + +INST(uint8_t, int8_t, float); +INST(uint8_t, float, float); + +#undef INST + void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, @@ -320,6 +476,65 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, src.layout.dtype.name()) .c_str()); } + + bool is_fusion_dtype = src.layout.dtype.enumv() != dst.layout.dtype.enumv(); + bool is_u8_or_qu8_in = + src.layout.dtype.enumv() == DTypeTrait::enumv || + src.layout.dtype.enumv() == + DTypeTrait::enumv; + + if (is_fusion_dtype && is_u8_or_qu8_in && + ((param().format == Format::NCHW_NCHW4_IC_SMALL) || + (param().format == Format::NHWC_NCHW4_IC_SMALL) || + (param().format == Format::NHWC_NCHW) || + (param().format == Format::NCHW))) { + if (src.layout.dtype.enumv() == + DTypeTrait::enumv || + src.layout.dtype.enumv() == DTypeTrait::enumv) { + float scale = 1.f; + + if (src.layout.dtype.enumv() == + DTypeTrait::enumv) { + scale = src.layout.dtype.param().scale; + } + + auto kparam = KernParam::from_tensors( + param().format, param().bmode, param().border_val, src, mat, + mat_idx, dst, workspace); + + if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { + auto run = [kparam, this](size_t index, size_t) { + kern_naive_dimshuffle_typecvt(kparam, + index); + }; + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, + kparam.oh * batch); + return; + } else if ((dst.layout.dtype.enumv() == + DTypeTrait::enumv) && + (dst.layout.dtype.param().scale == + scale)) { + auto run = [kparam, this](size_t index, size_t) { + kern_naive_dimshuffle_typecvt( + kparam, index); + }; + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN_OPR(run, + kparam.oh * batch); + return; + } else { + megdnn_throw(ssprintf("Unsupported DType in " + "WarpPerspective Dimshuffle Typecvt: %s", + src.layout.dtype.name()) + .c_str()); + } + } + + megdnn_throw(ssprintf("Unsupported input DType in " + "WarpPerspective: %s", + src.layout.dtype.name()) + .c_str()); + } + if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, param().format)) { MIDOUT_BEGIN(megdnn_naive_warpperspective, void) { @@ -331,12 +546,12 @@ void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in src, megdnn_assert(warp::is_dnn_available(src.layout, mat.layout, dst.layout, param().imode, param().format)); /*! - * We currently use floating point for all WarpPerspective computation, - * so even if the input ctype is one of the integer type, mtype should - * always be float32. + * We currently use floating point for all WarpPerspective + * computation, so even if the input ctype is one of the integer + * type, mtype should always be float32. * - * \warning It's different with \c WarpAffine, with mtype be float16 if - * input type is float16. + * \warning It's different with \c WarpAffine, with mtype be float16 + * if input type is float16. */ DISPATCH_ST(dtype::Float32, float, float, KERN); diff --git a/dnn/src/naive/warp_perspective/opr_impl.h b/dnn/src/naive/warp_perspective/opr_impl.h index 94e3b396..4652c568 100644 --- a/dnn/src/naive/warp_perspective/opr_impl.h +++ b/dnn/src/naive/warp_perspective/opr_impl.h @@ -26,6 +26,7 @@ protected: float border_val; size_t n_src, n_mat, c, ih, iw, oh, ow; ctype *sptr, *dptr; + DType src_dtype, dst_dtype; mtype* mptr; int* midx_ptr; //!< can be null Workspace workspace; @@ -41,6 +42,8 @@ protected: ret.bmode = bmode; ret.border_val = border_val; ret.n_src = src.layout.shape[0]; + ret.src_dtype = src.layout.dtype; + ret.dst_dtype = dst.layout.dtype; if (mat_idx.raw_ptr) { megdnn_assert(mat_idx.layout.ndim == 1); ret.n_mat = mat_idx.layout.shape[0]; @@ -50,7 +53,8 @@ protected: ret.n_mat = ret.n_src; ret.midx_ptr = nullptr; } - if (format == Format::NCHW) { + if (format == Format::NCHW || + format == Format::NCHW_NCHW4_IC_SMALL) { ret.c = src.layout.shape[1]; ret.ih = src.layout.shape[2]; ret.iw = src.layout.shape[3]; @@ -62,6 +66,13 @@ protected: ret.iw = src.layout.shape[2]; ret.oh = dst.layout.shape[1]; ret.ow = dst.layout.shape[2]; + } else if (format == Format::NHWC_NCHW || + format == Format::NHWC_NCHW4_IC_SMALL) { + ret.c = src.layout.shape[3]; + ret.ih = src.layout.shape[1]; + ret.iw = src.layout.shape[2]; + ret.oh = dst.layout.shape[2]; + ret.ow = dst.layout.shape[3]; } else if (format == Format::NCHW4) { ret.c = src.layout.shape[1] * 4; ret.ih = src.layout.shape[2]; @@ -76,15 +87,16 @@ protected: ret.oh = dst.layout.shape[1]; ret.ow = dst.layout.shape[3]; } - if (src.layout.dtype.enumv() == DTypeEnum::Float32 || - MEGDNN_FLOAT16_SELECT( - (src.layout.dtype.enumv() == DTypeEnum::Float16 || - src.layout.dtype.enumv() == DTypeEnum::BFloat16), - false) || - src.layout.dtype.enumv() == DTypeEnum::Int8 || - src.layout.dtype.enumv() == DTypeEnum::Uint8 || - src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 || - src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { + if ((src.layout.dtype.enumv() == DTypeEnum::Float32 || + MEGDNN_FLOAT16_SELECT( + (src.layout.dtype.enumv() == DTypeEnum::Float16 || + src.layout.dtype.enumv() == DTypeEnum::BFloat16), + false) || + src.layout.dtype.enumv() == DTypeEnum::Int8 || + src.layout.dtype.enumv() == DTypeEnum::Uint8 || + src.layout.dtype.enumv() == DTypeEnum::QuantizedS8 || + src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) && + (src.layout.dtype == dst.layout.dtype)) { ret.sptr = src.compatible_ptr(); ret.mptr = mat.ptr(); ret.dptr = dst.compatible_ptr(); @@ -92,6 +104,13 @@ protected: ret.sptr = src.compatible_ptr(); ret.mptr = mat.ptr(); ret.dptr = dst.compatible_ptr(); + } else if ((src.layout.dtype.enumv() == DTypeEnum::Uint8 || + src.layout.dtype.enumv() == + DTypeEnum::Quantized8Asymm) && + src.layout.dtype.enumv() != dst.layout.dtype.enumv()) { + ret.sptr = src.compatible_ptr(); + ret.mptr = mat.ptr(); + ret.dptr = reinterpret_cast(dst.raw_ptr); } else { ret.sptr = nullptr; ret.mptr = nullptr; @@ -122,6 +141,9 @@ private: template void kern_naive_nhwcd4(const KernParam& kern_param, size_t task_id); + template + void kern_naive_dimshuffle_typecvt( + const KernParam& kern_param, size_t task_id); }; class WarpPerspectiveBackwardDataImpl : public WarpPerspectiveBackwardData { diff --git a/dnn/test/cuda/warp_perspective.cpp b/dnn/test/cuda/warp_perspective.cpp index 2911f42c..43241b8a 100644 --- a/dnn/test/cuda/warp_perspective.cpp +++ b/dnn/test/cuda/warp_perspective.cpp @@ -23,8 +23,7 @@ using namespace megdnn; using namespace test; class NanMatRNG : public RNG { - void gen(const TensorND& tensor_) override - { + void gen(const TensorND& tensor_) override { auto& gen = RandomState::generator(); std::uniform_real_distribution pdist3(1.9f, 2.1f); std::uniform_real_distribution pdist(0.9f, 1.1f); @@ -335,6 +334,144 @@ TEST_F(CUDA, WARP_PERSPECTIVE_NCHW4) { } } +TEST_F(CUDA, WARP_PERSPECTIVE_NCHW_NCHW4_IC_SMALL) { + using Param = WarpPerspective::Param; + WarpPerspective::Param param; + Checker checker(handle_cuda()); + WarpPerspectiveMatRNG rng; + + param.format = Param::Format::NCHW_NCHW4_IC_SMALL; + + checker.set_rng(1, &rng); + checker.set_dtype(0, dtype::Quantized8Asymm(0.1f, 128)); + checker.set_dtype(2, dtype::QuantizedS8(0.1f)); + for (auto bmode : {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) { + param.border_val = 0.3f; + param.bmode = bmode; + param.imode = Param::InterpolationMode::LINEAR; + + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + checker.execs({{2, 3, 10, 11}, {2, 3, 3}, {2, 1, 11, 12, 4}}); + checker.execs({{1, 3, 25, 510}, {1, 3, 3}, {1, 1, 25, 25, 4}}); + checker.execs({{1, 3, 25, 25}, {1, 3, 3}, {1, 1, 51, 51, 4}}); + checker.execs({{1, 3, 51, 51}, {1, 3, 3}, {1, 1, 25, 25, 4}}); + } + { + Checker checker( + handle_cuda()); + constexpr int N_SRC = 5; + UniformIntRNG mat_idx_rng{0, N_SRC - 1}; + checker.set_dtype(0, dtype::Quantized8Asymm(0.1f, 128)); + checker.set_rng(1, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.set_rng(2, &mat_idx_rng); + checker.set_dtype(3, dtype::QuantizedS8(0.1f)); + param.bmode = WarpPerspective::Param::BorderMode::REFLECT; + param.imode = param::WarpPerspective::InterpolationMode::LINEAR; + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + checker.execs({{N_SRC, 3, 10, 11}, {2, 3, 3}, {2}, {2, 1, 11, 12, 4}}); + checker.execs( + {{N_SRC, 3, 17, 13}, {123, 3, 3}, {123}, {123, 1, 16, 15, 4}}); + } +} + +TEST_F(CUDA, WARP_PERSPECTIVE_NHWC_NCHW4_IC_SMALL) { + using Param = WarpPerspective::Param; + WarpPerspective::Param param; + Checker checker(handle_cuda()); + WarpPerspectiveMatRNG rng; + + param.format = Param::Format::NHWC_NCHW4_IC_SMALL; + + checker.set_rng(1, &rng); + checker.set_dtype(0, dtype::Uint8()); + checker.set_dtype(2, dtype::QuantizedS8(1.f)); + for (auto bmode : {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) { + param.border_val = 0.3f; + param.bmode = bmode; + param.imode = Param::InterpolationMode::LINEAR; + + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + checker.execs({{2, 10, 11, 3}, {2, 3, 3}, {2, 1, 11, 12, 4}}); + checker.execs({{1, 25, 510, 3}, {1, 3, 3}, {1, 1, 25, 25, 4}}); + checker.execs({{1, 25, 25, 3}, {1, 3, 3}, {1, 1, 51, 51, 4}}); + checker.execs({{1, 51, 51, 3}, {1, 3, 3}, {1, 1, 25, 25, 4}}); + } + { + Checker checker( + handle_cuda()); + constexpr int N_SRC = 5; + UniformIntRNG mat_idx_rng{0, N_SRC - 1}; + checker.set_dtype(0, dtype::Uint8()); + checker.set_rng(1, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.set_rng(2, &mat_idx_rng); + checker.set_dtype(3, dtype::QuantizedS8(1.f)); + param.bmode = WarpPerspective::Param::BorderMode::REFLECT; + param.imode = param::WarpPerspective::InterpolationMode::LINEAR; + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 1, 11, 12, 4}}); + checker.execs( + {{N_SRC, 17, 13, 3}, {123, 3, 3}, {123}, {123, 1, 16, 15, 4}}); + } +} + +TEST_F(CUDA, WARP_PERSPECTIVE_NHWC_NCHW) { + using Param = WarpPerspective::Param; + WarpPerspective::Param param; + Checker checker(handle_cuda()); + WarpPerspectiveMatRNG rng; + + param.format = Param::Format::NHWC_NCHW; + + checker.set_rng(1, &rng); + checker.set_dtype(0, dtype::Uint8()); + checker.set_dtype(2, dtype::Float32()); + for (auto bmode : {WarpPerspective::BorderMode::WRAP, + WarpPerspective::BorderMode::REFLECT, + WarpPerspective::BorderMode::REPLICATE, + WarpPerspective::BorderMode::CONSTANT}) { + param.border_val = 0.3f; + param.bmode = bmode; + param.imode = Param::InterpolationMode::LINEAR; + + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + checker.execs({{2, 10, 11, 3}, {2, 3, 3}, {2, 3, 11, 12}}); + checker.execs({{1, 25, 510, 3}, {1, 3, 3}, {1, 3, 25, 25}}); + checker.execs({{1, 25, 25, 3}, {1, 3, 3}, {1, 3, 51, 51}}); + checker.execs({{1, 51, 51, 3}, {1, 3, 3}, {1, 3, 25, 25}}); + } + { + Checker checker( + handle_cuda()); + constexpr int N_SRC = 5; + UniformIntRNG mat_idx_rng{0, N_SRC - 1}; + checker.set_dtype(0, dtype::Uint8()); + checker.set_rng(1, &rng); + checker.set_dtype(2, dtype::Int32()); + checker.set_rng(2, &mat_idx_rng); + checker.set_dtype(3, dtype::Float32()); + param.bmode = WarpPerspective::Param::BorderMode::REFLECT; + param.imode = param::WarpPerspective::InterpolationMode::LINEAR; + checker.set_param(param); + checker.set_epsilon(1 + 1e-3); + checker.execs({{N_SRC, 10, 11, 3}, {2, 3, 3}, {2}, {2, 3, 11, 12}}); + checker.execs( + {{N_SRC, 17, 13, 3}, {123, 3, 3}, {123}, {123, 3, 16, 15}}); + } +} + TEST_F(CUDA, WARP_PERSPECTIVE_FORWARD_NCHW_INT8) { warp_perspective::run_int8_test(handle_cuda()); } diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index e0b80eaa..1238cd27 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "megbrain/gopt/framework.h" @@ -35,13 +36,13 @@ using namespace gopt; /* ================ SubGraph ================ */ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( - OperatorNodeBase *opr) { - auto &&new_inp = m_opr_new_inp_cache; + OperatorNodeBase* opr) { + auto&& new_inp = m_opr_new_inp_cache; new_inp.clear(); new_inp.reserve(opr->input().size()); bool has_replaced_inp = false; - for (auto i: opr->input()) { + for (auto i : opr->input()) { auto new_var = get_var(i); if (new_var != i) { has_replaced_inp = true; @@ -52,14 +53,14 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( } if (has_replaced_inp) { - auto new_opr = serialization::copy_opr_shallow( - *opr, new_inp, opr->config()); + auto new_opr = + serialization::copy_opr_shallow(*opr, new_inp, opr->config()); auto &&out0 = opr->output(), &&out1 = new_opr->output(); size_t i = 0; auto err_msg = [opr, new_opr] { - return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}", - opr->cname(), opr->dyn_typeinfo()->name, - new_opr->cname(), new_opr->dyn_typeinfo()->name); + return ssprintf("bad opr copy: src=%s{%s} dst=%s{%s}", opr->cname(), + opr->dyn_typeinfo()->name, new_opr->cname(), + new_opr->dyn_typeinfo()->name); }; MGB_MARK_USED_VAR(err_msg); // opr output size mismatch may be caused by: @@ -67,33 +68,33 @@ OperatorNodeBase* SubGraph::Rewriter::auto_replace_outputs( // 1) other post-insert optimization (e.g. const folding) // we can't handle only usable_output here, since some output var with // volatile flag could be the graph's endpoint (e.g. RemoteSend) - for (; i < std::min(out0.size(), out1.size()); ++ i) { + for (; i < std::min(out0.size(), out1.size()); ++i) { bool v0 = out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), v1 = out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT); mgb_assert(v0 == v1, "%s", err_msg().c_str()); - auto &&ins = m_varmap.insert({out0[i], {true, nullptr}}); + auto&& ins = m_varmap.insert({out0[i], {true, nullptr}}); mgb_assert(ins.second || ins.first->second.first, "opr output already replaced"); // handle repeated call on the same opr ins.first->second.second = out1[i]; on_var_replaced(out0[i], out1[i], nullptr); } - for (; i < out0.size(); ++ i) { + for (; i < out0.size(); ++i) { mgb_assert(out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), - "%s", err_msg().c_str()); + "%s", err_msg().c_str()); } - for (; i < out1.size(); ++ i) { + for (; i < out1.size(); ++i) { mgb_assert(out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT), - "%s", err_msg().c_str()); + "%s", err_msg().c_str()); } return new_opr; } return opr; } -void SubGraph::Rewriter::replace_var( - VarNode *src, VarNode *dst, const char *msg) { +void SubGraph::Rewriter::replace_var(VarNode* src, VarNode* dst, + const char* msg) { if (src == dst) return; @@ -103,19 +104,19 @@ void SubGraph::Rewriter::replace_var( "dst %s maps back to src %s in SubGraph::Rewriter::replace_var", dst->cname(), src->cname()); - auto &&ins = m_varmap.insert({src, {false, dst}}); + auto&& ins = m_varmap.insert({src, {false, dst}}); if (!ins.second) { - auto &&old_rep = ins.first->second; + auto&& old_rep = ins.first->second; mgb_assert(old_rep.first || old_rep.second == dst, - "can not replace a var twice"); + "can not replace a var twice"); old_rep.first = false; old_rep.second = dst; } on_var_replaced(src, dst, msg); } -void SubGraph::Rewriter::on_var_replaced( - VarNode* src, VarNode* dst, const char* msg) { +void SubGraph::Rewriter::on_var_replaced(VarNode* src, VarNode* dst, + const char* msg) { if (auto state = m_owner_graph->owner_opt_state()) { state->on_var_replaced(src, dst, msg); } @@ -124,7 +125,7 @@ void SubGraph::Rewriter::on_var_replaced( void SubGraph::Rewriter::apply_inplace() const { m_owner_graph->m_endpoint_oprs.clear(); m_owner_graph->m_endpoint_vars_set.clear(); - for (auto &&var: m_owner_graph->m_endpoint_vars) { + for (auto&& var : m_owner_graph->m_endpoint_vars) { var = get_var(var.node()); m_owner_graph->m_endpoint_oprs.insert(var.node()->owner_opr()); m_owner_graph->m_endpoint_vars_set.insert(var.node()); @@ -150,33 +151,30 @@ std::pair SubGraph::Rewriter::get_var_internal(VarNode* var) { return it->second = {it_next->second.first & it->second.first, next.second}; } -SubGraph::SubGraph(const SymbolVarArray &endpoint_vars): - m_endpoint_vars(endpoint_vars) -{ +SubGraph::SubGraph(const SymbolVarArray& endpoint_vars) + : m_endpoint_vars(endpoint_vars) { mgb_assert(!endpoint_vars.empty(), "endpoints can not be empty"); m_comp_graph = endpoint_vars[0].node()->owner_graph(); - for (auto i: endpoint_vars) { + for (auto i : endpoint_vars) { m_endpoint_oprs.insert(i.node()->owner_opr()); m_endpoint_vars_set.insert(i.node()); mgb_assert(m_comp_graph == i.node()->owner_graph(), - "endpoints belong to different computing graphs"); + "endpoints belong to different computing graphs"); } } -void SubGraph::iter( - const Callback& cb, - std::shared_ptr extra_dep) const { +void SubGraph::iter(const Callback& cb, + std::shared_ptr extra_dep) const { Callback on_opr; if (m_owner_opt_state) { - on_opr = [state=m_owner_opt_state, &cb](OperatorNodeBase *opr) { + on_opr = [state = m_owner_opt_state, &cb](OperatorNodeBase* opr) { state->m_opr_property_flag = OprPropertyFlag::ALL; state->m_cur_iter_src_opr = cg::get_opr_root_source_opr(opr); state->m_cur_iter_opr_priority = - opr->node_prop().attribute().priority; + opr->node_prop().attribute().priority; state->m_cur_iter_opr_stream_prop_type = - state->m_comp_node_opt.stream_prop_type( - opr->output(0)); + state->m_comp_node_opt.stream_prop_type(opr->output(0)); mgb_assert(state->m_oprs_inserted.empty()); cb(opr); state->m_opr_property_flag = OprPropertyFlag::NONE; @@ -188,19 +186,19 @@ void SubGraph::iter( } cg::DepOprIter dep_iter{on_opr, std::move(extra_dep)}; - for (auto i: m_endpoint_oprs) + for (auto i : m_endpoint_oprs) dep_iter.add(i); } ThinHashMap SubGraph::get_var2nr_val_dep_oprs() const { ThinHashMap ret; - auto cb = [&](OperatorNodeBase *opr) { - for (auto &&i: opr->node_prop().dep_map()) { + auto cb = [&](OperatorNodeBase* opr) { + for (auto&& i : opr->node_prop().dep_map()) { if (OperatorNodeBase::NodeProp::is_device_value_dep(i.second)) { - ++ ret.at(i.first); + ++ret.at(i.first); } } - for (auto i: opr->output()) { + for (auto i : opr->output()) { if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { auto ins = ret.insert({i, 0}); mgb_assert(ins.second); @@ -208,13 +206,13 @@ ThinHashMap SubGraph::get_var2nr_val_dep_oprs() const { } }; iter(cb); - for (auto i: m_endpoint_vars_set) { + for (auto i : m_endpoint_vars_set) { auto iter = ret.find(i); if (iter == ret.end()) { mgb_assert(i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); ret[i] = 1; } else { - ++ ret.at(i); + ++ret.at(i); } } return ret; @@ -222,10 +220,8 @@ ThinHashMap SubGraph::get_var2nr_val_dep_oprs() const { /* ================ UniqReaderCheck ================ */ -UniqReaderCheck::UniqReaderCheck(const SubGraph &graph): - m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} -{ -} +UniqReaderCheck::UniqReaderCheck(const SubGraph& graph) + : m_var2nr_val_dep{graph.get_var2nr_val_dep_oprs()} {} void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr, OperatorNodeBase* repl_opr) { @@ -253,32 +249,30 @@ void UniqReaderCheck::update_on_opr_auto_replace(OperatorNodeBase* opr, /* ================ OptState ================ */ -OptState::OptState( - const GraphOptimizer *owner_optimizer, const SubGraph& graph): - m_owner_optimizer{owner_optimizer}, - m_var_replace_map{ - const_cast*>( - &GraphOptimizer::var_replace_map(*graph.comp_graph()))}, - m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()}, - m_graph{graph} -{ +OptState::OptState(const GraphOptimizer* owner_optimizer, const SubGraph& graph) + : m_owner_optimizer{owner_optimizer}, + m_var_replace_map{const_cast*>( + &GraphOptimizer::var_replace_map(*graph.comp_graph()))}, + m_comp_node_opt{graph.comp_graph()->seq_comp_node_optimizer()}, + m_graph{graph} { mgb_assert(!m_graph.m_owner_opt_state); m_var_replace_map->clear(); m_graph.m_owner_opt_state = this; m_oprs_inserted.clear(); - auto on_opr_insert = [this](const cg::event::OprInserted &ev) { + auto on_opr_insert = [this](const cg::event::OprInserted& ev) { auto need_src_opr = m_opr_property_flag & OprPropertyFlag::SOURCE_OPR, need_priority = m_opr_property_flag & OprPropertyFlag::PRIORITY; if (need_src_opr) - mgb_assert(m_cur_iter_src_opr, "opr %s{%s} created outside from " - "SubGraph::iter", - ev.opr->cname(), ev.opr->dyn_typeinfo()->name); + mgb_assert(m_cur_iter_src_opr, + "opr %s{%s} created outside from " + "SubGraph::iter", + ev.opr->cname(), ev.opr->dyn_typeinfo()->name); if (ev.exc || ev.is_dedup) return; - auto &&new_attr = ev.opr->node_prop().attribute(); - auto &&ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE}); + auto&& new_attr = ev.opr->node_prop().attribute(); + auto&& ins = m_oprs_inserted.insert({ev.opr, OprPropertyFlag::NONE}); mgb_assert(ins.second); if (need_src_opr && !new_attr.src_opr) { @@ -296,20 +290,22 @@ OptState::OptState( auto csp = m_cur_iter_opr_stream_prop_type; if (csp.prop_type != cg::SeqCompNodeOptimizer::StreamPropType::NONE) { - for (auto i: ev.opr->output()) + for (auto i : ev.opr->output()) m_comp_node_opt.register_stream_var(i, csp); } }; - m_on_opr_insert_handler = graph.comp_graph()->event().register_receiver< - cg::event::OprInserted>(on_opr_insert); + m_on_opr_insert_handler = + graph.comp_graph() + ->event() + .register_receiver(on_opr_insert); } -void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) { +void OptState::on_var_replaced(VarNode* src, VarNode* dst, const char* msg) { if (src->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { // this can only happen in auto_replace_outputs() mgb_assert(dst->contain_flag(VarNode::Flag::VOLATILE_CONTENT) && - src->owner_opr()->dyn_typeinfo() == - dst->owner_opr()->dyn_typeinfo()); + src->owner_opr()->dyn_typeinfo() == + dst->owner_opr()->dyn_typeinfo()); mgb_assert(!msg); return; } @@ -362,7 +358,7 @@ void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) { return f & (InferType::RT_STATIC | InferType::CONST); }; if (!(norm(it0.shape) == norm(it1.shape) && - norm(it0.value) <= norm(it1.value))) { + norm(it0.value) <= norm(it1.value))) { suc = false; fail_chks.push_back("infer-type"); } @@ -407,22 +403,21 @@ void OptState::on_var_replaced(VarNode *src, VarNode *dst, const char *msg) { #if MGB_ENABLE_LOGGING if (msg && m_owner_optimizer->verbosity()) { - m_log_msg. - append("\n "). - append(std::to_string(m_log_nr_item)). - append(": "). - append(src->owner_opr()->cname()). - append(" => "). - append(dst->owner_opr()->cname()). - append(" ("). - append(msg). - append(")"); - } - ++ m_log_nr_item; + m_log_msg.append("\n ") + .append(std::to_string(m_log_nr_item)) + .append(": ") + .append(src->owner_opr()->cname()) + .append(" => ") + .append(dst->owner_opr()->cname()) + .append(" (") + .append(msg) + .append(")"); + } + ++m_log_nr_item; #endif } -size_t OptState::flush_log(const char *title) { +size_t OptState::flush_log(const char* title) { if (m_owner_optimizer->verbosity() >= 2) { if (m_log_msg.empty()) { m_log_msg = mgb_cstr_log(" no var replacement logged"); @@ -435,42 +430,40 @@ size_t OptState::flush_log(const char *title) { return ret; } -void OptState::call_with_opr(OperatorNodeBase *opr, thin_function func, +void OptState::call_with_opr(OperatorNodeBase* opr, + thin_function func, OprPropertyFlag opr_property_flag) { auto src_opr = cg::get_opr_root_source_opr(opr); auto opr_priority = opr->node_prop().attribute().priority; auto stream_prop_type = m_comp_node_opt.stream_prop_type(opr->output(0)); ThinHashMap oprs_inserted; - auto swap_properties = [&, - need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR, - need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] { - if (need_src_opr) { - std::swap(m_cur_iter_src_opr, src_opr); - } - if (need_priority) { - std::swap(m_cur_iter_opr_priority, opr_priority); - } - std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type); - std::swap(m_opr_property_flag, opr_property_flag); - std::swap(m_oprs_inserted, oprs_inserted); - }; + auto swap_properties = + [&, need_src_opr = opr_property_flag & OprPropertyFlag::SOURCE_OPR, + need_priority = opr_property_flag & OprPropertyFlag::PRIORITY] { + if (need_src_opr) { + std::swap(m_cur_iter_src_opr, src_opr); + } + if (need_priority) { + std::swap(m_cur_iter_opr_priority, opr_priority); + } + std::swap(m_cur_iter_opr_stream_prop_type, stream_prop_type); + std::swap(m_opr_property_flag, opr_property_flag); + std::swap(m_oprs_inserted, oprs_inserted); + }; MGB_TRY { swap_properties(); func(); - } MGB_FINALLY({ - swap_properties(); - }); + } + MGB_FINALLY({ swap_properties(); }); } /* ================ RecursiveSubGraphRewriteHelper ================ */ -RecursiveSubGraphRewriteHelper:: -~RecursiveSubGraphRewriteHelper() noexcept = default; +RecursiveSubGraphRewriteHelper::~RecursiveSubGraphRewriteHelper() noexcept = + default; -RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState &state): - m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} -{ -} +RecursiveSubGraphRewriteHelper::RecursiveSubGraphRewriteHelper(OptState& state) + : m_opt_state{state}, m_rewriter{state.graph().make_rewriter()} {} void RecursiveSubGraphRewriteHelper::apply() { using namespace std::placeholders; @@ -479,8 +472,8 @@ void RecursiveSubGraphRewriteHelper::apply() { m_rewriter.apply_inplace(); } -void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { - auto on_new_opr = [this](OperatorNodeBase *opr) { +void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase* opr) { + auto on_new_opr = [this](OperatorNodeBase* opr) { auto repl_opr = m_rewriter.auto_replace_outputs(opr); return on_new_opr_check_should_process(opr, repl_opr); }; @@ -493,8 +486,8 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { return; mgb_assert(m_opr_stack.empty()); - m_opr_stack.push_back({ - orig_out, m_rewriter.get_var(orig_out)->owner_opr()}); + m_opr_stack.push_back( + {orig_out, m_rewriter.get_var(orig_out)->owner_opr()}); bool first = true; while (!m_opr_stack.empty()) { @@ -515,9 +508,9 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { if (should_process) { auto trans = process_opr(cur_out); if (trans.valid()) { - m_opr_stack.push_back({ - cur_frame.orig_var, trans->result->owner_opr()}); - for (auto i: reverse_adaptor(trans->internal)) { + m_opr_stack.push_back( + {cur_frame.orig_var, trans->result->owner_opr()}); + for (auto i : reverse_adaptor(trans->internal)) { if (i) m_opr_stack.push_back({i, i->owner_opr()}); } @@ -532,7 +525,7 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { auto src = cur_frame.orig_var; if (m_rewriter.get_var(src) != cur_out) { - const char *msg = nullptr; + const char* msg = nullptr; if (m_opr_stack.empty()) { msg = m_log_msg.c_str(); } @@ -550,11 +543,12 @@ void RecursiveSubGraphRewriteHelper::on_opr(OperatorNodeBase *opr) { GraphOptimizer::~GraphOptimizer() noexcept = default; -class GraphOptimizer::VarReplaceMapStorage :public UserDataContainer::UserData { +class GraphOptimizer::VarReplaceMapStorage + : public UserDataContainer::UserData { MGB_TYPEINFO_OBJ_DECL; - public: - ThinHashMap map; +public: + ThinHashMap map; }; MGB_TYPEINFO_OBJ_IMPL(GraphOptimizer::VarReplaceMapStorage); @@ -565,7 +559,7 @@ GraphOptimizer& GraphOptimizer::add_pass(std::unique_ptr pass) { return *this; } -SubGraph GraphOptimizer::apply(const SubGraph &graph) const { +SubGraph GraphOptimizer::apply(const SubGraph& graph) const { RealTimer timer; OptState state{this, graph}; @@ -574,38 +568,38 @@ SubGraph GraphOptimizer::apply(const SubGraph &graph) const { // first update output var shapes of all oprs state.graph().iter(cg::update_output_var_shapes); - auto &&opt = graph.comp_graph()->options(); + auto&& opt = graph.comp_graph()->options(); auto orig_setting = opt.graph_opt_level; - Pass *cur_pass = nullptr; + Pass* cur_pass = nullptr; MGB_MARK_USED_VAR(cur_pass); MGB_TRY { - for (auto &&i: m_passes) { + for (auto&& i : m_passes) { state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL); cur_pass = i.get(); opt.graph_opt_level = 1; i->apply(state); tot_nr_replace += state.flush_log( - mgb_ssprintf_log( - "apply optimization pass %s:", i->name()).c_str()); + mgb_ssprintf_log("apply optimization pass %s:", i->name()) + .c_str()); } - } MGB_CATCH(std::exception &exc, { + } + MGB_CATCH(std::exception & exc, { mgb_log_error("error while applying optimization pass %s: %s", - cur_pass->name(), exc.what()); + cur_pass->name(), exc.what()); opt.graph_opt_level = orig_setting; throw; }) - MGB_FINALLY( - opt.graph_opt_level = orig_setting - ); + MGB_FINALLY(opt.graph_opt_level = orig_setting); if (verbosity() >= 1) { - mgb_log_debug("graph optimization: applied %zu passes, " + mgb_log_debug( + "graph optimization: applied %zu passes, " "total %zu var(s) replaced; time=%.2fms", m_passes.size(), tot_nr_replace, timer.get_msecs()); } return state.graph(); } -const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray &vars) const { +const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray& vars) const { if (m_passes.empty()) { // this check is necessary, since OptState would clear // var_replace_map() @@ -613,7 +607,7 @@ const GraphOptimizer& GraphOptimizer::apply_inplace(VarNodeArray &vars) const { } auto g = apply({{vars.begin(), vars.end()}}); - for (size_t i = 0; i < vars.size(); ++ i) { + for (size_t i = 0; i < vars.size(); ++i) { vars[i] = g.endpoint_vars()[i].node(); } return *this; @@ -653,7 +647,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( #if MGB_JIT bool need_jit = false; if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 || - comp_graph_opt->graph_opt.jit)) { + comp_graph_opt->graph_opt.jit)) { need_jit = true; } if (need_jit && after_grad) { @@ -679,7 +673,6 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( add_passes_for_optimize_options(*inference_opt); } - if (inference_opt) { // merge params to reduce loading time and graph overhead add_pass(); @@ -689,15 +682,16 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( } const ThinHashMap& GraphOptimizer::var_replace_map( - ComputingGraph &graph) { - auto storage = graph.options().user_data.get_user_data_or_create< - VarReplaceMapStorage>(); + ComputingGraph& graph) { + auto storage = + graph.options() + .user_data.get_user_data_or_create(); return storage->map; } -VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { - auto &&map = var_replace_map(*(var->owner_graph())); - for (; ; ) { +VarNode* GraphOptimizer::var_replace_lookup(VarNode* var) { + auto&& map = var_replace_map(*(var->owner_graph())); + for (;;) { auto iter = map.find(var); if (iter == map.end()) return var; @@ -705,7 +699,6 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { } } - const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( const cg::GraphCommonOptimizeOptions& options) { return add_passes_for_optimize_options( @@ -723,12 +716,14 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( options.disable_##_option(); \ } \ } - - cb(fuse_preprocess, {add_pass(FuseNCHW4Int8Preprocess::make());}); + + cb(fuse_preprocess, { + add_pass(FuseNCHW4Int8Preprocess::make()); + add_pass(); + }); cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); - cb(nchw4, { add_pass(); add_pass(); @@ -763,6 +758,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( add_pass(); add_pass(); add_pass(FuseNCHW4Int8Preprocess::make()); + add_pass(); }); cb(chwn4, { add_pass(); @@ -790,9 +786,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( /* ================ ConstVarPropogateBase ================ */ ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( - OperatorNodeBase *opr) { + OperatorNodeBase* opr) { using ProfFlag = OperatorNodeBase::NodeProp::Flag; - auto &&info = m_oprinfo[opr]; + auto&& info = m_oprinfo[opr]; if (info.processed) return info.result; info.processed = true; @@ -819,15 +815,14 @@ ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( if (opr->input().empty()) return make_ret(); - if (opr->node_prop().contain( - ProfFlag::FORCE_UPDATE_INPUT_VAR | - ProfFlag::IMPURE_FUNC)) { + if (opr->node_prop().contain(ProfFlag::FORCE_UPDATE_INPUT_VAR | + ProfFlag::IMPURE_FUNC)) { return make_ret(); } size_t max_input_size = 0; ret.all_const_inp = true; - for (auto i: opr->input()) { + for (auto i : opr->input()) { auto io = i->owner_opr(); auto iter = m_oprinfo.find(io); if (iter == m_oprinfo.end()) { @@ -835,7 +830,7 @@ ConstVarPropogate::AddOprResult ConstVarPropogate::add_opr( iter = m_oprinfo.find(io); mgb_assert(iter != m_oprinfo.end()); } - auto &&src = iter->second; + auto&& src = iter->second; if (src.is_const) { update_max(max_input_size, src.max_size); ret.has_const_inp = true; diff --git a/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp b/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp index 1923b6a1..28c59f92 100644 --- a/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp +++ b/src/gopt/impl/fuse_nchw4_int8_preprocess.cpp @@ -19,6 +19,7 @@ #include "megbrain/opr/utility.h" #include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/serialization/serializer.h" +#include "megbrain/opr/imgproc.h" using namespace mgb; using namespace gopt; @@ -443,4 +444,244 @@ void FuseNCHW4Int8Preprocess::apply(OptState& state) const { }; state.graph().iter(on_opr); rewriter.apply_inplace(); +} + +/* ==================== FuseWarpPerspectiveDimshufflePass ================= */ +const char* FuseWarpPerspectiveDimshufflePass::name() const { + return mgb_cstr_log("Fuse warp perspective dimshuffle pass"); +} + +void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const { + auto rewriter = opt.graph().make_rewriter(); + auto uniq_reader_check = UniqReaderCheck{opt.graph()}; + + auto make_new_warp = [&rewriter](opr::WarpPerspective* warp, + opr::WarpPerspective::Param new_param, + megdnn::DType dst_dtype, + SymbolVar& new_warp) { + OperatorNodeConfig new_config(dst_dtype); + if (warp->input().size() == 3) { + auto src = rewriter.get_var(warp->input(0)), + mat = rewriter.get_var(warp->input(1)), + out_shape = rewriter.get_var(warp->input(2)); + new_warp = opr::WarpPerspective::make(src, mat, out_shape, + new_param, new_config); + } else { + mgb_assert(warp->input().size() == 4); + auto src = rewriter.get_var(warp->input(0)), + mat = rewriter.get_var(warp->input(1)), + mat_idx = rewriter.get_var(warp->input(2)), + out_shape = rewriter.get_var(warp->input(3)); + new_warp = opr::WarpPerspective::make(src, mat, mat_idx, out_shape, + new_param, new_config); + } + }; + + auto is_warp_nchw = [&uniq_reader_check](OperatorNodeBase* bottom_opr, + OperatorNodeBase*& top_opr) { + // check warp + auto warp = try_cast_as_op(bottom_opr); + if (warp == nullptr) + return false; + auto inp_dtype = warp->input(0)->dtype(); + bool is_u8_or_qu8 = inp_dtype.enumv() == DTypeEnum::Quantized8Asymm || + inp_dtype.enumv() == DTypeEnum::Uint8; + + bool is_nchw = warp->param().format == + megdnn::param::WarpPerspective::Format::NCHW; + if (!(is_u8_or_qu8 && is_nchw)) + return false; + if (!uniq_reader_check(warp->input(0))) + return false; + + top_opr = warp; + return true; + }; + + auto is_warp_nhwc2nchw = [&uniq_reader_check](OperatorNodeBase* bottom_opr, + OperatorNodeBase*& top_opr) { + // check shuffle + auto shuffle = try_cast_as_op(bottom_opr); + if (shuffle == nullptr) + return false; + auto&& shuffle_param = shuffle->param(); + if (shuffle_param.pattern_len != 4) + return false; + bool is_nhwc2nchw = shuffle_param.pattern[0] == 0 && + shuffle_param.pattern[1] == 3 && + shuffle_param.pattern[2] == 1 && + shuffle_param.pattern[3] == 2; + if (!is_nhwc2nchw) + return false; + if (!uniq_reader_check(shuffle->input(0))) + return false; + + // check warp + auto warp = try_cast_as_op( + shuffle->input(0)->owner_opr()); + if (warp == nullptr) + return false; + auto inp_dtype = warp->input(0)->dtype(); + bool is_u8_or_qu8 = inp_dtype.enumv() == DTypeEnum::Quantized8Asymm || + inp_dtype.enumv() == DTypeEnum::Uint8; + bool is_nhwc = warp->param().format == + megdnn::param::WarpPerspective::Format::NHWC; + if (!(is_u8_or_qu8 && is_nhwc)) + return false; + + top_opr = warp; + return true; + }; + + auto try_warp_nchw_typecvt = [&rewriter, &uniq_reader_check, &is_warp_nchw, + &make_new_warp](OperatorNodeBase* opr) { + // check typecvt + auto typecvt = try_cast_as_op(opr); + if (typecvt == nullptr) + return false; + bool is_to_f32 = + typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32; + if (!is_to_f32) + return false; + if (!uniq_reader_check(typecvt->input(0))) + return false; + + OperatorNodeBase* top_opr = nullptr; + if (!is_warp_nchw(typecvt->input(0)->owner_opr(), top_opr)) + return false; + auto warp = try_cast_as_op(top_opr); + SymbolVar new_warp; + make_new_warp(warp, warp->param(), opr->output()[0]->dtype(), new_warp); + + rewriter.replace_var(opr->output(0), new_warp.node(), + mgb_cstr_log("replace warp + typecvt" + "fuse warp_dimshuffle(NCHW)")); + + return true; + }; + + auto try_warp_nhwc2nchw_typecvt = [&rewriter, &uniq_reader_check, + &is_warp_nhwc2nchw, + &make_new_warp](OperatorNodeBase* opr) { + // check typecvt + auto typecvt = try_cast_as_op(opr); + if (typecvt == nullptr) + return false; + bool is_to_f32 = + typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32; + if (!is_to_f32) + return false; + if (!uniq_reader_check(typecvt->input(0))) + return false; + + OperatorNodeBase* top_opr = nullptr; + if (!is_warp_nhwc2nchw(typecvt->input(0)->owner_opr(), top_opr)) + return false; + auto warp = try_cast_as_op(top_opr); + opr::WarpPerspective::Param new_param = warp->param(); + new_param.format = megdnn::param::WarpPerspective::Format::NHWC_NCHW; + SymbolVar new_warp; + make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp); + + rewriter.replace_var( + opr->output(0), new_warp.node(), + mgb_cstr_log("replace conv_bias + dimshuffle + " + "typecvt to warp_dimshuffle(NHWC_NCHW)")); + + return true; + }; + + auto try_warp_nhwc2nchw4_typecvt = [&rewriter, &uniq_reader_check, + &is_warp_nhwc2nchw, + &make_new_warp](OperatorNodeBase* opr) { + // check relayout + auto relayout = try_cast_as_op(opr); + if (relayout == nullptr) + return false; + bool is_to_q8 = + relayout->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; + bool is_to_nchw2nchw4 = relayout->param().mode == + opr::RelayoutFormat::Param::Mode::NCHW_NCHW4; + if (!(is_to_q8 && is_to_nchw2nchw4)) + return false; + if (!uniq_reader_check(relayout->input(0))) + return false; + + OperatorNodeBase* top_opr = nullptr; + if (!is_warp_nhwc2nchw(relayout->input(0)->owner_opr(), top_opr)) + return false; + + auto warp = try_cast_as_op(top_opr); + + bool is_small_chn = warp->input(0)->shape()[3] < 4; + if (!is_small_chn) + return false; + + opr::WarpPerspective::Param new_param = warp->param(); + new_param.format = + megdnn::param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL; + + SymbolVar new_warp; + make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp); + + rewriter.replace_var( + opr->output(0), new_warp.node(), + mgb_cstr_log("replace warp + dimshuffle + relayout(NCHW_NCHW4)" + "to warp_dimshuffle(NHWC_NCHW4_IC_SMALL)")); + + return true; + }; + + auto try_warp_nchw2nchw4_typecvt = [&rewriter, &uniq_reader_check, + &is_warp_nchw, + &make_new_warp](OperatorNodeBase* opr) { + // check relayout + auto relayout = try_cast_as_op(opr); + if (relayout == nullptr) + return false; + bool is_to_q8 = + relayout->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; + bool is_to_nchw2nchw4 = relayout->param().mode == + opr::RelayoutFormat::Param::Mode::NCHW_NCHW4; + if (!(is_to_q8 && is_to_nchw2nchw4)) + return false; + if (!uniq_reader_check(relayout->input(0))) + return false; + + OperatorNodeBase* top_opr = nullptr; + if (!is_warp_nchw(relayout->input(0)->owner_opr(), top_opr)) + return false; + + auto warp = try_cast_as_op(top_opr); + + bool is_small_chn = warp->input(0)->shape()[1] < 4; + if (!is_small_chn) + return false; + + opr::WarpPerspective::Param new_param = warp->param(); + new_param.format = + megdnn::param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL; + + SymbolVar new_warp; + make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp); + + rewriter.replace_var( + opr->output(0), new_warp.node(), + mgb_cstr_log("replace warp + relayout(NCHW_NCHW4)" + "to warp_dimshuffle(NCHW_NCHW4_IC_SMALL)")); + + return true; + }; + + auto on_opr = [&try_warp_nchw_typecvt, &try_warp_nhwc2nchw_typecvt, + &try_warp_nhwc2nchw4_typecvt, &try_warp_nchw2nchw4_typecvt, + &rewriter](OperatorNodeBase* opr) { + if (!try_warp_nchw_typecvt(opr) && !try_warp_nhwc2nchw_typecvt(opr) && + !try_warp_nhwc2nchw4_typecvt(opr) && + !try_warp_nchw2nchw4_typecvt(opr)) { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); } \ No newline at end of file diff --git a/src/gopt/include/megbrain/gopt/inference.h b/src/gopt/include/megbrain/gopt/inference.h index b309bcb8..b2dd957d 100644 --- a/src/gopt/include/megbrain/gopt/inference.h +++ b/src/gopt/include/megbrain/gopt/inference.h @@ -173,6 +173,16 @@ namespace gopt { }; /*! + * \brief fuse warp perspective and dimshuffle, quint8/uint8 to qint8/float + */ + class FuseWarpPerspectiveDimshufflePass : public Pass { + public: + const char* name() const override; + void apply(OptState& opt) const override; + }; + + + /*! * \brief fuse deconv and typecvt to a deconv opr */ class FuseDeconvCvtPass : public Pass { diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index bf35086d..85875c82 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -1172,7 +1172,8 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) { param.pad_h = param.pad_w = 1; auto w2 = mkcvar("w2", {4, 4, 3, 3}), y = opr::Convolution::make(elem, w2, param), - z = opr::AxisAddRemove::make(y, {opr::AxisAddRemove::AxisDesc::make_add(0)}); + z = opr::AxisAddRemove::make( + y, {opr::AxisAddRemove::AxisDesc::make_add(0)}); SymbolVar y_opt, z_opt; auto options = gopt::OptimizeForInferenceOptions{}; @@ -3722,5 +3723,65 @@ TEST(TestGoptInference, PreProcessCase1) { ASSERT_TRUE(y_opt.node()->owner_opr()->same_type()); } + +TEST(TestGoptInference, WarpAndPreProcessCase) { + REQUIRE_GPU(1); + HostTensorGenerator gen(0, 255); + auto cn = CompNode::load("gpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + + size_t n = 1; + size_t c = 3; + size_t h = 16; + size_t w = 16; + auto host_x1 = gen({n, h, w, c}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + + auto mat_host = std::make_shared(cn, TensorShape{n, 3, 3}, + dtype::Float32()); + warp_perspective_mat_gen(*mat_host, n, h, w); + auto mat = opr::Host2DeviceCopy::make(*graph, mat_host).rename("mat"); + + opr::WarpPerspective::Param warp_param; + warp_param.format = opr::WarpPerspective::Param::Format::NHWC; + auto x_warp = + opr::WarpPerspective::make(x, mat, TensorShape{h, w}, warp_param); + auto x_nchw = opr::Dimshuffle::make(x_warp, {0, 3, 1, 2}, 4, cn); + + auto x_u8 = opr::TypeCvt::make(x_nchw, dtype::Float32(), cn); + auto x_s8 = x_u8 - 128; + auto zero = DTypeScalar(dtype::Float32()); + auto zero_tensor = opr::ImmutableTensor::make(*graph, zero, cn); + auto pad_channel_tensor = + opr::Broadcast::make(zero_tensor, {n, 1, h, w}, cn); + auto paded_x = opr::Concat::make({x_s8, pad_channel_tensor}, 1, cn) + .reshape({n, 1, 4, h, w}); + + auto nchw4_out = opr::Dimshuffle::make(paded_x, {0, 1, 3, 4, 2}, 5, cn); + auto result = opr::TypeCvt::make(nchw4_out, dtype::QuantizedS8(1.f)); + + auto y = result; + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_preprocess(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_TRUE(y_opt.node()->owner_opr()->same_type()); + + ASSERT_EQ(opr::WarpPerspective::Param::Format::NHWC_NCHW4_IC_SMALL, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.WarpAndPreProcessCase.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5); +} #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/imgproc.cpp b/src/opr/impl/imgproc.cpp index 94887296..eb40c40b 100644 --- a/src/opr/impl/imgproc.cpp +++ b/src/opr/impl/imgproc.cpp @@ -47,7 +47,11 @@ SymbolVar WarpPerspectiveForward::make(SymbolVar i0, SymbolVar i1, SymbolVar i2, } void WarpPerspectiveForward::init_output_dtype() { - output(0)->dtype(input(0)->dtype()); + if (config().output_dtype().valid()) { + output(0)->dtype(config().output_dtype()); + } else { + output(0)->dtype(input(0)->dtype()); + } } void WarpPerspectiveForward::add_input_layout_constraint() { @@ -78,23 +82,40 @@ void WarpPerspectiveForward::outshape_by_symvar_do_get_output_shape( mat_idx_shp.to_string().c_str()); } - //! The index of height, e.g.,[b, h, w, c], the height_idx = 1 - size_t height_idx = 0; - if (param().format == Param::Format::NCHW || - param().format == Param::Format::NCHW4) { - height_idx = 2; - } else { - height_idx = 1; - } - - dest = imgshp; - dest[0] = matshp[0]; - if (param().format == Param::Format::NHWCD4) { - dest.shape[height_idx] = oshp2d.shape[0]; - dest.shape[height_idx + 2] = oshp2d.shape[1]; - } else { - for (int i = 0; i < 2; ++i) - dest.shape[height_idx + i] = oshp2d.shape[i]; + switch (param().format) { + case Param::Format::NCHW_NCHW4_IC_SMALL: + case Param::Format::NHWC_NCHW4_IC_SMALL: + dest.ndim = 5; + dest[0] = matshp[0]; + dest.shape[1] = 1; + dest.shape[2] = oshp2d.shape[0]; + dest.shape[3] = oshp2d.shape[1]; + dest.shape[4] = 4; + break; + case Param::Format::NHWC_NCHW: + dest[0] = matshp[0]; + dest.shape[1] = imgshp.shape[3]; + dest.shape[2] = oshp2d.shape[0]; + dest.shape[3] = oshp2d.shape[1]; + break; + default: + size_t height_idx = 0; + if (param().format == Param::Format::NCHW || + param().format == Param::Format::NCHW4) { + height_idx = 2; + } else { + height_idx = 1; + } + dest = imgshp; + dest[0] = matshp[0]; + if (param().format == Param::Format::NHWCD4) { + dest.shape[height_idx] = oshp2d.shape[0]; + dest.shape[height_idx + 2] = oshp2d.shape[1]; + } else { + for (int i = 0; i < 2; ++i) + dest.shape[height_idx + i] = oshp2d.shape[i]; + } + break; } }