Browse Source

Merge pull request #484 from wangxiang9603:add-nchw44-deconv

GitOrigin-RevId: 444429a625
release-1.11.1
huangxinda Megvii Engine Team 2 years ago
parent
commit
a07fbf79f7
7 changed files with 351 additions and 12 deletions
  1. +6
    -5
      dnn/src/common/convolution.cpp
  2. +131
    -4
      dnn/src/fallback/convolution/algos.cpp
  3. +14
    -0
      dnn/src/fallback/convolution/algos.h
  4. +67
    -0
      dnn/src/fallback/convolution/col2img_helper.h
  5. +8
    -3
      dnn/src/fallback/convolution/opr_impl.cpp
  6. +2
    -0
      dnn/src/fallback/convolution/opr_impl.h
  7. +123
    -0
      dnn/test/fallback/convolution.cpp

+ 6
- 5
dnn/src/common/convolution.cpp View File

@@ -1186,7 +1186,7 @@ void ConvolutionBackwardData::deduce_layout(
MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert_contiguous(filter);
megdnn_assert_contiguous(diff);
megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str());
megdnn_assert(filter.ndim >= 4_z && filter.ndim <= 7_z, "%s", errmsg().c_str());
megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str());

deduce_dtype(filter.dtype, diff.dtype, grad.dtype);
@@ -1223,11 +1223,12 @@ void ConvolutionBackwardData::deduce_layout(
deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i],
cflt.stride[i], cflt.padding[i]);
}
} else if (param().format == Param::Format::NCHW4) {
} else if (
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44) {
megdnn_assert(
diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu",
diff.ndim);
megdnn_assert(cflt.group == 1, "%s", errmsg().c_str());
diff.ndim == 5,
"valid diff ndim for NCHW4 and NCHW44, expected=5, got=%zu", diff.ndim);
megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str());
grad.ndim = diff.ndim;
grad[0] = diff[0];


+ 131
- 4
dnn/src/fallback/convolution/algos.cpp View File

