Browse Source

feat(mgb): add arm resize nchwxx and naive nearest interp

GitOrigin-RevId: d5fbd59a30
release-1.6
Megvii Engine Team 3 years ago
parent
commit
bc9cfc277a
11 changed files with 736 additions and 370 deletions
  1. +5
    -1
      dnn/include/megdnn/oprs/cv.h
  2. +194
    -4
      dnn/src/arm_common/resize/opr_impl.cpp
  3. +12
    -1
      dnn/src/arm_common/resize/opr_impl.h
  4. +46
    -13
      dnn/src/common/resize.cpp
  5. +52
    -55
      dnn/src/fallback/resize/opr_impl.cpp
  6. +116
    -101
      dnn/src/naive/resize/opr_impl.cpp
  7. +14
    -9
      dnn/src/naive/resize/opr_impl.h
  8. +41
    -15
      dnn/test/arm_common/resize.cpp
  9. +52
    -7
      dnn/test/common/resize.h
  10. +200
    -162
      src/gopt/impl/tensor_reformat.cpp
  11. +4
    -2
      src/opr/impl/imgproc.cpp

+ 5
- 1
dnn/include/megdnn/oprs/cv.h View File

@@ -197,7 +197,11 @@ public:

protected:
//! get origin coord
std::pair<float, int> get_origin_coord(float scale, int size, int idx, bool cubic=false);
std::pair<float, int> get_cubic_coord(float scale, int idx);

std::tuple<float, int, float, int> get_nearest_linear_coord(
InterpolationMode imode, float scale, int size, int idx);

//! get nearest index in src
int get_nearest_src(float scale, int size, int idx);



+ 194
- 4
dnn/src/arm_common/resize/opr_impl.cpp View File

@@ -6,12 +6,14 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/arm_common/resize/opr_impl.h"
#include "src/arm_common/handle.h"
#include "src/arm_common/resize/resize_cv.h"
#include "src/arm_common/simd_macro/marm_neon.h"

using namespace megdnn;
using namespace arm_common;
@@ -19,9 +21,58 @@ using namespace arm_common;
void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW ||
(src.layout[3] != 1 && src.layout[3] != 3) ||
!is_nhwc_contig_wc(src.layout)) {

if (param().format == param::Resize::Format::NCHW44 ||
param().format == param::Resize::Format::NCHW88) {
bool is_contiguous =
src.layout.is_contiguous() && dst.layout.is_contiguous();
bool dtype_same = src.layout.dtype == dst.layout.dtype;
bool nchw44_enable = param().format == param::Resize::Format::NCHW44 &&
src.layout.dtype == dtype::Float32();
bool nchw88_enable =
param().format == param::Resize::Format::NCHW88 &&
DNN_FLOAT16_SELECT(src.layout.dtype == dtype::Float16(), false);
bool interp_supported =
param().imode ==
param::Resize::InterpolationMode::INTER_NEAREST ||
param().imode == param::Resize::InterpolationMode::INTER_LINEAR;
bool is_upsample2 =
param().imode ==
param::Resize::InterpolationMode::INTER_NEAREST &&
src.layout.shape[2] * 2 == dst.layout.shape[2] &&
src.layout.shape[3] * 2 == dst.layout.shape[3];
bool need_fallback = !is_contiguous || !dtype_same ||
!interp_supported ||
(!nchw44_enable && !nchw88_enable);

if (need_fallback) {
fallback::ResizeImpl::exec(src, dst, workspace);
} else if (nchw44_enable) {
auto kern_param = KernParam<float>::from_tensors(
param().format, param().imode, src, dst, workspace);
if (is_upsample2) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
kern_nearest_upsample2_pack_simd_width(src, dst));
} else {
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw44_fp32(kern_param));
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} else if (nchw88_enable) {
auto kern_param = KernParam<dt_float16>::from_tensors(
param().format, param().imode, src, dst, workspace);
if (is_upsample2) {
MEGDNN_DISPATCH_CPU_KERN_OPR(
kern_nearest_upsample2_pack_simd_width(src, dst));
} else {
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw88_fp16(kern_param));
}
#endif
} else {
fallback::ResizeImpl::exec(src, dst, workspace);
}
} else if (param().format == param::Resize::Format::NCHW ||
(src.layout[3] != 1 && src.layout[3] != 3) ||
!is_nhwc_contig_wc(src.layout)) {
fallback::ResizeImpl::exec(src, dst, workspace);
} else {
megdnn_assert(param().format == param::Resize::Format::NHWC,
@@ -30,4 +81,143 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
}
}

template <typename ctype>
void ResizeImpl::kern_nchw44_fp32(const KernParam<ctype>& kern_param) {
UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;

for (size_t n = 0; n < N; ++n) {
for (size_t c = 0; c < C / 4; ++c) {
for (size_t oh = 0; oh < OH; ++oh) {
for (size_t ow = 0; ow < OW; ++ow) {
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;

std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord(
kern_param.imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord(
kern_param.imode, scale_w, IW, ow);

#define SRC_ADDRESS(ih, iw) \
(sptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 4)
#define DST_ADDRESS(oh, ow) \
(dptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 4)
float32x4_t r0 = vld1q_f32(SRC_ADDRESS(ih0, iw0));
float32_t a0 = ah0 * aw0;
float32x4_t r1 = vld1q_f32(SRC_ADDRESS(ih0, iw1));
float32_t a1 = ah0 * aw1;
float32x4_t r2 = vld1q_f32(SRC_ADDRESS(ih1, iw0));
float32_t a2 = ah1 * aw0;
float32x4_t r3 = vld1q_f32(SRC_ADDRESS(ih1, iw1));
float32_t a3 = ah1 * aw1;

r0 = vmulq_n_f32(r0, a0);
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
r0 = vfmaq_n_f32(r0, r1, a1);
r0 = vfmaq_n_f32(r0, r2, a2);
r0 = vfmaq_n_f32(r0, r3, a3);
#else
r0 = vaddq_f32(r0, vmulq_n_f32(r1, a1));
r0 = vaddq_f32(r0, vmulq_n_f32(r2, a2));
r0 = vaddq_f32(r0, vmulq_n_f32(r3, a3));
#endif

vst1q_f32(DST_ADDRESS(oh, ow), r0);
#undef SRC_ADDRESS
#undef DST_ADDRESS
}
}
}
}
}

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename ctype>
void ResizeImpl::kern_nchw88_fp16(const KernParam<ctype>& kern_param) {
UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;
const float16_t* src_ptr = reinterpret_cast<float16_t*>(sptr);
float16_t* dst_ptr = reinterpret_cast<float16_t*>(dptr);

for (size_t n = 0; n < N; ++n) {
for (size_t c = 0; c < C / 8; ++c) {
for (size_t oh = 0; oh < OH; ++oh) {
for (size_t ow = 0; ow < OW; ++ow) {
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;

std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord(
kern_param.imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord(
kern_param.imode, scale_w, IW, ow);

#define SRC_ADDRESS(ih, iw) \
(src_ptr + n * C * IH * IW + (c * IH * IW + ih * IW + iw) * 8)
#define DST_ADDRESS(oh, ow) \
(dst_ptr + n * C * OH * OW + (c * OH * OW + oh * OW + ow) * 8)
float16x8_t r0 = vld1q_f16(SRC_ADDRESS(ih0, iw0));
float32_t a0 = ah0 * aw0;
float16x8_t r1 = vld1q_f16(SRC_ADDRESS(ih0, iw1));
float32_t a1 = ah0 * aw1;
float16x8_t r2 = vld1q_f16(SRC_ADDRESS(ih1, iw0));
float32_t a2 = ah1 * aw0;
float16x8_t r3 = vld1q_f16(SRC_ADDRESS(ih1, iw1));
float32_t a3 = ah1 * aw1;

r0 = vmulq_n_f16(r0, a0);
#if defined(__ARM_FEATURE_FMA) && defined(__aarch64__)
r0 = vfmaq_n_f16(r0, r1, a1);
r0 = vfmaq_n_f16(r0, r2, a2);
r0 = vfmaq_n_f16(r0, r3, a3);
#else
r0 = vaddq_f16(r0, vmulq_n_f16(r1, a1));
r0 = vaddq_f16(r0, vmulq_n_f16(r2, a2));
r0 = vaddq_f16(r0, vmulq_n_f16(r3, a3));
#endif

vst1q_f16(DST_ADDRESS(oh, ow), r0);
#undef SRC_ADDRESS
#undef DST_ADDRESS
}
}
}
}
}
#endif

void ResizeImpl::kern_nearest_upsample2_pack_simd_width(
_megdnn_tensor_in src, _megdnn_tensor_out dst) {
const uint8_t* src_ptr = reinterpret_cast<uint8_t*>(src.raw_ptr);
uint8_t* dst_ptr = reinterpret_cast<uint8_t*>(dst.raw_ptr);

size_t S = 2;
size_t N = src.layout.shape[0];
size_t IC = src.layout.shape[1];
size_t IH = src.layout.shape[2];
size_t IW = src.layout.shape[3];
size_t OH = dst.layout.shape[2];
size_t OW = dst.layout.shape[3];

for (size_t i = 0; i < N * IC; ++i) {
for (size_t ih = 0; ih < IH; ++ih) {
for (size_t iw = 0; iw < IW; ++iw) {
size_t oh = ih * S;
size_t ow = iw * S;
uint8x16_t r0 = vld1q_u8(src_ptr + i * IH * IW * 16 +
ih * IW * 16 + iw * 16);

for (size_t fh = 0; fh < S; ++fh) {
for (size_t fw = 0; fw < S; ++fw) {
vst1q_u8(dst_ptr + i * OH * OW * 16 +
(oh + fh) * OW * 16 + (ow + fw) * 16,
r0);
}
}
}
}
}
}

// vim: syntax=cpp.doxygen

+ 12
- 1
dnn/src/arm_common/resize/opr_impl.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
@@ -25,6 +26,16 @@ public:
const TensorLayout&) override {
return 0;
}

private:
template <typename ctype>
void kern_nchw44_fp32(const KernParam<ctype>& kern_param);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <typename ctype>
void kern_nchw88_fp16(const KernParam<ctype>& kern_param);
#endif
void kern_nearest_upsample2_pack_simd_width(_megdnn_tensor_in src,
_megdnn_tensor_out dst);
};

} // namespace arm_common


+ 46
- 13
dnn/src/common/resize.cpp View File

@@ -40,11 +40,29 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src,
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8);
megdnn_assert(src.shape[4] == 4);
megdnn_assert(dst.shape[4] == 4);
} else if (param().format == Param::Format::NCHW44) {
megdnn_assert(src.ndim == 5);
megdnn_assert(src.shape[4] == 4);
megdnn_assert(dst.shape[4] == 4);
megdnn_assert(param().imode ==
param::Resize::InterpolationMode::INTER_LINEAR ||
param().imode ==
param::Resize::InterpolationMode::INTER_NEAREST);
} else if (param().format == Param::Format::NCHW88) {
megdnn_assert(src.ndim == 5);
megdnn_assert(src.shape[4] == 8);
megdnn_assert(dst.shape[4] == 8);
megdnn_assert(param().imode ==
param::Resize::InterpolationMode::INTER_LINEAR ||
param().imode ==
param::Resize::InterpolationMode::INTER_NEAREST);
} else {
megdnn_assert(param().format == Param::Format::NHWCD4,
"invalid resize tensor format");
megdnn_assert(param().imode ==
param::Resize::InterpolationMode::INTER_LINEAR);
param::Resize::InterpolationMode::INTER_LINEAR ||
param().imode ==
param::Resize::InterpolationMode::INTER_NEAREST);
megdnn_assert(dst.shape[2] == src.shape[2], "%s", errmsg().c_str());
}
}
@@ -67,24 +85,39 @@ void ResizeBackward::check_exec(const TensorLayout& diff,
"Backward resize only supports Float32 and NCHW.");
}

std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
int idx, bool cubic) {
//! copy from resize_cv.cpp
std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) {
float alpha = (idx + 0.5f) / scale - 0.5f;
int origin_idx = static_cast<int>(floor(alpha));
alpha -= origin_idx;
if (!cubic) {
if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
}
}
return {alpha, origin_idx};
}

std::tuple<float, int, float, int> ResizeBase::get_nearest_linear_coord(
InterpolationMode imode, float scale, int size, int idx) {
if (size == 1) {
return std::make_tuple(1.0f, 0, 0.0f, 0);
}

float alpha = (idx + 0.5f) / scale - 0.5f;
int origin_idx = static_cast<int>(floor(alpha));
alpha -= origin_idx;

if (imode == InterpolationMode::INTER_NEAREST) {
origin_idx = get_nearest_src(scale, size, idx);
alpha = 0;
}

if (origin_idx < 0) {
origin_idx = 0;
alpha = 0;
} else if (origin_idx + 1 >= size) {
origin_idx = size - 2;
alpha = 1;
}

return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1);
}

int ResizeBase::get_nearest_src(float scale, int size, int idx) {
return std::min(static_cast<int>(idx / scale), size - 1);
}


+ 52
- 55
dnn/src/fallback/resize/opr_impl.cpp View File

@@ -6,13 +6,14 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/fallback/resize/opr_impl.h"
#include <vector>
#include "src/fallback/handle.h"
#include "src/common/rounding_converter.cuh"
#include "src/fallback/handle.h"

