|
- /**
- * \file dnn/src/cuda/warp_perspective/forward.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
- * implied.
- */
- #include "src/cuda/warp_perspective/opr_impl.h"
- #include "src/cuda/warp_perspective/warp_perspective_cv.cuh"
-
- #include "src/cuda/utils.h"
- #include "src/cuda/warp_perspective/common.h"
- #include "src/cuda/warp_perspective/helper.h"
-
- #include "src/common/cv/common.h"
- #include "src/common/warp_common.h"
-
- namespace megdnn {
- namespace cuda {
-
- namespace {
- inline void deduce_reformat_layout(std::unique_ptr<RelayoutFormat>& relayout,
- const TensorLayout& src_layout,
- TensorLayout& dst_layout,
- RelayoutFormat::Param::Mode mode,
- const int oc = 0, const int group = 1) {
- if (src_layout.ndim > 0) {
- RelayoutFormat::Param trans_param;
- trans_param.mode = mode;
- trans_param.oc = oc;
- trans_param.group = group;
- relayout->param() = trans_param;
- relayout->deduce_layout(src_layout, dst_layout);
- } else {
- dst_layout = src_layout;
- }
- }
-
- void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
- TensorLayout& inner_src, TensorLayout& inner_dst,
- Handle* handle,
- WarpPerspectiveForwardImpl::Param::Format format) {
- if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
- dst.dtype.enumv() == DTypeEnum::QuantizedS4 &&
- format == param::WarpPerspective::Format::NCHW) {
- auto relayout_opr = handle->create_operator<RelayoutFormat>();
- deduce_reformat_layout(relayout_opr, src, inner_src,
- RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
- deduce_reformat_layout(relayout_opr, dst, inner_dst,
- RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
- } else {
- megdnn_assert(0, "not support");
- }
- }
-
- } // namespace
-
- namespace warp_perspective {
-
- void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in mat,
- _megdnn_tensor_in dst, float border_val,
- BorderMode bmode, InterpolationMode imode,
- _megdnn_workspace workspace,
- cudaStream_t stream) {
- megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3,
- "unsupported src channel");
- megdnn_assert(src.layout.dtype != dtype::Float32() ||
- src.layout.dtype != dtype::Uint8(),
- "unsupported src dtype");
- if (imode == InterpolationMode::INTER_AREA) {
- imode = InterpolationMode::INTER_LINEAR;
- }
- using namespace megcv;
- const float* trans_ptr = mat.ptr<dt_float32>();
- double* workspace_ptr = workspace.ptr<double>();
- for (size_t i = 0; i < src.layout.shape[0]; ++i) {
- if (dst.layout.dtype == dtype::Float32()) {
- Mat<float> src_mat = TensorND2Mat<float>(src, i);
- Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
- if (src_mat.channels() == 1) {
- warp_perspective_cv_proxy<float, 1>(
- src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
- src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
- src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
- border_val, workspace_ptr, stream);
- } else {
- warp_perspective_cv_proxy<float, 3>(
- src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
- src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
- src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
- border_val, workspace_ptr, stream);
- }
- } else if (dst.layout.dtype == dtype::Uint8()) {
- Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
- Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
- if (src_mat.channels() == 1) {
- warp_perspective_cv_proxy<uchar, 1>(
- src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
- src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
- src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
- static_cast<uchar>(border_val), workspace_ptr, stream);
- } else {
- warp_perspective_cv_proxy<uchar, 3>(
- src_mat.ptr(), dst_mat.ptr(), src_mat.rows(),
- src_mat.cols(), dst_mat.rows(), dst_mat.cols(),
- src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr,
- static_cast<uchar>(border_val), workspace_ptr, stream);
- }
-
- } else {
- megdnn_throw("Unsupported datatype of WarpPerspective optr.");
- }
-
- trans_ptr += 3 * 3;
- workspace_ptr += 3 * 3;
- }
- }
-
- } // namespace warp_perspective
-
- WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle(
- void* ptr, const TensorLayout& src, const TensorLayout& mat,
- const TensorLayout& mat_idx, const TensorLayout& dst) const {
- MEGDNN_MARK_USED_VAR(mat_idx);
- SmallVector<size_t> sizes;
- TensorLayout fsrc = src;
- TensorLayout fmat = mat;
- TensorLayout fdst = dst;
- if (src.dtype.enumv() == DTypeEnum::QuantizedS4 &&
- param().format == param::WarpPerspective::Format::NCHW) {
- get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
- sizes.push_back(fsrc.span().dist_byte());
- sizes.push_back(fdst.span().dist_byte());
- } else {
- auto get_workspace = [&sizes](TensorLayout& layout) {
- if (layout.dtype == dtype::BFloat16()) {
- layout.dtype = dtype::Float32();
- sizes.push_back(layout.span().dist_byte());
- }
- };
- get_workspace(fsrc);
- get_workspace(fmat);
- get_workspace(fdst);
- }
- if (param().format == param::WarpPerspective::Format::NHWC) {
- //! use double for the workspace dtype as float may cause
- //! accuracy problems
- sizes.push_back(mat.total_nr_elems() * sizeof(double));
- }
-
- return {ptr, std::move(sizes)};
- }
-
- void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc,
- _megdnn_tensor_in smat,
- _megdnn_tensor_in smat_idx,
- _megdnn_tensor_out sdst,
- _megdnn_workspace sworkspace) {
- check_exec_allow_nhwc_mat_idx(ssrc.layout, smat.layout, smat_idx.layout,
- sdst.layout, sworkspace.size);
-
- TensorND src = ssrc;
- TensorND mat = smat;
- TensorND mat_idx = smat_idx;
- TensorND dst = sdst;
- Param::Format inner_format = param().format;
- auto bundle =
- get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, smat.layout,
- smat_idx.layout, sdst.layout);
- auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
- concrete_handle(this->handle()), &bundle);
- if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
- ctypecvt.src_to_comp_type(ssrc, src)
- .src_to_comp_type(smat, mat)
- .src_to_comp_type(sdst, dst);
- } else if (ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 &&
- param().format == Param::Format::NCHW) {
- auto handle_ptr = handle();
- get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout,
- handle_ptr, param().format);
- src.raw_ptr = bundle.get(0);
- dst.raw_ptr = bundle.get(1);
- auto relayout_opr = handle_ptr->create_operator<RelayoutFormat>();
- RelayoutFormat::Param trans_param;
- trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64;
- relayout_opr->param() = trans_param;
- relayout_opr->exec(ssrc, src, {});
- inner_format = Param::Format::NCHW64;
- }
-
- {
- auto stream = cuda_stream(this->handle());
- bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC;
-
- if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) {
- // use opencv impl only for nhwc and non-linear interp
- megdnn_assert(!mat_idx.raw_ptr,
- "mat_idx is not supported in NHWC case with "
- "non-linear interpolation");
- warp_perspective::warp_perspective_cv_exec(
- src, mat, dst, param().border_val,
- warp_perspective::get_bmode(param().bmode),
- warp_perspective::get_imode(param().imode),
- ctypecvt.workspace(), stream);
-
- } else {
- megdnn_assert(warp::is_dnn_available(src.layout, mat.layout,
- dst.layout, param().imode,
- inner_format));
- size_t C, IH, IW, OH, OW;
- if (is_nhwc) {
- C = src.layout.shape[3];
- IH = src.layout.shape[1];
- IW = src.layout.shape[2];
- OH = dst.layout.shape[1];
- OW = dst.layout.shape[2];
- } else if (inner_format == Param::Format::NCHW4) {
- C = src.layout.shape[1] * 4;
- IH = src.layout.shape[2];
- IW = src.layout.shape[3];
- OH = dst.layout.shape[2];
- OW = dst.layout.shape[3];
- } else if (inner_format == Param::Format::NHWC_NCHW) {
- C = src.layout.shape[3];
- IH = src.layout.shape[1];
- IW = src.layout.shape[2];
- OH = dst.layout.shape[2];
- OW = dst.layout.shape[3];
- } else if (inner_format == Param::Format::NHWC_NCHW4_IC_SMALL) {
- C = src.layout.shape[3];
- IH = src.layout.shape[1];
- IW = src.layout.shape[2];
- OH = dst.layout.shape[2];
- OW = dst.layout.shape[3];
- megdnn_assert(
- (C == 1) || (C == 3),
- "NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3");
- } else if (inner_format == Param::Format::NCHW_NCHW4_IC_SMALL) {
- C = src.layout.shape[1];
- IH = src.layout.shape[2];
- IW = src.layout.shape[3];
- OH = dst.layout.shape[2];
- OW = dst.layout.shape[3];
- megdnn_assert(
- (C == 1) || (C == 3),
- "NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3");
- } else if (inner_format == Param::Format::NCHW64) {
- C = src.layout.shape[1] * 64;
- IH = src.layout.shape[2];
- IW = src.layout.shape[3];
- OH = dst.layout.shape[2];
- OW = dst.layout.shape[3];
- } else {
- megdnn_assert(
- inner_format == param::WarpPerspective::Format::NCHW,
- "invalid warp_perspective format");
- C = src.layout.shape[1];
- IH = src.layout.shape[2];
- IW = src.layout.shape[3];
- OH = dst.layout.shape[2];
- OW = dst.layout.shape[3];
- }
- megdnn_assert(param().imode == Param::InterpolationMode::LINEAR,
- "unsupported interpolation mode for NCHW format");
- auto bval = param().border_val;
- auto bmode = warp_perspective::get_bmode(param().bmode);
- if (src.layout.dtype == dst.layout.dtype) {
- if (src.layout.dtype == dtype::Float32{}) {
- warp_perspective::forward_proxy(
- is_nhwc, src.ptr<dt_float32>(),
- mat.ptr<dt_float32>(),
- mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
- dst.ptr<dt_float32>(), src.layout[0], mat.layout[0],
- C, IH, IW, OH, OW, bval, bmode,
- async_error_info(handle()), m_error_tracker,
- stream);
- } else if (DNN_FLOAT16_SELECT(
- src.layout.dtype == dtype::Float16(),
- false)) {
- #ifndef MEGDNN_DISABLE_FLOAT16
- warp_perspective::forward_proxy(
- is_nhwc, src.ptr<dt_float16>(),
- mat.ptr<dt_float32>(),
- mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
- dst.ptr<dt_float16>(), src.layout[0], mat.layout[0],
- C, IH, IW, OH, OW, static_cast<dt_float16>(bval),
- bmode, async_error_info(handle()), m_error_tracker,
- stream);
- #endif
- } else if (src.layout.dtype == dtype::Uint8()) {
- warp_perspective::forward_proxy<dt_uint8>(
- is_nhwc, src.ptr<dt_uint8>(), mat.ptr<dt_float32>(),
- mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
- dst.ptr<dt_uint8>(), src.layout[0], mat.layout[0],
- C, IH, IW, OH, OW, bval, bmode,
- async_error_info(handle()), m_error_tracker,
- stream);
- } else if (src.layout.dtype == dtype::Int8()) {
- megdnn_assert(!is_nhwc,
- "WarpPerspective on CUDA does not support "
- "NHWC + Int8");
- warp_perspective::forward_proxy<dt_int8>(
- false, src.ptr<dt_int8>(), mat.ptr<dt_float32>(),
- mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
- dst.ptr<dt_int8>(), src.layout[0], mat.layout[0], C,
- IH, IW, OH, OW,
- bval /* implicit float -> int8 conversion,
- should be safe */
- ,
- bmode, async_error_info(handle()), m_error_tracker,
- stream);
- } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
- megdnn_assert(param().format == Param::Format::NCHW4,
- "WarpPerspective on CUDA supports NCHW4 + "
- "QuantizedS8 only");
- warp_perspective::forward_proxy_nchw4<dt_int8>(
- src.compatible_ptr<dt_int8>(),
- mat.ptr<dt_float32>(),
- mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
- dst.compatible_ptr<dt_int8>(), src.layout[0],
- mat.layout[0], C, IH, IW, OH, OW, bval, bmode,
- async_error_info(handle()), m_error_tracker,
- stream);
- } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) {
- megdnn_assert(
- param().format == Param::Format::NCHW64 ||
- param().format == Param::Format::NCHW,
- "WarpPerspective on CUDA supports NCHW64 or NCHW+ "
- "QuantizedS4 only");
- bval = roundf(bval);
- bval = fmin(fmax(-8.f, bval), 7.f);
- warp_perspective::forward_proxy_nchw64<dt_qint4>(
- src.compatible_ptr<dt_qint4>(),
- mat.ptr<dt_float32>(),
- mat_idx.raw_ptr ? mat_idx.ptr<int>() : nullptr,
- dst.compatible_ptr<dt_qint4>(), src.layout[0],
- mat.layout[0], C, IH, IW, OH, OW,
- static_cast<dt_qint4>(bval), bmode,
- async_error_info(handle()), m_error_tracker,
- stream);
- if (param().format == Param::Format::NCHW) {
- auto relayout_opr =
- handle()->create_operator<RelayoutFormat>();
- RelayoutFormat::Param trans_param;
- trans_param.mode =
- RelayoutFormat::Param::Mode::NCHW64_NCHW;
- relayout_opr->param() = trans_param;
- relayout_opr->exec(dst, sdst, {});
- }
- }
- } else if ((src.layout.dtype.enumv() ==
- DTypeEnum::Quantized8Asymm ||
- src.layout.dtype.enumv() == DTypeEnum::Uint8)) {
- uint8_t zero_point = 0;
- float scale = 1.f;
- if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
- zero_point =
- src.layout.dtype.param<dtype::Quantized8Asymm>()
- .zero_point;
- scale = src.layout.dtype.param<dtype::Quantized8Asymm>()
- .scale;
- } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 &&
- dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) {
- zero_point = 128;
- scale = 1.f;
- }
- DTypeParamImpl<dt_quint8> src_dtype_param(scale, zero_point);
-
- if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 &&
- dst.layout.dtype.param<dtype::QuantizedS8>().scale ==
- scale) &&
- ((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) ||
- (param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) {
- bool is_nhwc_ic_small =
- (param().format ==
- Param::Format::NHWC_NCHW4_IC_SMALL);
- warp_perspective::
- forward_proxy_quint8_dimshuffle_typecvt_nchw4<
- dt_quint8, dt_uint8, dt_int8>(
- is_nhwc_ic_small,
- src.compatible_ptr<dt_uint8>(),
- mat.ptr<dt_float32>(),
- mat_idx.raw_ptr ? mat_idx.ptr<int>()
- : nullptr,
- dst.compatible_ptr<dt_int8>(),
- src.layout[0], mat.layout[0], C, IH, IW, OH,
- OW, bval, src_dtype_param, bmode,
- async_error_info(handle()), m_error_tracker,
- stream);
- } else {
- megdnn_assert(
- ((dst.layout.dtype.enumv() == DTypeEnum::Float32) &&
- ((param().format == Param::Format::NCHW) ||
- (param().format == Param::Format::NHWC_NCHW))),
- "invalid format for Quantized8Asymm input");
- bool is_nhwc = (param().format == Param::Format::NHWC_NCHW);
- warp_perspective::
- forward_proxy_quint8_dimshuffle_typecvt_nchw<
- dt_quint8, dt_uint8, dt_float32>(
- is_nhwc, src.compatible_ptr<dt_uint8>(),
- mat.ptr<dt_float32>(),
- mat_idx.raw_ptr ? mat_idx.ptr<int>()
- : nullptr,
- dst.compatible_ptr<dt_float32>(),
- src.layout[0], mat.layout[0], C, IH, IW, OH,
- OW, bval, src_dtype_param, bmode,
- async_error_info(handle()), m_error_tracker,
- stream);
- }
- } else {
- megdnn_throw(ssprintf("unsupported dtype: %s",
- src.layout.dtype.name()));
- }
- }
- }
- if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
- ctypecvt.comp_to_dst_type(dst, sdst);
- }
- }
-
- } // namespace cuda
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
|