@@ -29,14 +29,19 @@ Relayout* get_relayout_opr() {
MatrixMul* get_matmul_opr(const NCBKernSizeParam& param) {
using ConvCM = param::Convolution::ComputeMode;
using MmCM = param::MatrixMul::ComputeMode;
static CpuOprDelegationStorage<2> storage;
static CpuOprDelegationStorage<3> storage;
if (param.filter_meta.format == param::Convolution::Format::NCHW44) {
MatrixMul::Param p;
p.format = param::MatrixMul::Format::MK4;
return storage.get<MatrixMul, 0>(p);
}
switch (param.compute_mode) {
default:
return storage.get<MatrixMul, 0>({});
return storage.get<MatrixMul, 1>({});
case ConvCM::FLOAT32: {
MatrixMul::Param p;
p.compute_mode = MmCM::FLOAT32;
return storage.get<MatrixMul, 1>(p);
return storage.get<MatrixMul, 2>(p);
}
}
}
@@ -58,7 +63,14 @@ WorkspaceBundle get_bundle(const NCBKernSizeParam& param) {
part0 = (IC * FH * FW * IH * IW) * param.grad_type.size();
}
part2 = (OC * IC * FH * FW) * param.filter_type.size();
{
if (param.filter_meta.format == param::Convolution::Format::NCHW44) {
TensorLayout A_, B_, C_;
A_ = TensorLayout({IC / 4 * FH * FW, OC / 4, 4, 4}, param.filter_type);
B_ = TensorLayout({OC / 4, IH * IW}, param.diff_type);
C_ = TensorLayout({IC / 4 * FH * FW, IH * IW, 4}, param.grad_type);
auto matmul_algo = get_matmul_opr(param);
part1 = matmul_algo->get_workspace_in_bytes(A_, B_, C_);
} else {
TensorLayout A_, B_, C_;
A_ = TensorLayout({IC * FH * FW, OC}, param.filter_type);
B_ = TensorLayout({OC, IH * IW}, param.diff_type);
@@ -573,4 +585,119 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::is_preferred(
return is_matrix_mul_preferred(param);
}

/* ===================== Matrix mul nchw44 algo ===================== */
namespace{
void kern_matmul_nchw44(const NCBKernParam& param) {
bool is_xcorr = !param.filter_meta.should_flip;
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
auto bundle = get_bundle(param);
bundle.set(param.workspace_ptr);
bool is1X1 = (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0);

typedef void (*Func1)(const float*, float*, int, int, int, int, int, int, int);
typedef void (*Func2)(
const float*, float*, int, int, int, int, int, int, int, int, int, int,
int);
Func1 f1 = nullptr;
Func2 f2 = nullptr;
if (is_xcorr) {
f1 = col2img_nchw44<true>;
f2 = col2img_stride_padding_nchw44<true>;
} else {
f1 = col2img_nchw44<false>;
f2 = col2img_stride_padding_nchw44<false>;
}
float* filter = const_cast<float*>(param.filter<float>());
TensorND A_src, A_dst;
{
A_src.layout = TensorLayout(
{IC / 4 * FH * FW, OC / 4, 4, 4},
{
static_cast<std::ptrdiff_t>(16),
static_cast<std::ptrdiff_t>(IC * FH * FW * 4),
static_cast<std::ptrdiff_t>(1),
static_cast<std::ptrdiff_t>(4),
},
param.filter_type);
A_src.reset_ptr(static_cast<void*>(filter));
A_dst.layout =
TensorLayout({IC / 4 * FH * FW, OC / 4, 4, 4}, param.filter_type);
A_dst.reset_ptr(static_cast<void*>(bundle.get(2)));
// TODO Should be removed once armv8 convolution support transpose.
get_relayout_opr()->exec(A_src, A_dst, inplace_cpu_handle().get());
}
TensorND B_, C_;
for (size_t n = 0; n < N; ++n) {
float*C_src, *C_dst;
float* diff = const_cast<float*>(param.diff<float>() + n * param.inp_bs);
float* grad = param.grad<float>() + n * param.out_bs;
if (is1X1) {
C_src = grad;
} else {
C_src = static_cast<float*>(bundle.get(0));
}
{
B_.layout = TensorLayout({OC/4, IH * IW, 4}, param.diff_type);
B_.reset_ptr(static_cast<void*>(diff));
C_.layout = TensorLayout({IC / 4 * FH * FW, IH * IW, 4}, param.grad_type);
C_.reset_ptr(C_src);
Workspace workspace(
static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1));
auto matmul_opr =get_matmul_opr(param);
matmul_opr->exec(A_dst, B_, C_, workspace);
}

if (!is1X1) {
C_dst = grad;
std::memset(C_dst, 0, param.grad_type.size() * IC * OH * OW);
if (PH == 0 && PW == 0 && SH == 1 && SW == 1) {
f1(C_src, C_dst, OH, OW, IC, IH, IW, FH, FW);
} else {
f2(C_src, C_dst, OH, OW, IC, IH, IW, FH, FW, SH, SW, PH, PW);
}
}
}
}
} // namespace

bool ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::usable(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
return fm.format == param::Convolution::Format::NCHW44 &&
param.diff_type.enumv() == DTypeTrait<dtype::Float32>::enumv &&
param.filter_type.enumv() == DTypeTrait<dtype::Float32>::enumv &&
param.grad_type.enumv() == DTypeTrait<dtype::Float32>::enumv &&
fm.spatial_ndim == 2 && fm.group == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.icpg % 4 == 0 && fm.ocpg % 4 == 0;
}

size_t ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_fallback_deconv,
midout_iv("AlgoMatrixMulNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::
AlgoMatrixMulNCHW44::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
if (param.filter_type.enumv() == DTypeTrait<dtype::Float32>::enumv) {
MIDOUT_BEGIN(megdnn_fallback_deconv, midout_iv("FLOAT_NCHW44"_hash)) {
return kern_matmul_nchw44;
}
MIDOUT_END();
}

megdnn_throw("unsupported data type on matrix mul");
}

bool ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::is_preferred(
const NCBKernSizeParam& param) const {
return is_matrix_mul_preferred(param);
}

// vim: syntax=cpp.doxygen

+ 14
- 0
dnn/src/fallback/convolution/algos.h View File

@@ -198,6 +198,20 @@ public:
MEGDNN_DECL_ALGO_TYPE(FB_MATMUL)
};

class ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44 final : public AlgoBase {
public:
const char* name() const override { return "DeconvMatmulNchw44"; }
bool usable(ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param)
const override;
size_t get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override;
ncb_kern_t dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override;
bool is_preferred(const NCBKernSizeParam& param) const override;
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
MEGDNN_DECL_ALGO_TYPE(FB_MATMUL_NCHW44)
};

} // namespace fallback
} // namespace megdnn



+ 67
- 0
dnn/src/fallback/convolution/col2img_helper.h View File

@@ -1,5 +1,6 @@
#include <cstddef>
#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_float.h"

namespace {

@@ -61,6 +62,72 @@ void col2img(
}
}

template <bool is_xcorr>
void col2img_stride_padding_nchw44(
const float* __restrict src, float* __restrict dst, const int OH, const int OW,
const int IC, const int IH, const int IW, const int FH, const int FW,
const int SH, const int SW, int PH, int PW) {
size_t i = 0;
rep(ic, IC / 4) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
rep(ih, IH) {
int h = ih * SH - PH + fh2;
rep(iw, IW) {
int w = iw * SW - PW + fw2;
if (h >= 0 && h < OH && w >= 0 && w < OW) {
float* dst_ptr = dst + (ic * OH * OW + h * OW + w) * 4;
GI_FLOAT32_t dst_data = GiLoadFloat32(dst_ptr);
GI_FLOAT32_t src_data = GiLoadFloat32(src+i);
GiStoreFloat32(dst_ptr, GiAddFloat32(dst_data, src_data));
}
i += 4;
}
}
}
}
}
}

template <bool is_xcorr>
void col2img_nchw44(
const float* __restrict src, float* __restrict dst, const int OH, const int OW,
const int IC, const int IH, const int IW, const int FH, const int FW) {
size_t i = 0;
rep(ic, IC / 4) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
rep(ih, IH) {
rep(iw, IW) {
float* dst_ptr = dst + ic * OH * OW * 4 + (ih + fh2) * OW * 4 +
iw * 4 + fw2 * 4;
GI_FLOAT32_t dst_data = GiLoadFloat32(dst_ptr);
GI_FLOAT32_t src_data = GiLoadFloat32(src + i);
GiStoreFloat32(dst_ptr, GiAddFloat32(dst_data, src_data));
i += 4;
}
}
}
}
}
}

} // anonymous namespace

// vim: syntax=cpp.doxygen

+ 8
- 3
dnn/src/fallback/convolution/opr_impl.cpp View File

@@ -437,11 +437,13 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
AlgoNaive algo_naive;
AlgoDirect algo_direct;
AlgoMatrixMul algo_matmul;
AlgoMatrixMulNCHW44 algo_matmul_nchw44;
SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map;

public:
AlgoPack() {
m_all_algos.emplace_back(&algo_matmul_nchw44);
m_all_algos.emplace_back(&algo_matmul);
m_all_algos.emplace_back(&algo_direct);
m_all_algos.emplace_back(&algo_naive);
@@ -557,7 +559,8 @@ ConvolutionBackwardDataImpl::NCBKernSizeParam ConvolutionBackwardDataImpl::
return v;
};
size_t spatial_pos;
if (param().format == Param::Format::NCHW) {
if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW44) {
spatial_pos = 2;
} else {
megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format");
@@ -622,7 +625,8 @@ void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param)
} else {
megdnn_assert(
p1g.filter_meta.format == Param::Format::NCHW ||
p1g.filter_meta.format == Param::Format::NHWC,
p1g.filter_meta.format == Param::Format::NHWC ||
p1g.filter_meta.format == Param::Format::NCHW44,
"invalid conv format");
auto run = [kptr, p1g_orig = p1g, group]() {
auto p1g = p1g_orig;
@@ -640,7 +644,8 @@ void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param)
p1g.filter_type.size();
p1g.grad_extra_mem_size =
(group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size();
if (p1g.filter_meta.format == Param::Format::NCHW) {
if (p1g.filter_meta.format == Param::Format::NCHW ||
p1g.filter_meta.format == Param::Format::NCHW44) {
istrd *= p1g.isz[0] * p1g.isz[1];
ostrd *= p1g.osz[0] * p1g.osz[1];
p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1];


+ 2
- 0
dnn/src/fallback/convolution/opr_impl.h View File

@@ -392,6 +392,7 @@ protected:
FB_NAIVE = 1 << 0,
FB_DIRECT,
FB_MATMUL,
FB_MATMUL_NCHW44,

#if MEGDNN_AARCH64 || MEGDNN_ARMV7
ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32 = 1 << 8,
@@ -480,6 +481,7 @@ private:
class AlgoNaive;
class AlgoDirect;
class AlgoMatrixMul;
class AlgoMatrixMulNCHW44;
class AlgoPack;
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;



+ 123
- 0
dnn/test/fallback/convolution.cpp View File

@@ -463,6 +463,60 @@ TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA) {
}
}

TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_NCHW44) {
Checker<ConvolutionBackwardData> checker(handle());
using Param = ConvolutionBackwardData::Param;

Param param;
param.format = Param::Format::NCHW44;
auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh,
size_t fw, size_t stride, size_t padding, size_t dilate = 1,
size_t group = 1) {
param.pad_h = param.pad_w = padding;
param.stride_h = param.stride_w = stride;
param.dilate_h = param.dilate_w = dilate;

TensorLayout diff =
TensorLayout{{n, oc / 4 * group, oh, ow, 4}, dtype::Float32()};
TensorLayout grad;
TensorLayout filter;
if (group == 1) {
param.sparse = Param::Sparse::DENSE;
filter = {{oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()};
} else {
param.sparse = Param::Sparse::GROUP;
filter = {{group, oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()};
}
// TensorLayout grad;
{
auto opr = handle()->create_operator<ConvolutionBackwardData>();
opr->param() = param;
opr->deduce_layout(filter, diff, grad);
}
checker.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32());
checker.exec(TensorLayoutArray{filter, diff, grad});
};

for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) {
param.mode = mode;
run(1, 4, 2, 2, 4, 1, 1, 1, 0, 1, 1);
run(1, 4, 2, 2, 4, 3, 3, 1, 0, 1, 1);
run(1, 4, 2, 2, 4, 3, 3, 1, 1, 1, 1);

run(4, 16, 10, 13, 16, 1, 1, 1, 0, 1, 1);
run(4, 16, 10, 13, 16, 3, 3, 1, 0, 1, 1);
run(4, 16, 10, 13, 16, 3, 3, 1, 1, 1, 1);

run(4, 32, 11, 23, 32, 1, 1, 1, 0, 1, 4);

run(4, 16, 11, 23, 8, 3, 3, 1, 0, 1, 4);
run(4, 16, 11, 23, 8, 3, 3, 1, 1, 1, 4);
run(4, 16, 11, 23, 8, 3, 3, 2, 1, 1, 4);
}
}

TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_RECORD) {
TaskRecordChecker<ConvolutionBackwardData> checker(1);
using Param = ConvolutionBackwardData::Param;
@@ -707,4 +761,73 @@ TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_NAIVE_ALGO) {
}
}

#if MEGDNN_WITH_BENCHMARK

TEST_F(FALLBACK, BENCHMARK_CONVOLUTION_BACKWARD_DATA_NCHW44) {
using Param = ConvolutionBackwardData::Param;
auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh,
size_t fw, size_t stride, size_t padding, size_t dilate = 1,
size_t group = 1) {
Param param;
param.pad_h = param.pad_w = padding;
param.stride_h = param.stride_w = stride;
param.dilate_h = param.dilate_w = dilate;

TensorLayout diff_nchw44 =
TensorLayout{{n, oc / 4 * group, oh, ow, 4}, dtype::Float32()};
TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()};
TensorLayout grad;
TensorLayout grad_nchw44;
TensorLayout filter;
TensorLayout filter_nchw44;
if (group == 1) {
param.sparse = Param::Sparse::DENSE;
filter_nchw44 = {{oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()};
filter = {{oc, ic, fh, fw}, dtype::Float32()};
} else {
param.sparse = Param::Sparse::GROUP;
filter_nchw44 = {{group, oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()};
filter = {{group, oc, ic, fh, fw}, dtype::Float32()};
}
{
auto opr = handle()->create_operator<ConvolutionBackwardData>();
opr->param() = param;
opr->deduce_layout(filter, diff, grad);
opr->param().format = Param::Format::NCHW44;
opr->deduce_layout(filter_nchw44, diff_nchw44, grad_nchw44);
}
Benchmarker<ConvolutionBackwardData> benchmarker_fallback(handle());
size_t RUN = 50;
benchmarker_fallback.set_display(false)
.set_dtype(0, dtype::Float32{})
.set_dtype(1, dtype::Float32{})
.set_dtype(2, dtype::Float32{})
.set_times(RUN);

auto tnchw =
benchmarker_fallback.set_param(param)
.exec(TensorLayoutArray{filter, diff, grad});
param.format = Param::Format::NCHW44;
auto tnchw44 =
benchmarker_fallback.set_param(param)
.exec(TensorLayoutArray{filter_nchw44, diff_nchw44, grad_nchw44});
size_t IC = ic;
size_t FH = fh;
size_t FW = fw;
size_t total_flops = IC * diff.total_nr_elems() * FH * FW * 2;
printf("nchw_time: %.3f ms nchw_flops: %.3f Gflops\n", tnchw,
total_flops / (tnchw / RUN * 1e6));
printf("nchw44_time: %.3f ms nchw44_flops: %.3f Gflops\n", tnchw44,
total_flops / (tnchw44 / RUN * 1e6));
printf("speedup: %.3f\n", tnchw / tnchw44);
};
run(1, 16, 14, 14, 16, 3, 3, 1, 1, 1, 1);
run(1, 32, 28, 28, 16, 3, 3, 1, 1, 1, 1);
run(1, 48, 28, 28, 48, 2, 2, 1, 0, 1, 1);
run(1, 32, 26, 26, 32, 3, 3, 1, 0, 1, 1);
run(2, 32, 64, 64, 32, 3, 3, 1, 0, 1, 1);
run(2, 16, 112, 112, 16, 3, 3, 1, 0, 1, 1);
}
#endif

// vim: syntax=cpp.doxygen

Loading…
Cancel
Save