using namespace megdnn;
using namespace fallback;
@@ -30,37 +31,36 @@ void ResizeImpl::kern_fallback(const KernParam<ctype>& kern_param) {
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;

auto build_table = [this](float scale, int isize,
int osize) -> std::vector<std::pair<float, int>> {
std::vector<std::pair<float, int>> table;
rep(i, osize) { table.push_back(get_origin_coord(scale, isize, i)); }
auto build_table = [this](InterpolationMode imode, float scale, int isize,
int osize) {
std::vector<std::tuple<float, int, float, int>> table;
rep(i, osize) {
table.push_back(get_nearest_linear_coord(imode, scale, isize, i));
}
return table;
};

auto table_h = build_table(scale_h, IH, OH);
auto table_w = build_table(scale_w, IW, OW);
auto table_h = build_table(kern_param.imode, scale_h, IH, OH);
auto table_w = build_table(kern_param.imode, scale_w, IW, OW);

rep(n, N) {
rep(c, static_cast<int>(C)) {
rep(oh, OH) {
auto coord_h = table_h[oh];
float alphah = coord_h.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
float ah0, ah1, aw0, aw1;
int ih0, ih1, iw0, iw1;
std::tie(ah0, ih0, ah1, ih1) = table_h[oh];
rep(ow, OW) {
auto coord_w = table_w[ow];
float alphaw = coord_w.first;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
std::tie(aw0, iw0, aw1, iw1) = table_w[ow];
dptr[c * OH * OW + oh * OW + ow] = output_converter(
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * (1.0f - alphah) +
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] *
alphaw * (1.0f - alphah) +
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * alphah +
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] *
alphaw * alphah);
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * ah0 *
aw0 +
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * ah0 *
aw1 +
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * ah1 *
aw0 +
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * ah1 *
aw1);
}
}
}
@@ -76,35 +76,31 @@ void ResizeImpl::kern_fallback_nhwc(const KernParam<ctype>& kern_param) {
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;

auto build_table = [this](float scale, int isize,
int osize) -> std::vector<std::pair<float, int>> {
std::vector<std::pair<float, int>> table;
rep(i, osize) { table.push_back(get_origin_coord(scale, isize, i)); }
auto build_table = [this](InterpolationMode imode, float scale, int isize,
int osize) {
std::vector<std::tuple<float, int, float, int>> table;
rep(i, osize) {
table.push_back(get_nearest_linear_coord(imode, scale, isize, i));
}
return table;
};
auto table_h = build_table(scale_h, IH, OH);
auto table_w = build_table(scale_w, IW, OW);
auto table_h = build_table(kern_param.imode, scale_h, IH, OH);
auto table_w = build_table(kern_param.imode, scale_w, IW, OW);

rep(n, N) {
rep(oh, OH) {
auto coord_h = table_h[oh];
float alphah = coord_h.first;
int ih0 = coord_h.second;
int ih1 = ih0 + 1;
float ah0, ah1, aw0, aw1;
int ih0, ih1, iw0, iw1;
std::tie(ah0, ih0, ah1, ih1) = table_h[oh];
rep(ow, OW) {
auto coord_w = table_w[ow];
float alphaw = coord_w.first;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
std::tie(aw0, iw0, aw1, iw1) = table_w[ow];
rep(c, C) {
dptr[(oh * OW + ow) * C + c] = output_converter(
sptr[(ih0 * IW + iw0) * C + c] * (1.0f - alphaw) *
(1.0f - alphah) +
sptr[(ih0 * IW + iw1) * C + c] * alphaw *
(1.0f - alphah) +
sptr[(ih1 * IW + iw0) * C + c] * (1.0f - alphaw) *
alphah +
sptr[(ih1 * IW + iw1) * C + c] * alphaw * alphah);
sptr[(ih0 * IW + iw0) * C + c] * ah0 * aw0 +
sptr[(ih0 * IW + iw1) * C + c] * ah0 * aw1 +
sptr[(ih1 * IW + iw0) * C + c] * ah1 * aw0 +
sptr[(ih1 * IW + iw1) * C + c] * ah1 * aw1);
}
}
}
@@ -117,6 +113,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
if (param().format == param::Resize::Format::NCHW4 ||
param().format == param::Resize::Format::NCHW44 ||
param().format == param::Resize::Format::NCHW88 ||
(param().format == param::Resize::Format::NCHW &&
param().imode != param::Resize::InterpolationMode::INTER_LINEAR)) {
naive::ResizeImpl::exec(src, dst, workspace);
@@ -125,12 +123,12 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
if ((param().format == param::Resize::Format::NCHW ||
(src.layout[3] != 1 && src.layout[3] != 3)) ||
(param().imode == param::Resize::InterpolationMode::LINEAR)) {
#define cb(dt, ct) \
case DTypeTrait<dt>::enumv: { \
auto kparam = KernParam<ct>::from_tensors(param().format, src, dst, \
workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback(kparam)); \
return; \
#define cb(dt, ct) \
case DTypeTrait<dt>::enumv: { \
auto kparam = KernParam<ct>::from_tensors( \
param().format, param().imode, src, dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_fallback(kparam)); \
return; \
}

switch (src.layout.dtype.enumv()) {
@@ -141,10 +139,9 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
cb(dtype::Uint8, uint8_t);
cb(dtype::Quantized8Asymm, uint8_t);
default:
megdnn_throw(
ssprintf("Unsupported input DType in Resize: %s",
src.layout.dtype.name())
.c_str());
megdnn_throw(ssprintf("Unsupported input DType in Resize: %s",
src.layout.dtype.name())
.c_str());
return;
}



+ 116
- 101
dnn/src/naive/resize/opr_impl.cpp View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/naive/resize/opr_impl.h"
@@ -27,10 +28,11 @@ using namespace resize;

template <typename ctype>
ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors(
Format format, _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
Format format, InterpolationMode imode, _megdnn_tensor_in src,
_megdnn_tensor_out dst, _megdnn_workspace workspace) {
KernParam<ctype> ret;
ret.format = format;
ret.imode = imode;
ret.n = src.layout.shape[0];
if (format == Format::NCHW) {
ret.c = src.layout.shape[1];
@@ -54,6 +56,18 @@ ResizeImpl::KernParam<ctype> ResizeImpl::KernParam<ctype>::from_tensors(
ret.iw = src.layout.shape[3];
ret.oh = dst.layout.shape[2];
ret.ow = dst.layout.shape[3];
} else if (format == Format::NCHW44) {
ret.c = src.layout.shape[1] * 4;
ret.ih = src.layout.shape[2];
ret.iw = src.layout.shape[3];
ret.oh = dst.layout.shape[2];
ret.ow = dst.layout.shape[3];
} else if (format == Format::NCHW88) {
ret.c = src.layout.shape[1] * 8;
ret.ih = src.layout.shape[2];
ret.iw = src.layout.shape[3];
ret.oh = dst.layout.shape[2];
ret.ow = dst.layout.shape[3];
} else {
megdnn_assert(format == Format::NHWCD4);
ret.c = src.layout.shape[2] * 4;
@@ -115,33 +129,30 @@ void ResizeImpl::kern_nchw(const KernParam<ctype>& kern_param,
break;
}
case InterpolationMode::INTER_LINEAR: {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);

float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;

int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord(
kern_param.imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord(
kern_param.imode, scale_w, IW, ow);

rep(c, static_cast<int>(C)) {
dptr[c * OH * OW + oh * OW + ow] = output_converter(
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * (1.0f - alphah) +
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] *
alphaw * (1.0f - alphah) +
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] *
(1.0f - alphaw) * alphah +
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] *
alphaw * alphah);
sptr[c * S_IC + ih0 * S_IH + iw0 * S_IW] * ah0 *
aw0 +
sptr[c * S_IC + ih0 * S_IH + iw1 * S_IW] * ah0 *
aw1 +
sptr[c * S_IC + ih1 * S_IH + iw0 * S_IW] * ah1 *
aw0 +
sptr[c * S_IC + ih1 * S_IH + iw1 * S_IW] * ah1 *
aw1);
}
break;
}
case InterpolationMode::INTER_CUBIC: {
auto coord_h = get_origin_coord(scale_h, IH, oh, true);
auto coord_w = get_origin_coord(scale_w, IW, ow, true);
auto coord_h = get_cubic_coord(scale_h, oh);
auto coord_w = get_cubic_coord(scale_w, ow);

float alphah = coord_h.first;
float alphaw = coord_w.first;
@@ -193,7 +204,19 @@ void ResizeImpl::kern_naive(const KernParam<ctype>& kern_param) {
return;
} else if (kern_param.format == Format::NCHW4) {
MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(2)) {
kern_naive_nchw4(kern_param);
kern_naive_nchwx<ctype, 4>(kern_param);
}
MIDOUT_END();
return;
} else if (kern_param.format == Format::NCHW44) {
MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(3)) {
kern_naive_nchwx<ctype, 4>(kern_param);
}
MIDOUT_END();
return;
} else if (kern_param.format == Format::NCHW88) {
MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(4)) {
kern_naive_nchwx<ctype, 8>(kern_param);
}
MIDOUT_END();
return;
@@ -209,25 +232,20 @@ void ResizeImpl::kern_naive_nhwc(const KernParam<ctype>& kern_param) {

rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;

float alphah = coord_h.first;
float alphaw = coord_w.first;
std::tie(ah0, ih0, ah1, ih1) =
get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) =
get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow);

int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, C) {
dptr[(oh * OW + ow) * C + c] = output_converter(
sptr[(ih0 * IW + iw0) * C + c] * (1.0f - alphaw) *
(1.0f - alphah) +
sptr[(ih0 * IW + iw1) * C + c] * alphaw *
(1.0f - alphah) +
sptr[(ih1 * IW + iw0) * C + c] * (1.0f - alphaw) *
alphah +
sptr[(ih1 * IW + iw1) * C + c] * alphaw * alphah);
sptr[(ih0 * IW + iw0) * C + c] * ah0 * aw0 +
sptr[(ih0 * IW + iw1) * C + c] * ah0 * aw1 +
sptr[(ih1 * IW + iw0) * C + c] * ah1 * aw0 +
sptr[(ih1 * IW + iw1) * C + c] * ah1 * aw1);
}
}
sptr += C * IH * IW;
@@ -251,26 +269,20 @@ void ResizeImpl::kern_naive_nhwcd4(const KernParam<ctype>& kern_param) {

rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;

float alphah = coord_h.first;
float alphaw = coord_w.first;
std::tie(ah0, ih0, ah1, ih1) =
get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) =
get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow);

