/** * \file dnn/src/cuda/warp_perspective/forward.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "src/cuda/warp_perspective/opr_impl.h" #include "src/cuda/warp_perspective/warp_perspective_cv.cuh" #include "src/cuda/utils.h" #include "src/cuda/warp_perspective/common.h" #include "src/cuda/warp_perspective/helper.h" #include "src/common/cv/common.h" #include "src/common/warp_common.h" namespace megdnn { namespace cuda { namespace { inline void deduce_reformat_layout(std::unique_ptr& relayout, const TensorLayout& src_layout, TensorLayout& dst_layout, RelayoutFormat::Param::Mode mode, const int oc = 0, const int group = 1) { if (src_layout.ndim > 0) { RelayoutFormat::Param trans_param; trans_param.mode = mode; trans_param.oc = oc; trans_param.group = group; relayout->param() = trans_param; relayout->deduce_layout(src_layout, dst_layout); } else { dst_layout = src_layout; } } void get_inner_layout(const TensorLayout& src, const TensorLayout& dst, TensorLayout& inner_src, TensorLayout& inner_dst, Handle* handle, WarpPerspectiveForwardImpl::Param::Format format) { if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 || src.dtype.enumv() == DTypeEnum::Quantized4Asymm) && dst.dtype.enumv() == src.dtype.enumv() && format == param::WarpPerspective::Format::NCHW) { auto relayout_opr = handle->create_operator(); deduce_reformat_layout(relayout_opr, src, inner_src, RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1); deduce_reformat_layout(relayout_opr, dst, inner_dst, RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1); } else { megdnn_assert(0, "not support"); } } } // namespace namespace warp_perspective { void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in dst, float border_val, BorderMode bmode, InterpolationMode imode, _megdnn_workspace workspace, cudaStream_t stream) { megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3, "unsupported src channel"); megdnn_assert(src.layout.dtype != dtype::Float32() || src.layout.dtype != dtype::Uint8(), "unsupported src dtype"); if (imode == InterpolationMode::INTER_AREA) { imode = InterpolationMode::INTER_LINEAR; } using namespace megcv; const float* trans_ptr = mat.ptr(); double* workspace_ptr = workspace.ptr(); for (size_t i = 0; i < src.layout.shape[0]; ++i) { if (dst.layout.dtype == dtype::Float32()) { Mat src_mat = TensorND2Mat(src, i); Mat dst_mat = TensorND2Mat(dst, i); if (src_mat.channels() == 1) { warp_perspective_cv_proxy( src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(), dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr, border_val, workspace_ptr, stream); } else { warp_perspective_cv_proxy( src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(), dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr, border_val, workspace_ptr, stream); } } else if (dst.layout.dtype == dtype::Uint8()) { Mat src_mat = TensorND2Mat(src, i); Mat dst_mat = TensorND2Mat(dst, i); if (src_mat.channels() == 1) { warp_perspective_cv_proxy( src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(), dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr, static_cast(border_val), workspace_ptr, stream); } else { warp_perspective_cv_proxy( src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(), dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(), bmode, imode, trans_ptr, static_cast(border_val), workspace_ptr, stream); } } else { megdnn_throw("Unsupported datatype of WarpPerspective optr."); } trans_ptr += 3 * 3; workspace_ptr += 3 * 3; } } } // namespace warp_perspective WorkspaceBundle WarpPerspectiveForwardImpl::get_workspace_bundle( void* ptr, const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& dst) const { MEGDNN_MARK_USED_VAR(mat_idx); SmallVector sizes; TensorLayout fsrc = src; TensorLayout fmat = mat; TensorLayout fdst = dst; if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 || src.dtype.enumv() == DTypeEnum::Quantized4Asymm) && param().format == param::WarpPerspective::Format::NCHW) { get_inner_layout(src, dst, fsrc, fdst, handle(), param().format); sizes.push_back(fsrc.span().dist_byte()); sizes.push_back(fdst.span().dist_byte()); } else { auto get_workspace = [&sizes](TensorLayout& layout) { if (layout.dtype == dtype::BFloat16()) { layout.dtype = dtype::Float32(); sizes.push_back(layout.span().dist_byte()); } }; get_workspace(fsrc); get_workspace(fmat); get_workspace(fdst); } if (param().format == param::WarpPerspective::Format::NHWC) { //! use double for the workspace dtype as float may cause //! accuracy problems sizes.push_back(mat.total_nr_elems() * sizeof(double)); } return {ptr, std::move(sizes)}; } void WarpPerspectiveForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_in smat, _megdnn_tensor_in smat_idx, _megdnn_tensor_out sdst, _megdnn_workspace sworkspace) { check_exec_allow_nhwc_mat_idx(ssrc.layout, smat.layout, smat_idx.layout, sdst.layout, sworkspace.size); TensorND src = ssrc; TensorND mat = smat; TensorND mat_idx = smat_idx; TensorND dst = sdst; Param::Format inner_format = param().format; auto bundle = get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, smat.layout, smat_idx.layout, sdst.layout); auto ctypecvt = CompTypeCvter( concrete_handle(this->handle()), &bundle); if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { ctypecvt.src_to_comp_type(ssrc, src) .src_to_comp_type(smat, mat) .src_to_comp_type(sdst, dst); } else if ((ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 || ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) && param().format == Param::Format::NCHW) { auto handle_ptr = handle(); get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout, handle_ptr, param().format); src.raw_ptr = bundle.get(0); dst.raw_ptr = bundle.get(1); auto relayout_opr = handle_ptr->create_operator(); RelayoutFormat::Param trans_param; trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64; relayout_opr->param() = trans_param; relayout_opr->exec(ssrc, src, {}); inner_format = Param::Format::NCHW64; } { auto stream = cuda_stream(this->handle()); bool is_nhwc = inner_format == param::WarpPerspective::Format::NHWC; if (is_nhwc && param().imode != Param::InterpolationMode::LINEAR) { // use opencv impl only for nhwc and non-linear interp megdnn_assert(!mat_idx.raw_ptr, "mat_idx is not supported in NHWC case with " "non-linear interpolation"); warp_perspective::warp_perspective_cv_exec( src, mat, dst, param().border_val, warp_perspective::get_bmode(param().bmode), warp_perspective::get_imode(param().imode), ctypecvt.workspace(), stream); } else { megdnn_assert(warp::is_dnn_available(src.layout, mat.layout, dst.layout, param().imode, inner_format)); size_t C, IH, IW, OH, OW; if (is_nhwc) { C = src.layout.shape[3]; IH = src.layout.shape[1]; IW = src.layout.shape[2]; OH = dst.layout.shape[1]; OW = dst.layout.shape[2]; } else if (inner_format == Param::Format::NCHW4) { C = src.layout.shape[1] * 4; IH = src.layout.shape[2]; IW = src.layout.shape[3]; OH = dst.layout.shape[2]; OW = dst.layout.shape[3]; } else if (inner_format == Param::Format::NHWC_NCHW) { C = src.layout.shape[3]; IH = src.layout.shape[1]; IW = src.layout.shape[2]; OH = dst.layout.shape[2]; OW = dst.layout.shape[3]; } else if (inner_format == Param::Format::NHWC_NCHW4_IC_SMALL) { C = src.layout.shape[3]; IH = src.layout.shape[1]; IW = src.layout.shape[2]; OH = dst.layout.shape[2]; OW = dst.layout.shape[3]; megdnn_assert( (C == 1) || (C == 3), "NHWC_NCHW4_IC_SMALL only support C == 1 or C == 3"); } else if (inner_format == Param::Format::NCHW_NCHW4_IC_SMALL) { C = src.layout.shape[1]; IH = src.layout.shape[2]; IW = src.layout.shape[3]; OH = dst.layout.shape[2]; OW = dst.layout.shape[3]; megdnn_assert( (C == 1) || (C == 3), "NCHW_NCHW4_IC_SMALL only support C == 1 or C == 3"); } else if (inner_format == Param::Format::NCHW64) { C = src.layout.shape[1] * 64; IH = src.layout.shape[2]; IW = src.layout.shape[3]; OH = dst.layout.shape[2]; OW = dst.layout.shape[3]; } else { megdnn_assert( inner_format == param::WarpPerspective::Format::NCHW, "invalid warp_perspective format"); C = src.layout.shape[1]; IH = src.layout.shape[2]; IW = src.layout.shape[3]; OH = dst.layout.shape[2]; OW = dst.layout.shape[3]; } megdnn_assert(param().imode == Param::InterpolationMode::LINEAR, "unsupported interpolation mode for NCHW format"); auto bval = param().border_val; auto bmode = warp_perspective::get_bmode(param().bmode); if (src.layout.dtype == dst.layout.dtype) { if (src.layout.dtype == dtype::Float32{}) { warp_perspective::forward_proxy( is_nhwc, src.ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, bval, bmode, async_error_info(handle()), m_error_tracker, stream); } else if (DNN_FLOAT16_SELECT( src.layout.dtype == dtype::Float16(), false)) { #ifndef MEGDNN_DISABLE_FLOAT16 warp_perspective::forward_proxy( is_nhwc, src.ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, static_cast(bval), bmode, async_error_info(handle()), m_error_tracker, stream); #endif } else if (src.layout.dtype == dtype::Uint8()) { warp_perspective::forward_proxy( is_nhwc, src.ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, bval, bmode, async_error_info(handle()), m_error_tracker, stream); } else if (src.layout.dtype == dtype::Int8()) { megdnn_assert(!is_nhwc, "WarpPerspective on CUDA does not support " "NHWC + Int8"); warp_perspective::forward_proxy( false, src.ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, bval /* implicit float -> int8 conversion, should be safe */ , bmode, async_error_info(handle()), m_error_tracker, stream); } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { megdnn_assert(param().format == Param::Format::NCHW4, "WarpPerspective on CUDA supports NCHW4 + " "QuantizedS8 only"); warp_perspective::forward_proxy_nchw4( src.compatible_ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.compatible_ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, bval, bmode, async_error_info(handle()), m_error_tracker, stream); } else if (src.layout.dtype.enumv() == DTypeEnum::QuantizedS4) { megdnn_assert( param().format == Param::Format::NCHW64 || param().format == Param::Format::NCHW, "WarpPerspective on CUDA supports NCHW64 or NCHW+ " "QuantizedS4"); bval = roundf(bval); bval = fmin(fmax(-8.f, bval), 7.f); warp_perspective::forward_proxy_nchw64( src.compatible_ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.compatible_ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, static_cast(bval), bmode, async_error_info(handle()), m_error_tracker, stream); if (param().format == Param::Format::NCHW) { auto relayout_opr = handle()->create_operator(); RelayoutFormat::Param trans_param; trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW; trans_param.oc = sdst.layout[1]; relayout_opr->param() = trans_param; relayout_opr->exec(dst, sdst, {}); } } else if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { megdnn_assert( param().format == Param::Format::NCHW64 || param().format == Param::Format::NCHW, "WarpPerspective on CUDA supports NCHW64 or NCHW+ " "Quantized4Asymm"); bval = roundf(bval); bval = fmin(fmax(0, bval), 15); warp_perspective::forward_proxy_nchw64( src.compatible_ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.compatible_ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, static_cast(bval), bmode, async_error_info(handle()), m_error_tracker, stream); if (param().format == Param::Format::NCHW) { auto relayout_opr = handle()->create_operator(); RelayoutFormat::Param trans_param; trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW; trans_param.oc = sdst.layout[1]; relayout_opr->param() = trans_param; relayout_opr->exec(dst, sdst, {}); } } } else if ((src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm || src.layout.dtype.enumv() == DTypeEnum::Uint8)) { uint8_t zero_point = 0; float scale = 1.f; if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { zero_point = src.layout.dtype.param() .zero_point; scale = src.layout.dtype.param() .scale; } else if (src.layout.dtype.enumv() == DTypeEnum::Uint8 && dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { zero_point = 128; scale = 1.f; } DTypeParamImpl src_dtype_param(scale, zero_point); if ((dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8 && dst.layout.dtype.param().scale == scale) && ((param().format == Param::Format::NCHW_NCHW4_IC_SMALL) || (param().format == Param::Format::NHWC_NCHW4_IC_SMALL))) { bool is_nhwc_ic_small = (param().format == Param::Format::NHWC_NCHW4_IC_SMALL); warp_perspective:: forward_proxy_quint8_dimshuffle_typecvt_nchw4< dt_quint8, dt_uint8, dt_int8>( is_nhwc_ic_small, src.compatible_ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.compatible_ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, bval, src_dtype_param, bmode, async_error_info(handle()), m_error_tracker, stream); } else { megdnn_assert( ((dst.layout.dtype.enumv() == DTypeEnum::Float32) && ((param().format == Param::Format::NCHW) || (param().format == Param::Format::NHWC_NCHW))), "invalid format for Quantized8Asymm input"); bool is_nhwc = (param().format == Param::Format::NHWC_NCHW); warp_perspective:: forward_proxy_quint8_dimshuffle_typecvt_nchw< dt_quint8, dt_uint8, dt_float32>( is_nhwc, src.compatible_ptr(), mat.ptr(), mat_idx.raw_ptr ? mat_idx.ptr() : nullptr, dst.compatible_ptr(), src.layout[0], mat.layout[0], C, IH, IW, OH, OW, bval, src_dtype_param, bmode, async_error_info(handle()), m_error_tracker, stream); } } else { megdnn_throw(ssprintf("unsupported dtype: %s", src.layout.dtype.name())); } } } if (ssrc.layout.dtype.enumv() == DTypeTrait::enumv) { ctypecvt.comp_to_dst_type(dst, sdst); } } } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen