@@ -1156,7 +1156,7 @@ void ConvolutionBackwardData::deduce_layout( | |||
MEGDNN_MARK_USED_VAR(errmsg); | |||
megdnn_assert_contiguous(filter); | |||
megdnn_assert_contiguous(diff); | |||
megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str()); | |||
megdnn_assert(filter.ndim >= 4_z && filter.ndim <= 7_z, "%s", errmsg().c_str()); | |||
megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str()); | |||
deduce_dtype(filter.dtype, diff.dtype, grad.dtype); | |||
@@ -1193,11 +1193,12 @@ void ConvolutionBackwardData::deduce_layout( | |||
deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], | |||
cflt.stride[i], cflt.padding[i]); | |||
} | |||
} else if (param().format == Param::Format::NCHW4) { | |||
} else if ( | |||
param().format == Param::Format::NCHW4 || | |||
param().format == Param::Format::NCHW44) { | |||
megdnn_assert( | |||
diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu", | |||
diff.ndim); | |||
megdnn_assert(cflt.group == 1, "%s", errmsg().c_str()); | |||
diff.ndim == 5, | |||
"valid diff ndim for NCHW4 and NCHW44, expected=5, got=%zu", diff.ndim); | |||
megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str()); | |||
grad.ndim = diff.ndim; | |||
grad[0] = diff[0]; | |||
@@ -29,14 +29,19 @@ Relayout* get_relayout_opr() { | |||
MatrixMul* get_matmul_opr(const NCBKernSizeParam& param) { | |||
using ConvCM = param::Convolution::ComputeMode; | |||
using MmCM = param::MatrixMul::ComputeMode; | |||
static CpuOprDelegationStorage<2> storage; | |||
static CpuOprDelegationStorage<3> storage; | |||
if (param.filter_meta.format == param::Convolution::Format::NCHW44) { | |||
MatrixMul::Param p; | |||
p.format = param::MatrixMul::Format::MK4; | |||
return storage.get<MatrixMul, 0>(p); | |||
} | |||
switch (param.compute_mode) { | |||
default: | |||
return storage.get<MatrixMul, 0>({}); | |||
return storage.get<MatrixMul, 1>({}); | |||
case ConvCM::FLOAT32: { | |||
MatrixMul::Param p; | |||
p.compute_mode = MmCM::FLOAT32; | |||
return storage.get<MatrixMul, 1>(p); | |||
return storage.get<MatrixMul, 2>(p); | |||
} | |||
} | |||
} | |||
@@ -58,7 +63,14 @@ WorkspaceBundle get_bundle(const NCBKernSizeParam& param) { | |||
part0 = (IC * FH * FW * IH * IW) * param.grad_type.size(); | |||
} | |||
part2 = (OC * IC * FH * FW) * param.filter_type.size(); | |||
{ | |||
if (param.filter_meta.format == param::Convolution::Format::NCHW44) { | |||
TensorLayout A_, B_, C_; | |||
A_ = TensorLayout({IC / 4 * FH * FW, OC / 4, 4, 4}, param.filter_type); | |||
B_ = TensorLayout({OC / 4, IH * IW}, param.diff_type); | |||
C_ = TensorLayout({IC / 4 * FH * FW, IH * IW, 4}, param.grad_type); | |||
auto matmul_algo = get_matmul_opr(param); | |||
part1 = matmul_algo->get_workspace_in_bytes(A_, B_, C_); | |||
} else { | |||
TensorLayout A_, B_, C_; | |||
A_ = TensorLayout({IC * FH * FW, OC}, param.filter_type); | |||
B_ = TensorLayout({OC, IH * IW}, param.diff_type); | |||
@@ -573,4 +585,119 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::is_preferred( | |||
return is_matrix_mul_preferred(param); | |||
} | |||
/* ===================== Matrix mul nchw44 algo ===================== */ | |||
namespace{ | |||
void kern_matmul_nchw44(const NCBKernParam& param) { | |||
bool is_xcorr = !param.filter_meta.should_flip; | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
auto bundle = get_bundle(param); | |||
bundle.set(param.workspace_ptr); | |||
bool is1X1 = (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); | |||
typedef void (*Func1)(const float*, float*, int, int, int, int, int, int, int); | |||
typedef void (*Func2)( | |||
const float*, float*, int, int, int, int, int, int, int, int, int, int, | |||
int); | |||
Func1 f1 = nullptr; | |||
Func2 f2 = nullptr; | |||
if (is_xcorr) { | |||
f1 = col2img_nchw44<true>; | |||
f2 = col2img_stride_padding_nchw44<true>; | |||
} else { | |||
f1 = col2img_nchw44<false>; | |||
f2 = col2img_stride_padding_nchw44<false>; | |||
} | |||
float* filter = const_cast<float*>(param.filter<float>()); | |||
TensorND A_src, A_dst; | |||
{ | |||
A_src.layout = TensorLayout( | |||
{IC / 4 * FH * FW, OC / 4, 4, 4}, | |||
{ | |||
static_cast<std::ptrdiff_t>(16), | |||
static_cast<std::ptrdiff_t>(IC * FH * FW * 4), | |||
static_cast<std::ptrdiff_t>(1), | |||
static_cast<std::ptrdiff_t>(4), | |||
}, | |||
param.filter_type); | |||
A_src.reset_ptr(static_cast<void*>(filter)); | |||
A_dst.layout = | |||
TensorLayout({IC / 4 * FH * FW, OC / 4, 4, 4}, param.filter_type); | |||
A_dst.reset_ptr(static_cast<void*>(bundle.get(2))); | |||
// TODO Should be removed once armv8 convolution support transpose. | |||
get_relayout_opr()->exec(A_src, A_dst, inplace_cpu_handle().get()); | |||
} | |||
TensorND B_, C_; | |||
for (size_t n = 0; n < N; ++n) { | |||
float*C_src, *C_dst; | |||
float* diff = const_cast<float*>(param.diff<float>() + n * param.inp_bs); | |||
float* grad = param.grad<float>() + n * param.out_bs; | |||
if (is1X1) { | |||
C_src = grad; | |||
} else { | |||
C_src = static_cast<float*>(bundle.get(0)); | |||
} | |||
{ | |||
B_.layout = TensorLayout({OC/4, IH * IW, 4}, param.diff_type); | |||
B_.reset_ptr(static_cast<void*>(diff)); | |||
C_.layout = TensorLayout({IC / 4 * FH * FW, IH * IW, 4}, param.grad_type); | |||
C_.reset_ptr(C_src); | |||
Workspace workspace( | |||
static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)); | |||
auto matmul_opr =get_matmul_opr(param); | |||
matmul_opr->exec(A_dst, B_, C_, workspace); | |||
} | |||
if (!is1X1) { | |||
C_dst = grad; | |||
std::memset(C_dst, 0, param.grad_type.size() * IC * OH * OW); | |||
if (PH == 0 && PW == 0 && SH == 1 && SW == 1) { | |||
f1(C_src, C_dst, OH, OW, IC, IH, IW, FH, FW); | |||
} else { | |||
f2(C_src, C_dst, OH, OW, IC, IH, IW, FH, FW, SH, SW, PH, PW); | |||
} | |||
} | |||
} | |||
} | |||
} // namespace | |||
bool ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::usable( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
auto&& fm = param.filter_meta; | |||
return fm.format == param::Convolution::Format::NCHW44 && | |||
param.diff_type.enumv() == DTypeTrait<dtype::Float32>::enumv && | |||
param.filter_type.enumv() == DTypeTrait<dtype::Float32>::enumv && | |||
param.grad_type.enumv() == DTypeTrait<dtype::Float32>::enumv && | |||
fm.spatial_ndim == 2 && fm.group == 1 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1 && fm.icpg % 4 == 0 && fm.ocpg % 4 == 0; | |||
} | |||
size_t ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::get_workspace( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_fallback_deconv, | |||
midout_iv("AlgoMatrixMulNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl:: | |||
AlgoMatrixMulNCHW44::dispatch_kern( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
if (param.filter_type.enumv() == DTypeTrait<dtype::Float32>::enumv) { | |||
MIDOUT_BEGIN(megdnn_fallback_deconv, midout_iv("FLOAT_NCHW44"_hash)) { | |||
return kern_matmul_nchw44; | |||
} | |||
MIDOUT_END(); | |||
} | |||
megdnn_throw("unsupported data type on matrix mul"); | |||
} | |||
bool ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::is_preferred( | |||
const NCBKernSizeParam& param) const { | |||
return is_matrix_mul_preferred(param); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -198,6 +198,20 @@ public: | |||
MEGDNN_DECL_ALGO_TYPE(FB_MATMUL) | |||
}; | |||
class ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44 final : public AlgoBase { | |||
public: | |||
const char* name() const override { return "DeconvMatmulNchw44"; } | |||
bool usable(ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) | |||
const override; | |||
size_t get_workspace( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; | |||
ncb_kern_t dispatch_kern( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
MEGDNN_DECL_ALGO_TYPE(FB_MATMUL_NCHW44) | |||
}; | |||
} // namespace fallback | |||
} // namespace megdnn | |||
@@ -1,5 +1,6 @@ | |||
#include <cstddef> | |||
#include "src/common/utils.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
namespace { | |||
@@ -61,6 +62,72 @@ void col2img( | |||
} | |||
} | |||
template <bool is_xcorr> | |||
void col2img_stride_padding_nchw44( | |||
const float* __restrict src, float* __restrict dst, const int OH, const int OW, | |||
const int IC, const int IH, const int IW, const int FH, const int FW, | |||
const int SH, const int SW, int PH, int PW) { | |||
size_t i = 0; | |||
rep(ic, IC / 4) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
rep(ih, IH) { | |||
int h = ih * SH - PH + fh2; | |||
rep(iw, IW) { | |||
int w = iw * SW - PW + fw2; | |||
if (h >= 0 && h < OH && w >= 0 && w < OW) { | |||
float* dst_ptr = dst + (ic * OH * OW + h * OW + w) * 4; | |||
GI_FLOAT32_t dst_data = GiLoadFloat32(dst_ptr); | |||
GI_FLOAT32_t src_data = GiLoadFloat32(src+i); | |||
GiStoreFloat32(dst_ptr, GiAddFloat32(dst_data, src_data)); | |||
} | |||
i += 4; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
template <bool is_xcorr> | |||
void col2img_nchw44( | |||
const float* __restrict src, float* __restrict dst, const int OH, const int OW, | |||
const int IC, const int IH, const int IW, const int FH, const int FW) { | |||
size_t i = 0; | |||
rep(ic, IC / 4) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2, fw2; | |||
if (is_xcorr) { | |||
fh2 = fh; | |||
fw2 = fw; | |||
} else { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
rep(ih, IH) { | |||
rep(iw, IW) { | |||
float* dst_ptr = dst + ic * OH * OW * 4 + (ih + fh2) * OW * 4 + | |||
iw * 4 + fw2 * 4; | |||
GI_FLOAT32_t dst_data = GiLoadFloat32(dst_ptr); | |||
GI_FLOAT32_t src_data = GiLoadFloat32(src + i); | |||
GiStoreFloat32(dst_ptr, GiAddFloat32(dst_data, src_data)); | |||
i += 4; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} // anonymous namespace | |||
// vim: syntax=cpp.doxygen |
@@ -437,11 +437,13 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
AlgoNaive algo_naive; | |||
AlgoDirect algo_direct; | |||
AlgoMatrixMul algo_matmul; | |||
AlgoMatrixMulNCHW44 algo_matmul_nchw44; | |||
SmallVector<AlgoBase*> m_all_algos; | |||
AlgoBase::Mapper m_all_algos_map; | |||
public: | |||
AlgoPack() { | |||
m_all_algos.emplace_back(&algo_matmul_nchw44); | |||
m_all_algos.emplace_back(&algo_matmul); | |||
m_all_algos.emplace_back(&algo_direct); | |||
m_all_algos.emplace_back(&algo_naive); | |||
@@ -557,7 +559,8 @@ ConvolutionBackwardDataImpl::NCBKernSizeParam ConvolutionBackwardDataImpl:: | |||
return v; | |||
}; | |||
size_t spatial_pos; | |||
if (param().format == Param::Format::NCHW) { | |||
if (param().format == Param::Format::NCHW || | |||
param().format == Param::Format::NCHW44) { | |||
spatial_pos = 2; | |||
} else { | |||
megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format"); | |||
@@ -622,7 +625,8 @@ void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) | |||
} else { | |||
megdnn_assert( | |||
p1g.filter_meta.format == Param::Format::NCHW || | |||
p1g.filter_meta.format == Param::Format::NHWC, | |||
p1g.filter_meta.format == Param::Format::NHWC || | |||
p1g.filter_meta.format == Param::Format::NCHW44, | |||
"invalid conv format"); | |||
auto run = [kptr, p1g_orig = p1g, group]() { | |||
auto p1g = p1g_orig; | |||
@@ -640,7 +644,8 @@ void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) | |||
p1g.filter_type.size(); | |||
p1g.grad_extra_mem_size = | |||
(group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size(); | |||
if (p1g.filter_meta.format == Param::Format::NCHW) { | |||
if (p1g.filter_meta.format == Param::Format::NCHW || | |||
p1g.filter_meta.format == Param::Format::NCHW44) { | |||
istrd *= p1g.isz[0] * p1g.isz[1]; | |||
ostrd *= p1g.osz[0] * p1g.osz[1]; | |||
p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1]; | |||
@@ -392,6 +392,7 @@ protected: | |||
FB_NAIVE = 1 << 0, | |||
FB_DIRECT, | |||
FB_MATMUL, | |||
FB_MATMUL_NCHW44, | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32 = 1 << 8, | |||
@@ -480,6 +481,7 @@ private: | |||
class AlgoNaive; | |||
class AlgoDirect; | |||
class AlgoMatrixMul; | |||
class AlgoMatrixMulNCHW44; | |||
class AlgoPack; | |||
Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
@@ -463,6 +463,60 @@ TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA) { | |||
} | |||
} | |||
TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_NCHW44) { | |||
Checker<ConvolutionBackwardData> checker(handle()); | |||
using Param = ConvolutionBackwardData::Param; | |||
Param param; | |||
param.format = Param::Format::NCHW44; | |||
auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh, | |||
size_t fw, size_t stride, size_t padding, size_t dilate = 1, | |||
size_t group = 1) { | |||
param.pad_h = param.pad_w = padding; | |||
param.stride_h = param.stride_w = stride; | |||
param.dilate_h = param.dilate_w = dilate; | |||
TensorLayout diff = | |||
TensorLayout{{n, oc / 4 * group, oh, ow, 4}, dtype::Float32()}; | |||
TensorLayout grad; | |||
TensorLayout filter; | |||
if (group == 1) { | |||
param.sparse = Param::Sparse::DENSE; | |||
filter = {{oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()}; | |||
} else { | |||
param.sparse = Param::Sparse::GROUP; | |||
filter = {{group, oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()}; | |||
} | |||
// TensorLayout grad; | |||
{ | |||
auto opr = handle()->create_operator<ConvolutionBackwardData>(); | |||
opr->param() = param; | |||
opr->deduce_layout(filter, diff, grad); | |||
} | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()); | |||
checker.exec(TensorLayoutArray{filter, diff, grad}); | |||
}; | |||
for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) { | |||
param.mode = mode; | |||
run(1, 4, 2, 2, 4, 1, 1, 1, 0, 1, 1); | |||
run(1, 4, 2, 2, 4, 3, 3, 1, 0, 1, 1); | |||
run(1, 4, 2, 2, 4, 3, 3, 1, 1, 1, 1); | |||
run(4, 16, 10, 13, 16, 1, 1, 1, 0, 1, 1); | |||
run(4, 16, 10, 13, 16, 3, 3, 1, 0, 1, 1); | |||
run(4, 16, 10, 13, 16, 3, 3, 1, 1, 1, 1); | |||
run(4, 32, 11, 23, 32, 1, 1, 1, 0, 1, 4); | |||
run(4, 16, 11, 23, 8, 3, 3, 1, 0, 1, 4); | |||
run(4, 16, 11, 23, 8, 3, 3, 1, 1, 1, 4); | |||
run(4, 16, 11, 23, 8, 3, 3, 2, 1, 1, 4); | |||
} | |||
} | |||
TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_RECORD) { | |||
TaskRecordChecker<ConvolutionBackwardData> checker(1); | |||
using Param = ConvolutionBackwardData::Param; | |||
@@ -707,4 +761,73 @@ TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_NAIVE_ALGO) { | |||
} | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(FALLBACK, BENCHMARK_CONVOLUTION_BACKWARD_DATA_NCHW44) { | |||
using Param = ConvolutionBackwardData::Param; | |||
auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh, | |||
size_t fw, size_t stride, size_t padding, size_t dilate = 1, | |||
size_t group = 1) { | |||
Param param; | |||
param.pad_h = param.pad_w = padding; | |||
param.stride_h = param.stride_w = stride; | |||
param.dilate_h = param.dilate_w = dilate; | |||
TensorLayout diff_nchw44 = | |||
TensorLayout{{n, oc / 4 * group, oh, ow, 4}, dtype::Float32()}; | |||
TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()}; | |||
TensorLayout grad; | |||
TensorLayout grad_nchw44; | |||
TensorLayout filter; | |||
TensorLayout filter_nchw44; | |||
if (group == 1) { | |||
param.sparse = Param::Sparse::DENSE; | |||
filter_nchw44 = {{oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()}; | |||
filter = {{oc, ic, fh, fw}, dtype::Float32()}; | |||
} else { | |||
param.sparse = Param::Sparse::GROUP; | |||
filter_nchw44 = {{group, oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()}; | |||
filter = {{group, oc, ic, fh, fw}, dtype::Float32()}; | |||
} | |||
{ | |||
auto opr = handle()->create_operator<ConvolutionBackwardData>(); | |||
opr->param() = param; | |||
opr->deduce_layout(filter, diff, grad); | |||
opr->param().format = Param::Format::NCHW44; | |||
opr->deduce_layout(filter_nchw44, diff_nchw44, grad_nchw44); | |||
} | |||
Benchmarker<ConvolutionBackwardData> benchmarker_fallback(handle()); | |||
size_t RUN = 50; | |||
benchmarker_fallback.set_display(false) | |||
.set_dtype(0, dtype::Float32{}) | |||
.set_dtype(1, dtype::Float32{}) | |||
.set_dtype(2, dtype::Float32{}) | |||
.set_times(RUN); | |||
auto tnchw = | |||
benchmarker_fallback.set_param(param) | |||
.exec(TensorLayoutArray{filter, diff, grad}); | |||
param.format = Param::Format::NCHW44; | |||
auto tnchw44 = | |||
benchmarker_fallback.set_param(param) | |||
.exec(TensorLayoutArray{filter_nchw44, diff_nchw44, grad_nchw44}); | |||
size_t IC = ic; | |||
size_t FH = fh; | |||
size_t FW = fw; | |||
size_t total_flops = IC * diff.total_nr_elems() * FH * FW * 2; | |||
printf("nchw_time: %.3f ms nchw_flops: %.3f Gflops\n", tnchw, | |||
total_flops / (tnchw / RUN * 1e6)); | |||
printf("nchw44_time: %.3f ms nchw44_flops: %.3f Gflops\n", tnchw44, | |||
total_flops / (tnchw44 / RUN * 1e6)); | |||
printf("speedup: %.3f\n", tnchw / tnchw44); | |||
}; | |||
run(1, 16, 14, 14, 16, 3, 3, 1, 1, 1, 1); | |||
run(1, 32, 28, 28, 16, 3, 3, 1, 1, 1, 1); | |||
run(1, 48, 28, 28, 48, 2, 2, 1, 0, 1, 1); | |||
run(1, 32, 26, 26, 32, 3, 3, 1, 0, 1, 1); | |||
run(2, 32, 64, 64, 32, 3, 3, 1, 0, 1, 1); | |||
run(2, 16, 112, 112, 16, 3, 3, 1, 0, 1, 1); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |