|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 |
- /**
- * \file dnn/src/common/warp_perspective.cpp
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
- * implied.
- */
- #include "megdnn/oprs.h"
-
- #include "src/common/utils.h"
-
- namespace megdnn {
-
- void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src,
- const TensorLayout& mat,
- const TensorLayout& mat_idx,
- const TensorLayout& dst) {
- megdnn_assert_contiguous(mat);
- megdnn_assert_contiguous(src);
- megdnn_assert_contiguous(dst);
- auto errmsg = [&]() {
- return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(mat) + ", " +
- megdnn_layout_msg(mat_idx) + ", " + megdnn_layout_msg(dst) +
- ", " + param_msg();
- };
- MEGDNN_MARK_USED_VAR(errmsg);
- if (param().format == param::WarpPerspective::Format::NHWCD4 ||
- param().format == param::WarpPerspective::Format::NCHW4 ||
- param().format == param::WarpPerspective::Format::NCHW64) {
- megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str());
- megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
-
- } else if (param().format ==
- param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL ||
- param().format ==
- param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) {
- megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
- megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str());
- } else {
- megdnn_assert(param().format == param::WarpPerspective::Format::NHWC ||
- param().format == param::WarpPerspective::Format::NCHW ||
- param().format ==
- param::WarpPerspective::Format::NHWC_NCHW);
- megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str());
- megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str());
- }
- megdnn_assert(mat.ndim == 3_z, "%s", errmsg().c_str());
- megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str());
- if (mat_idx.ndim) {
- megdnn_assert(mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1,
- "%s", errmsg().c_str());
- megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str());
- megdnn_assert_contiguous(mat_idx);
- } else {
- megdnn_assert(src.shape[0] == dst.shape[0], "%s", errmsg().c_str());
- }
- megdnn_assert(mat.shape[1] == 3_z, "%s", errmsg().c_str());
- megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str());
-
- if (src.format == dst.format && dst.dtype == src.dtype) {
- if (param().format == param::WarpPerspective::Format::NCHW) {
- megdnn_assert(
- src.dtype.enumv() == DTypeEnum::Float32 ||
- DNN_FLOAT16_SELECT(
- (src.dtype.enumv() == DTypeEnum::Float16 ||
- src.dtype.enumv() == DTypeEnum::BFloat16),
- false) ||
- src.dtype.enumv() == DTypeEnum::Int8 ||
- src.dtype.enumv() == DTypeEnum::Uint8 ||
- (src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
- src.dtype.enumv() == DTypeEnum::Quantized8Asymm) ||
- src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
- src.dtype.enumv() == DTypeEnum::Quantized4Asymm,
- "WarpPerspective NCHW input dtype should be "
- "Float32/Int8/Uint8/QInt8/QUint8/QInt4/QUInt4" DNN_FLOAT16_SELECT(
- "/Float16/BFloat16", "") ".");
- megdnn_assert(
- (src.dtype.category() == DTypeCategory::FLOAT &&
- (src.dtype == mat.dtype ||
- mat.dtype.enumv() == DTypeEnum::Float32)) ||
- ((src.dtype.category() == DTypeCategory::INT ||
- src.dtype.category() ==
- DTypeCategory::QUANTIZED) &&
- mat.dtype.enumv() == DTypeEnum::Float32),
- "The input to WarpPerspective is in NCHW format, in this "
- "case, if the input dtype is floating point, the "
- "transformation matrix should have same dtype as the "
- "input, otherwise, it should be in Float32, %s given.",
- mat.dtype.name());
-
- megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
-
- megdnn_assert(param().imode ==
- param::WarpPerspective::InterpolationMode::LINEAR);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::TRANSPARENT);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::ISOLATED);
- } else if (param().format == param::WarpPerspective::Format::NHWC) {
- megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str());
- } else if (param().format == param::WarpPerspective::Format::NCHW4) {
- megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8,
- "src expected QuantizedS8, but got %s",
- src.dtype.name());
- megdnn_assert(mat.dtype == dtype::Float32(),
- "matrix dtype expected float, got %s",
- mat.dtype.name());
- megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4);
- megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
-
- megdnn_assert(param().imode ==
- param::WarpPerspective::InterpolationMode::LINEAR);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::TRANSPARENT);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::ISOLATED);
- } else if (param().format == param::WarpPerspective::Format::NCHW64) {
- megdnn_assert((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
- src.dtype.enumv() == DTypeEnum::Quantized4Asymm),
- "src expected QuantizedS4/Quantized4Asymm, but got %s",
- src.dtype.name());
- megdnn_assert(mat.dtype == dtype::Float32(),
- "matrix dtype expected float, got %s",
- mat.dtype.name());
- megdnn_assert(src.shape[4] == 64 && dst.shape[4] == 64);
- megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str());
-
- megdnn_assert(param().imode ==
- param::WarpPerspective::InterpolationMode::LINEAR);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::TRANSPARENT);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::ISOLATED);
- } else {
- megdnn_assert(param().format ==
- param::WarpPerspective::Format::NHWCD4);
- megdnn_assert(
- src.dtype == dtype::Float32() ||
- DNN_FLOAT16_SELECT((src.dtype == dtype::Float16() ||
- src.dtype == dtype::BFloat16()),
- false) ||
- src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
- src.dtype.enumv() == DTypeEnum::Quantized8Asymm,
- "WarpPerspective NHWCD4 input dtype should be "
- "Float32" DNN_FLOAT16_SELECT(
- "/Float16/BFloat16",
- "") ",QunatizedS8, Quantized8Asymm.");
- megdnn_assert(
- (src.dtype == mat.dtype || mat.dtype == dtype::Float32()),
- "The input to WarpPerspective is in NHWCD4 format, in this "
- "case, if the input dtype is floating point, the "
- "transformation matrix should have same dtype as the "
- "input, %s given.",
- mat.dtype.name());
- //! number of channels is same
- megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str());
- megdnn_assert(param().imode ==
- param::WarpPerspective::InterpolationMode::LINEAR);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::TRANSPARENT);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::ISOLATED);
- }
- } else if (param().format ==
- param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL ||
- param().format ==
- param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) {
- megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
- src.dtype.enumv() == DTypeEnum::Uint8),
- "src expected Quantized8Asymm or Uint8, but got %s",
- src.dtype.name());
- megdnn_assert(mat.dtype == dtype::Float32(),
- "matrix dtype expected float, got %s", mat.dtype.name());
- megdnn_assert(dst.shape[4] == 4);
-
- megdnn_assert(param().imode ==
- param::WarpPerspective::InterpolationMode::LINEAR);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::TRANSPARENT);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::ISOLATED);
- } else if (param().format == param::WarpPerspective::Format::NHWC_NCHW) {
- megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
- src.dtype.enumv() == DTypeEnum::Uint8),
- "src expected Quantized8Asymm or Uint8, but got %s",
- src.dtype.name());
- megdnn_assert(mat.dtype == dtype::Float32(),
- "matrix dtype expected float, got %s", mat.dtype.name());
- megdnn_assert(src.shape[3] == dst.shape[1], "%s", errmsg().c_str());
-
- megdnn_assert(param().imode ==
- param::WarpPerspective::InterpolationMode::LINEAR);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::TRANSPARENT);
- megdnn_assert(param().bmode !=
- param::WarpPerspective::BorderMode::ISOLATED);
- } else {
- megdnn_assert(param().format == param::WarpPerspective::Format::NCHW);
- megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm ||
- src.dtype.enumv() == DTypeEnum::Uint8) &&
- dst.dtype.enumv() == DTypeEnum::Float32);
- }
- }
-
- std::string WarpPerspectiveBase::param_msg() const {
- std::string res;
- res.append("imode=");
- switch (param().imode) {
- case InterpolationMode::NEAREST:
- res.append("NEAREST");
- break;
- case InterpolationMode::LINEAR:
- res.append("LINEAR");
- break;
- case InterpolationMode::AREA:
- res.append("AREA");
- break;
- case InterpolationMode::CUBIC:
- res.append("CUBIC");
- break;
- case InterpolationMode::LANCZOS4:
- res.append("LANCZOS4");
- break;
- }
- res.append(", bmode=");
- switch (param().bmode) {
- case BorderMode::WRAP:
- res.append("WRAP");
- break;
- case BorderMode::CONSTANT:
- res.append("CONSTANT");
- break;
- case BorderMode::REFLECT:
- res.append("REFLECT");
- break;
- case BorderMode::REFLECT_101:
- res.append("REFLECT_101");
- break;
- case BorderMode::REPLICATE:
- res.append("REPLICATE");
- break;
- case BorderMode::TRANSPARENT:
- res.append("TRANSPARENT");
- break;
- case BorderMode::ISOLATED:
- res.append("ISOLATED");
- break;
- }
- if (param().bmode == BorderMode::CONSTANT) {
- res.append(", " + std::to_string(param().border_val));
- }
- return res;
- }
-
- int WarpPerspectiveBase::get_real_coord(int p, int len) {
- auto bmode = param().bmode;
- if ((unsigned)p < (unsigned)len)
- ;
- else if (bmode == BorderMode::REPLICATE)
- p = p < 0 ? 0 : len - 1;
- else if (bmode == BorderMode::REFLECT || bmode == BorderMode::REFLECT_101) {
- int delta = (bmode == BorderMode::REFLECT_101);
- if (len == 1)
- return 0;
- do {
- if (p < 0)
- p = -p - 1 + delta;
- else
- p = len - 1 - (p - len) - delta;
- } while ((unsigned)p >= (unsigned)len);
- } else if (bmode == BorderMode::WRAP) {
- if (p < 0)
- p -= ((p - len + 1) / len) * len;
- /*
- if( p >= len )
- p %= len;
- */
- while (p >= len) {
- p -= len;
- }
- } else if (bmode == BorderMode::CONSTANT)
- p = -1;
- return p;
- }
-
- void WarpPerspectiveForward::check_exec(const TensorLayout& src,
- const TensorLayout& mat,
- const TensorLayout& mat_idx,
- const TensorLayout& dst,
- size_t workspace_in_bytes) {
- check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes);
- }
-
- void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx(
- const TensorLayout& src, const TensorLayout& mat,
- const TensorLayout& mat_idx, const TensorLayout& dst,
- size_t workspace_in_bytes) {
- check_layout_fwd(src, mat, mat_idx, dst);
- auto required_workspace_in_bytes =
- get_workspace_in_bytes(src, mat, mat_idx, dst);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- if (param().format != Param::Format::NHWC &&
- param().format != Param::Format::NCHW &&
- param().format != Param::Format::NCHW4 &&
- param().format != Param::Format::NHWC_NCHW &&
- param().format != Param::Format::NHWC_NCHW4_IC_SMALL &&
- param().format != Param::Format::NCHW_NCHW4_IC_SMALL &&
- param().format != Param::Format::NCHW64) {
- megdnn_assert(!mat_idx.ndim,
- "mat_idx not supported for current format");
- }
- }
-
- void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat,
- const TensorLayout& mat_idx,
- const TensorLayout& diff,
- const TensorLayout& grad,
- size_t workspace_in_bytes) {
- check_layout_fwd(grad, mat, mat_idx, diff);
- megdnn_assert(grad.dtype == dtype::Float32() DNN_INC_FLOAT16(
- || grad.dtype == dtype::BFloat16()),
- "Backward WarpPerspective only supports Float32/BFloat16.");
- auto required_workspace_in_bytes =
- get_workspace_in_bytes(mat, mat_idx, diff, grad);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- }
-
- void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src,
- const TensorLayout& mat,
- const TensorLayout& mat_idx,
- const TensorLayout& diff,
- const TensorLayout& grad,
- size_t workspace_in_bytes) {
- check_layout_fwd(src, mat, mat_idx, diff);
- megdnn_assert_eq_layout(mat, grad);
- megdnn_assert(grad.dtype == dtype::Float32() DNN_INC_FLOAT16(
- || grad.dtype == dtype::BFloat16()),
- "Backward WarpPerspective only supports Float32/BFloat16.");
- auto required_workspace_in_bytes =
- get_workspace_in_bytes(src, mat, mat_idx, diff, grad);
- megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
- }
-
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
|