int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, C) {
dptr[get_tensor_addr(oh, ow, c, OW, C)] = output_converter(
sptr[get_tensor_addr(ih0, iw0, c, IW, C)] *
(1.0f - alphaw) * (1.0f - alphah) +
sptr[get_tensor_addr(ih0, iw1, c, IW, C)] * alphaw *
(1.0f - alphah) +
sptr[get_tensor_addr(ih1, iw0, c, IW, C)] *
(1.0f - alphaw) * alphah +
sptr[get_tensor_addr(ih1, iw1, c, IW, C)] * alphaw *
alphah);
sptr[get_tensor_addr(ih0, iw0, c, IW, C)] * ah0 * aw0 +
sptr[get_tensor_addr(ih0, iw1, c, IW, C)] * ah0 * aw1 +
sptr[get_tensor_addr(ih1, iw0, c, IW, C)] * ah1 * aw0 +
sptr[get_tensor_addr(ih1, iw1, c, IW, C)] * ah1 * aw1);
}
}
sptr += IH * (C / 4) * IW * 4;
@@ -278,41 +290,46 @@ void ResizeImpl::kern_naive_nhwcd4(const KernParam<ctype>& kern_param) {
}
}

template <typename ctype>
void ResizeImpl::kern_naive_nchw4(const KernParam<ctype>& kern_param) {
template <typename ctype, size_t pack_size>
void ResizeImpl::kern_naive_nchwx(const KernParam<ctype>& kern_param) {
UNPACK_RESIZE_FWD_KERN_PARAM(kern_param);
rounding::RoundingConverter<ctype> output_converter;
float scale_h = static_cast<float>(OH) / IH;
float scale_w = static_cast<float>(OW) / IW;

megdnn_assert(pack_size == 4 || pack_size == 8);
size_t log_pack_size = 2;
if (pack_size == 8) {
log_pack_size = 3;
}

auto get_tensor_addr = [&](size_t h, size_t w, size_t c, size_t H, size_t W,
size_t C) -> size_t {
megdnn_assert((C & 0x3) == 0);
return (((c >> 2) * H * W + h * W + w) << 2) + (c & 0b11);
megdnn_assert((C & (pack_size - 1)) == 0);
return (((c >> log_pack_size) * H * W + h * W + w) << log_pack_size) +
(c & (pack_size - 1));
};

rep(n, N) {
rep(oh, OH) rep(ow, OW) {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;

float alphah = coord_h.first;
float alphaw = coord_w.first;
std::tie(ah0, ih0, ah1, ih1) =
get_nearest_linear_coord(kern_param.imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) =
get_nearest_linear_coord(kern_param.imode, scale_w, IW, ow);

int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
rep(c, C) {
dptr[get_tensor_addr(oh, ow, c, OH, OW, C)] = output_converter(
sptr[get_tensor_addr(ih0, iw0, c, IH, IW, C)] *
(1.0f - alphaw) * (1.0f - alphah) +
sptr[get_tensor_addr(ih0, iw1, c, IH, IW, C)] * alphaw *
(1.0f - alphah) +
sptr[get_tensor_addr(ih1, iw0, c, IH, IW, C)] *
(1.0f - alphaw) * alphah +
sptr[get_tensor_addr(ih1, iw1, c, IH, IW, C)] * alphaw *
alphah);
sptr[get_tensor_addr(ih0, iw0, c, IH, IW, C)] * ah0 *
aw0 +
sptr[get_tensor_addr(ih0, iw1, c, IH, IW, C)] * ah0 *
aw1 +
sptr[get_tensor_addr(ih1, iw0, c, IH, IW, C)] * ah1 *
aw0 +
sptr[get_tensor_addr(ih1, iw1, c, IH, IW, C)] * ah1 *
aw1);
}
}
sptr += IH * IW * C;
@@ -327,8 +344,8 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
#define cb(dt, ct, _midout_iv) \
case DTypeTrait<dt>::enumv: { \
MIDOUT_BEGIN(megdnn_naive_resize_nchw, midout_iv(_midout_iv)) { \
auto kparam = KernParam<ct>::from_tensors(param().format, src, \
dst, workspace); \
auto kparam = KernParam<ct>::from_tensors( \
param().format, param().imode, src, dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_nchw(kparam, param().imode)); \
} \
MIDOUT_END(); \
@@ -356,15 +373,15 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst,
if (((src.layout[3] != 1 && src.layout[3] != 3) ||
!is_nhwc_contig_wc(src.layout)) ||
(param().imode == param::Resize::InterpolationMode::LINEAR)) {
#define cb(dt, ct, _midout_iv) \
case DTypeTrait<dt>::enumv: { \
MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(_midout_iv)) { \
auto kparam = KernParam<ct>::from_tensors(param().format, src, \
dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \
} \
MIDOUT_END(); \
return; \
#define cb(dt, ct, _midout_iv) \
case DTypeTrait<dt>::enumv: { \
MIDOUT_BEGIN(megdnn_naive_resize_layout, midout_iv(_midout_iv)) { \
auto kparam = KernParam<ct>::from_tensors( \
param().format, param().imode, src, dst, workspace); \
MEGDNN_DISPATCH_CPU_KERN_OPR(kern_naive(kparam)); \
} \
MIDOUT_END(); \
return; \
}

switch (src.layout.dtype.enumv()) {
@@ -409,27 +426,24 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
rep(oh, OH) rep(ow, OW) {
switch (param().imode) {
case InterpolationMode::INTER_LINEAR: {
auto coord_h = get_origin_coord(scale_h, IH, oh);
auto coord_w = get_origin_coord(scale_w, IW, ow);

float alphah = coord_h.first;
float alphaw = coord_w.first;
int ih0, ih1, iw0, iw1;
float ah0, ah1, aw0, aw1;

int ih0 = coord_h.second;
int ih1 = ih0 + 1;
int iw0 = coord_w.second;
int iw1 = iw0 + 1;
std::tie(ah0, ih0, ah1, ih1) = get_nearest_linear_coord(
param().imode, scale_h, IH, oh);
std::tie(aw0, iw0, aw1, iw1) = get_nearest_linear_coord(
param().imode, scale_w, IW, ow);

rep(c, C) {
float hidden = hptr[c * OH * OW + oh * OW + ow];
sptr[c * IH * IW + ih0 * IW + iw0] +=
(1.0f - alphaw) * (1.0f - alphah) * hidden;
ah0 * aw0 * hidden;
sptr[c * IH * IW + ih1 * IW + iw0] +=
(1.0f - alphaw) * alphah * hidden;
ah1 * aw0 * hidden;
sptr[c * IH * IW + ih0 * IW + iw1] +=
alphaw * (1.0f - alphah) * hidden;
ah0 * aw1 * hidden;
sptr[c * IH * IW + ih1 * IW + iw1] +=
alphaw * alphah * hidden;
ah1 * aw1 * hidden;
}
break;
}
@@ -443,8 +457,8 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
break;
}
case InterpolationMode::INTER_CUBIC: {
auto coord_h = get_origin_coord(scale_h, IH, oh, true);
auto coord_w = get_origin_coord(scale_w, IW, ow, true);
auto coord_h = get_cubic_coord(scale_h, oh);
auto coord_w = get_cubic_coord(scale_w, ow);

float alphah = coord_h.first;
float alphaw = coord_w.first;
@@ -460,7 +474,8 @@ void ResizeBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
rep(kh, ksize) {
int h = saturate<int, int>(ih0 + kh, 0, IH - 1);
rep(kw, ksize) {
int w = saturate<int, int>(iw0 + kw, 0, IW - 1);
int w = saturate<int, int>(iw0 + kw, 0,
IW - 1);
sptr[c * IH * IW + h * IW + w] +=
hptr[c * OH * OW + oh * OW + ow] *
h_coeff[kh] * w_coeff[kw];


+ 14
- 9
dnn/src/naive/resize/opr_impl.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once

@@ -19,15 +20,18 @@ namespace naive {
class ResizeImpl : public Resize {
public:
using Format = Param::Format;
using InterpolationMode = Param::InterpolationMode;
template <typename ctype>
struct KernParam {
Format format;
InterpolationMode imode;
size_t n, c, ih, iw, oh, ow;
ptrdiff_t s_in, s_ic, s_ih, s_iw;
ctype *sptr, *dptr;
Workspace workspace;

static KernParam from_tensors(Format format, _megdnn_tensor_in src,
static KernParam from_tensors(Format format, InterpolationMode imode,
_megdnn_tensor_in src,
_megdnn_tensor_out dst,
_megdnn_workspace workspace);
};
@@ -41,6 +45,7 @@ public:
const TensorLayout&) override {
return 0;
}

private:
// ctype: C type of input data type.
template <typename ctype>
@@ -55,8 +60,8 @@ private:
template <typename ctype>
void kern_naive_nhwcd4(const KernParam<ctype>& kern_param);

template <typename ctype>
void kern_naive_nchw4(const KernParam<ctype>& kern_param);
template <typename ctype, size_t pack_size>
void kern_naive_nchwx(const KernParam<ctype>& kern_param);

}; // class ResizeImpl

@@ -65,15 +70,15 @@ private:
ctype* __restrict sptr = p.sptr; \
ctype* __restrict dptr = p.dptr;

#define UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(p) \
UNPACK_RESIZE_FWD_KERN_PARAM(p) \
#define UNPACK_RESIZE_FWD_KERN_PARAM_WITH_STRIDE(p) \
UNPACK_RESIZE_FWD_KERN_PARAM(p) \
auto S_IN = p.s_in, S_IC = p.s_ic, S_IH = p.s_ih, S_IW = p.s_iw;

class ResizeBackwardImpl: public ResizeBackward {
class ResizeBackwardImpl : public ResizeBackward {
public:
using ResizeBackward::ResizeBackward;
void exec(_megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&) override {
return 0;


+ 41
- 15
dnn/test/arm_common/resize.cpp View File

@@ -6,40 +6,66 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/arm_common/fixture.h"
#include "test/common/resize.h"
#include "test/arm_common/fixture.h"
#include "test/common/checker.h"

namespace megdnn {
namespace test {

TEST_F(ARM_COMMON, RESIZE_CV)
{
TEST_F(ARM_COMMON, RESIZE_CV) {
using namespace resize;
std::vector<TestArg> args = get_cv_args();
Checker<Resize> checker(handle());

for (auto &&arg: args) {
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_epsilon(1 + 1e-3)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Uint8())
.execs({arg.src, arg.dst});
.set_epsilon(1 + 1e-3)
.set_dtype(0, dtype::Uint8())
.set_dtype(1, dtype::Uint8())
.execs({arg.src, arg.dst});
}

for (auto &&arg: args) {
for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.execs({arg.src, arg.dst});
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.execs({arg.src, arg.dst});
}
}

TEST_F(ARM_COMMON, RESIZE_NCHW44) {
using namespace resize;
std::vector<TestArg> args = get_nchw44_args();
Checker<Resize> checker(handle());

for (auto&& arg : args) {
checker.set_param(arg.param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.execs({arg.src, arg.dst});
}
}

TEST_F(ARM_COMMON, RESIZE_NCHW88) {
using namespace resize;
std::vector<TestArg> args = get_nchw88_args();
Checker<Resize> checker(handle());

for (auto&& arg : args) {
checker.set_param(arg.param)
.set_epsilon(0.01)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.execs({arg.src, arg.dst});
}
}

} // namespace test
} // namespace megdnn
} // namespace test
} // namespace megdnn

// vim: syntax=cpp.doxygen


+ 52
- 7
dnn/test/common/resize.h View File

@@ -6,12 +6,13 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/opr_param_defs.h"
#include "megdnn/basic_types.h"
#include <iostream>
#include "megdnn/basic_types.h"
#include "megdnn/opr_param_defs.h"

#include "./rng.h"
namespace megdnn {
@@ -68,13 +69,15 @@ static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) {
std::vector<TestArg> args;
set_nchw_args(args);

if(imode == IMode::INTER_LINEAR) {
//! test NHWC with ch != 1 or ch != 3
if (imode == IMode::INTER_LINEAR) {
//! test NHWC with ch != 1 or ch != 3
param::Resize param;
param.format = param::Resize::Format::NHWC;
param.imode = imode;
args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 4, 6, 4});
args.emplace_back(param, TensorShape{2, 4, 6, 4}, TensorShape{2, 2, 3, 4});
args.emplace_back(param, TensorShape{2, 2, 3, 4},
TensorShape{2, 4, 6, 4});
args.emplace_back(param, TensorShape{2, 4, 6, 4},
TensorShape{2, 2, 3, 4});
}
return args;
}
@@ -108,6 +111,48 @@ static inline std::vector<TestArg> get_nchw4_args() {
return args;
}

static inline std::vector<TestArg> get_nchw44_args() {
std::vector<TestArg> args;

param::Resize param;
param.format = param::Resize::Format::NCHW44;
param.imode = param::Resize::InterpolationMode::LINEAR;
rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
args.emplace_back(
param,
TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 4ul},
TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 4ul});

param.imode = param::Resize::InterpolationMode::NEAREST;
rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
args.emplace_back(
param,
TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 4ul},
TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 4ul});
return args;
}

static inline std::vector<TestArg> get_nchw88_args() {
std::vector<TestArg> args;

param::Resize param;
param.format = param::Resize::Format::NCHW88;
param.imode = param::Resize::InterpolationMode::LINEAR;
rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
args.emplace_back(
param,
TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 8ul},
TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 8ul});

param.imode = param::Resize::InterpolationMode::NEAREST;
rep(n, 4ul) rep(c, 4ul) rep(ih, 4ul) rep(iw, 4ul) rep(oh, 4ul) rep(ow, 4ul)
args.emplace_back(
param,
TensorShape{n + 1ul, c + 1ul, ih + 1ul, iw + 1ul, 8ul},
TensorShape{n + 1ul, c + 1ul, oh + 1ul, ow + 1ul, 8ul});
return args;
}

static inline std::vector<TestArg> get_cv_args() {
std::vector<TestArg> args;



+ 200
- 162
src/gopt/impl/tensor_reformat.cpp View File

@@ -68,87 +68,90 @@ using namespace gopt;
* oprs should not get involved in any actual computing.
*/
MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
cg::SingleCNOperatorNodeBase) // {
cg::SingleCNOperatorNodeBase) // {
public:
//! relayout type of this opr
enum class LayoutType {
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose
///< channel size less than 4
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout
//!< to nchw4 layout whose
//! channel size less than 4
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout
WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to
//!< nchw88 layout
WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout
//!< to nchw88 layout
//!< the weight layout of input is nchw output is nchw88, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
WEIGHT_HYBIRD_NCHW_NCHW88,
WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44
//!< layout
WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to
//!< nchw44 layout
WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout
//!< to nchw44 layout
//!< the weight layout of input is nchw output is nchw44, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
WEIGHT_HYBIRD_NCHW_NCHW44,
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout dense
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout group
NCHW32_TO_NCHW, //! <from nchw32 layout to nchw layout
NCHW32_TO_NCHW64, //! <from nchw32 layout to nchw64 layout
NCHW64_TO_NCHW, //! <from nchw64 layout to nchw layout
NCHW64_TO_NCHW4, //! <from nchw64 layout to nchw4 layout
NCHW64_TO_NCHW32, //! <from nchw64 layout to nchw32 layout
NCHW_TO_NCHW64, //! <from nchw layout to nchw64 layout
NCHW_TO_NCHW32, //! <from nchw layout to nchw64 layout
NCHW4_TO_NCHW64, //! <from nchw4 layout to nchw64 layout
NCHW_TO_NHWC, //! <NHWC related layout transformation
NCHW4_TO_NHWC,
NCHW32_TO_NHWC,
NCHW64_TO_NHWC,
NHWC_TO_NCHW,
NHWC_TO_NCHW4,
NHWC_TO_NCHW32,
NHWC_TO_NCHW64,
};
//! relayout type of this opr
enum class LayoutType {
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose
///< channel size less than 4
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout
//!< to nchw4 layout whose
//! channel size less than 4
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout
WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to
//!< nchw88 layout
WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout
//!< to nchw88 layout
//!< the weight layout of input is nchw output is nchw88, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
WEIGHT_HYBIRD_NCHW_NCHW88,
WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44
//!< layout
WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to
//!< nchw44 layout
WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout
//!< to nchw44 layout
//!< the weight layout of input is nchw output is nchw44, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
WEIGHT_HYBIRD_NCHW_NCHW44,
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout dense
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout group
NCHW32_TO_NCHW, //! <from nchw32 layout to nchw layout
NCHW32_TO_NCHW64, //! <from nchw32 layout to nchw64 layout
NCHW64_TO_NCHW, //! <from nchw64 layout to nchw layout
NCHW64_TO_NCHW4, //! <from nchw64 layout to nchw4 layout
NCHW64_TO_NCHW32, //! <from nchw64 layout to nchw32 layout
NCHW_TO_NCHW64, //! <from nchw layout to nchw64 layout
NCHW_TO_NCHW32, //! <from nchw layout to nchw64 layout
NCHW4_TO_NCHW64, //! <from nchw4 layout to nchw64 layout
NCHW_TO_NHWC, //! <NHWC related layout transformation
NCHW4_TO_NHWC,
NCHW32_TO_NHWC,
NCHW64_TO_NHWC,
NHWC_TO_NCHW,
NHWC_TO_NCHW4,
NHWC_TO_NCHW32,
NHWC_TO_NCHW64,
};

RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);

/*!
* \param src_var the input var
* \param layout_type tensor layout transform type of this relayout
* placeholder as described in LayoutType
*/
static SymbolVar make(VarNode* src_var, LayoutType layout_type);
/*!
* \param src_var the input var
* \param layout_type tensor layout transform type of this relayout
* placeholder as described in LayoutType
*/
static SymbolVar make(VarNode* src_var, LayoutType layout_type);

LayoutType layout_type() const { return m_layout_type; }
LayoutType layout_type() const {
return m_layout_type;
}

private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void init_output_comp_node() override;
const LayoutType m_layout_type;
};
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void init_output_comp_node() override;
const LayoutType m_layout_type;
}
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder);

TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder(
@@ -1023,8 +1026,7 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 =
opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0);
auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
@@ -1036,7 +1038,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0);
auto tshp0 =
opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
@@ -1048,7 +1051,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0);
auto tshp0 =
opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
@@ -1865,8 +1869,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
};

auto replace_deconv_opr = [trans_nchw4, conv_format](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
@@ -1881,7 +1885,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
opr->config());
}
VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0];
auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter);
auto deconv_mode =
trans_nchw4(deconv_opr.param().sparse, deconv_filter);
// src: NCHW --> NCWH4
if (deconv_src->shape().ndim != 5) {
mgb_assert(deconv_src->shape().ndim == 4);
@@ -2028,10 +2033,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(new_conv_bias_opr.node()->dtype().enumv() ==
DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// bias: NCHW --> NCHW4 when bias_dtype is not Float32
@@ -2047,10 +2052,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(new_conv_bias_opr.node()->dtype().enumv() ==
DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// z_inp: NCHW --> NCHW4 when bias_dtype is not Float32
@@ -2066,10 +2071,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(new_conv_bias_opr.node()->dtype().enumv() ==
DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
@@ -2210,8 +2215,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
auto&& replace_func = ret->m_opr_replace_func;
//! supportted nchw4
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvolutionBackwardData::typeinfo()] =
replace_deconv_opr;
replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
@@ -2348,6 +2352,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
megdnn::param::Convolution::Format::NCHW88;
megdnn::param::Pooling::Format pooling_format =
megdnn::param::Pooling::Format::NCHW88;
megdnn::param::Resize::Format resize_format =
megdnn::param::Resize::Format::NCHW88;
std::string convter_pass_name = "conv_format_nchw88";

if (pack_c_size == 4) {
@@ -2360,6 +2366,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
conv_format = megdnn::param::Convolution::Format::NCHW44;
pooling_format = megdnn::param::Pooling::Format::NCHW44;
resize_format = megdnn::param::Resize::Format::NCHW44;
convter_pass_name = "conv_format_nchw44";
}
auto test_trans_nchwxx =
@@ -2634,6 +2641,43 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
return new_opr;
}
};

auto replace_resize_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& resize_opr = opr->cast_final_safe<opr::ResizeForward>();
mgb_throw_if(
resize_opr.param().format !=
megdnn::param::Resize::Format::NCHW &&
resize_opr.param().format !=
megdnn::param::Resize::Format::NHWC,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NCHWxx");

VarNode* inp = new_inp[0];
if (resize_opr.param().format == megdnn::param::Resize::Format::NHWC) {
auto temp_inp = new_inp;
if (inp->shape().ndim == 5) {
auto new_var = RelayoutPlaceholder::make(inp, src_to_nchw_mode);
temp_inp[0] = new_var.node();
}
return serialization::copy_opr_shallow(*opr, temp_inp,
opr->config());
} else {
auto temp_inp = new_inp;
if (inp->shape().ndim == 5) {
auto new_param = resize_opr.param();
new_param.format = resize_format;
auto new_resize_opr = opr::ResizeForward::make(
new_inp[0], new_inp[1], new_param, opr->config());
return new_resize_opr.node()->owner_opr();
} else {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
}
};

//! When input change and all input can convert to nchwxx, this opr will run
//! in nchwxx mode, else it will run in nchw mode, for example concat and
//! elemwise opr
@@ -2704,6 +2748,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr;
@@ -2718,7 +2763,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
@@ -3236,26 +3280,27 @@ public:
MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
cg::SingleCNOperatorNodeBase) // {
public:
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);

static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);

TensorFormat inp_format() const {
return m_inp_format;
}
TensorFormat inp_format() const {
return m_inp_format;
}

TensorFormat out_format() const {
return m_out_format;
}
TensorFormat out_format() const {
return m_out_format;
}

private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
const TensorFormat m_inp_format;
const TensorFormat m_out_format;
};
void init_output_static_infer_desc() override;
void scn_do_execute() override;
const TensorFormat m_inp_format;
const TensorFormat m_out_format;
}
;

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr);

@@ -3910,8 +3955,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
opr_set.insert(opr);

// check dimshuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(
reshape->input(0)->owner_opr());
auto shuffle =
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
@@ -3981,10 +4026,9 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype});
rewriter.replace_var(
opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NHWC)"));
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NHWC)"));
return true;
};

@@ -4036,8 +4080,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW32;
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW32;
if (!is_s8nchw32)
return false;
if (conv_bias->input().size() != 3)
@@ -4078,9 +4122,8 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
&rewriter](OperatorNodeBase* opr) {
if (!try_conv_dimshuffle_reshape_typecvt(opr) &&
!try_conv_reformat_nchw42nchw32(opr) &&
!try_conv_reformat_nchw42nhwc(opr)
&& !try_conv_reformat_nchw322nchw4(opr)
) {
!try_conv_reformat_nchw42nhwc(opr) &&
!try_conv_reformat_nchw322nchw4(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
@@ -4497,7 +4540,7 @@ void PaddingChannelPass::apply(OptState& opt) const {

/* ================ EnableNCHW64Pass =============== */
VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
auto iter = m_opr_format_map.find(new_var->owner_opr());
mgb_assert(iter != m_opr_format_map.end(),
@@ -4532,8 +4575,7 @@ VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var,
return new_var;
}

std::unique_ptr<EnableNCHW64Pass>
EnableNCHW64Pass::make_nchw64_converter() {
std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
MIDOUT_B("EnableNCHW64Pass::make")
auto ret = std::make_unique<EnableNCHW64Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
@@ -4618,15 +4660,15 @@ EnableNCHW64Pass::make_nchw64_converter() {
[make_new_conv, &format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size()==new_inp.size());
mgb_assert(opr->input().size() == new_inp.size());
bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
mgb_assert(opr->output().size() > 0);
bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 3) {
auto dtype_expect = dst_float ? DTypeEnum::Float32
: DTypeEnum::QuantizedS32;
auto dtype_expect =
dst_float ? DTypeEnum::Float32 : DTypeEnum::QuantizedS32;
check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect;
}
if (opr->input().size() >= 4) {
@@ -4677,12 +4719,13 @@ EnableNCHW64Pass::make_nchw64_converter() {
for (size_t i = 0; i < inps.size(); ++i) {
// do not format bias and z when dst_float is true
bool skip = dst_float && i >= 2;
if (!skip) inps[i] = process(i);
if (!skip)
inps[i] = process(i);
}
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
auto ret = make_new_conv(
inps, &conv_bias,
dst_float ? Format::NCHW4_NCHW : Format::NCHW4);
auto ret =
make_new_conv(inps, &conv_bias,
dst_float ? Format::NCHW4_NCHW : Format::NCHW4);
if (!dst_float)
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4));
return ret;
@@ -4692,7 +4735,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
[make_new_conv, &format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size()==new_inp.size());
mgb_assert(opr->input().size() == new_inp.size());
bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
@@ -4754,18 +4797,17 @@ EnableNCHW64Pass::make_nchw64_converter() {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
// fint4XWint4 and fuint4XWint4
mgb_assert(opr->input().size()==new_inp.size());
mgb_assert(opr->input().size() == new_inp.size());
bool check_dtype =
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) &&
new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &= new_inp[3]->dtype().enumv() ==
new_inp[0]->dtype().enumv();
check_dtype &=
new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv();
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
@@ -4818,18 +4860,17 @@ EnableNCHW64Pass::make_nchw64_converter() {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
// fint4XWint4 and fuint4XWint4
mgb_assert(opr->input().size()==new_inp.size());
mgb_assert(opr->input().size() == new_inp.size());
bool check_dtype =
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) &&
new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &= new_inp[3]->dtype().enumv() ==
new_inp[0]->dtype().enumv();
check_dtype &=
new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv();
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
@@ -4842,8 +4883,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC);
inps[i], RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC);
return ovar.node();
} else {
const auto& fmt = iter->second;
@@ -4973,7 +5013,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
default:
mgb_assert(cur == Format::NCHW4);
}
auto param = deconv.param();
param.format = Format::NCHW4;
auto new_deconv = opr::ConvolutionBackwardData::make(
@@ -4990,7 +5030,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
break;
}
}
mgb_assert(!shape_changed,
mgb_assert(!shape_changed,
"EnableNCHW64Pass won't change format of output tensor "
"of non quantized deconv operator(name:%s)",
opr->cname());
@@ -5000,8 +5040,9 @@ EnableNCHW64Pass::make_nchw64_converter() {
};

// replace rule for elemwise like opr
auto replace_elemwise_like_opr = [&format_map](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
auto replace_elemwise_like_opr = [&format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
ThinHashMap<Format, size_t> format_size;
bool same_format = true;
@@ -5073,7 +5114,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
cur = Format::NCHW;
}
if (cur != max_format) {
inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]);
inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]);
}
}
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config());
@@ -5131,8 +5172,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
SymbolVar new_warp;
if (inps.size() == 3) {
new_warp = opr::WarpPerspectiveForward::make(
inps[0], inps[1], inps[2], param,
warp.config());
inps[0], inps[1], inps[2], param, warp.config());
} else {
mgb_assert(inps.size() == 4);
new_warp = opr::WarpPerspectiveForward::make(
@@ -5179,14 +5219,13 @@ EnableNCHW64Pass::make_nchw64_converter() {
default:
mgb_assert(cur == Format::NCHW4);
}
auto param = warp.param();
param.format = Format::NCHW4;
SymbolVar new_warp;
if (inps.size() == 3) {
new_warp = opr::WarpPerspectiveForward::make(
inps[0], inps[1], inps[2], param,
warp.config());
inps[0], inps[1], inps[2], param, warp.config());
} else {
mgb_assert(inps.size() == 4);
new_warp = opr::WarpPerspectiveForward::make(
@@ -5204,7 +5243,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
break;
}
}
mgb_assert(!shape_changed,
mgb_assert(!shape_changed,
"EnableNCHW64Pass won't change format of output tensor "
"of non quantized warp perspective operator(name:%s)",
opr->cname());
@@ -5212,9 +5251,8 @@ EnableNCHW64Pass::make_nchw64_converter() {
opr->config());
}
};
auto replace_pooling_opr = [&format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
auto replace_pooling_opr = [&format_map](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
@@ -5300,7 +5338,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
mgb_assert(cur == Format::NCHW4);
}
Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4;
auto param = pooling.param();
param.format = out_format;
auto new_pool =
@@ -5336,7 +5374,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
auto inps = new_inp;
for (size_t i = 0; i < opr->input().size(); ++i) {
auto iter = format_map.find(new_inp[i]->owner_opr());
auto fmt = iter != format_map.end()?iter->second:Format::NCHW;
auto fmt = iter != format_map.end() ? iter->second : Format::NCHW;
if (iter != format_map.end()) {
switch (fmt) {
case Format::NHWC:


+ 4
- 2
src/opr/impl/imgproc.cpp View File

@@ -10,9 +10,9 @@
* implied.
*/

#include "megbrain/opr/imgproc.h"
#include "./internal/megdnn_opr_wrapper.inl"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h"

@@ -340,7 +340,9 @@ void ResizeForward::outshape_by_symvar_do_get_output_shape(
//! The index of height, e.g.,[b, h, w, c], the height_idx = 1
size_t height_idx = 0;
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4) {
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW88) {
height_idx = 2;
} else {
height_idx = 1;


Loading…
Cancel
Save