Browse Source

feat(dnn/arm_common): add 9x9s1s2 dot chanwise kernel

GitOrigin-RevId: a28a97fcb5
release-1.10
Megvii Engine Team 3 years ago
parent
commit
03f78547f7
26 changed files with 2190 additions and 16 deletions
  1. +4
    -1
      dnn/scripts/opr_param_defs.py
  2. +31
    -0
      dnn/src/arm_common/conv_bias/int8/algos.h
  3. +270
    -0
      dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp
  4. +425
    -0
      dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp
  5. +27
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h
  6. +40
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h
  7. +221
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp
  8. +250
    -0
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp
  9. +4
    -0
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  10. +2
    -0
      dnn/src/arm_common/conv_bias/opr_impl.h
  11. +82
    -1
      dnn/src/arm_common/matrix_mul/algos.cpp
  12. +62
    -1
      dnn/src/arm_common/matrix_mul/algos.h
  13. +394
    -0
      dnn/src/arm_common/matrix_mul/int8/gemv.cpp
  14. +8
    -0
      dnn/src/arm_common/matrix_mul/int8/gemv.h
  15. +4
    -0
      dnn/src/arm_common/matrix_mul/opr_impl.cpp
  16. +3
    -1
      dnn/src/arm_common/matrix_mul/opr_impl.h
  17. +8
    -1
      dnn/src/arm_common/simd_macro/marm_neon.h
  18. +21
    -0
      dnn/src/common/matrix_mul.cpp
  19. +3
    -1
      dnn/src/fallback/conv_bias/opr_impl.cpp
  20. +3
    -0
      dnn/src/fallback/conv_bias/opr_impl.h
  21. +5
    -5
      dnn/src/fallback/matrix_mul/opr_impl.cpp
  22. +3
    -0
      dnn/src/fallback/matrix_mul/opr_impl.h
  23. +32
    -0
      dnn/src/naive/matrix_mul/matrix_mul_helper.h
  24. +78
    -4
      dnn/test/arm_common/conv_bias.cpp
  25. +103
    -1
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  26. +107
    -0
      dnn/test/arm_common/matrix_mul.cpp

+ 4
- 1
dnn/scripts/opr_param_defs.py View File

@@ -557,7 +557,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'),
Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:'
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'))
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'),
Doc('N32K4_DOT = 4', 'Split 32 from N and 4 from K, better for neon gevm dotprod:'
'N/32, K/4, 32(n), 4(k)')
)
)

(pdef('SVD').


+ 31
- 0
dnn/src/arm_common/conv_bias/int8/algos.h View File

@@ -127,6 +127,37 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8)
};

class ConvBiasImpl::AlgoDotS8DirectChanWiseLarge final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "ARMDOTS8_DIRECT_CHANWISE_LARGE"; }
bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy)
const override;

size_t get_workspace(const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8)
};

class ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "ARMDOTS8_IM2COL_CHANWISE_LARGE"; }
bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy)
const override;

size_t get_workspace(const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8)
};
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }


+ 270
- 0
dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp View File

@@ -0,0 +1,270 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include <arm_neon.h>
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h"
#include "src/common/unroll_macro.h"
#if MGB_ENABLE_DOT
#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel)

using namespace megdnn;
using namespace arm_common;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex;

namespace {

class DirectConvRunner {
public:
DirectConvRunner(size_t flt_size, size_t stride) {
if (flt_size == 9 && stride == 1) {
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16;
} else {
megdnn_assert(flt_size == 9 && stride == 2);
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16;
}
}
size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
auto FW = fm.spatial[1];
return round_up((size_t)FW, m_block_k);
}

size_t get_round_iw(const ConvBiasImpl::NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
size_t SW = fm.stride[1];
size_t OW = param.osz[1];
size_t round_ow = round_up(OW, m_block_ow);
size_t round_fw = get_round_fw(param);
size_t pad_iw = round_ow * SW - SW + round_fw;
return round_up(pad_iw, m_align_iw);
}

size_t get_round_ih(const ConvBiasImpl::NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
size_t SH = fm.stride[0];
size_t OH = param.osz[0];
auto FH = fm.spatial[0];
size_t round_oh = round_up(OH, m_block_oh);
return round_oh * SH - SH + FH;
}

WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];

size_t round_filter = get_round_fw(param) * FH;
size_t round_ih = get_round_ih(param);
size_t round_iw = get_round_iw(param);
size_t pad_src = round_iw * round_ih;
return {nullptr, {pad_src, round_filter}};
}

WorkspaceBundle get_total_bundle(
const ConvBiasImpl::NCBKernSizeParam& param) const {
auto sub_bundle = get_sub_bundle(param);
auto sub_bundle_size = sub_bundle.total_size_in_bytes();
size_t nr_threads = param.nr_threads;
SmallVector<size_t> sizes_in_bytes;
for (size_t i = 0; i < nr_threads; ++i) {
sizes_in_bytes.push_back(sub_bundle_size);
}
WorkspaceBundle total_bundle(nullptr, sizes_in_bytes);
return total_bundle;
}

void run(
const int8_t* pad_src_ptr, const int8_t* round_filter_ptr, int32_t bias,
int8_t* dst_ptr, size_t OH, size_t OW, size_t pad_iw, float scale,
int8_t relu_val) const {
const size_t ow_end = OW / m_block_ow * m_block_ow;
const size_t ow_remain = OW - ow_end;
const size_t oh_end = OH / m_block_oh * m_block_oh;
const size_t oh_remain = OH - oh_end;
int8_t cache[4 * 16];
for (size_t oh = 0; oh < oh_end; oh += m_block_oh) {
for (size_t ow = 0; ow < ow_end; ow += m_block_ow) {
m_func(pad_src_ptr, round_filter_ptr, bias, dst_ptr, oh, ow, OH, OW,
pad_iw, scale, relu_val);
}
if (ow_remain > 0) {
m_func(pad_src_ptr, round_filter_ptr, bias,
&cache[0] - (oh * m_block_ow + ow_end), oh, ow_end, OH,
m_block_ow, pad_iw, scale, relu_val);
for (size_t i = 0; i < m_block_oh; ++i) {
for (size_t j = 0; j < ow_remain; ++j) {
dst_ptr[(i + oh) * OW + (j + ow_end)] = cache[i * 16 + j];
}
}
}
}
if (oh_remain > 0) {
for (size_t ow = 0; ow < ow_end; ow += m_block_ow) {
m_func(pad_src_ptr, round_filter_ptr, bias,
&cache[0] - (oh_end * m_block_ow + ow), oh_end, ow, OH,
m_block_ow, pad_iw, scale, relu_val);
for (size_t i = 0; i < oh_remain; ++i) {
for (size_t j = 0; j < m_block_ow; ++j) {
dst_ptr[(i + oh_end) * OW + (j + ow)] = cache[i * 16 + j];
}
}
}
if (ow_remain > 0) {
m_func(pad_src_ptr, round_filter_ptr, bias,
&cache[0] - (oh_end * m_block_ow + ow_end), oh_end, ow_end, OH,
m_block_ow, pad_iw, scale, relu_val);
for (size_t i = 0; i < oh_remain; ++i) {
for (size_t j = 0; j < ow_remain; ++j) {
dst_ptr[(i + oh_end) * OW + (j + ow_end)] = cache[i * 16 + j];
}
}
}
}
}

private:
std::function<void(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst,
size_t oh, size_t ow, size_t OH, size_t OW, size_t pad_iw,
const float scale, int8_t relu_val)>
m_func;
size_t m_block_oh{4};
size_t m_block_ow{16};
size_t m_block_k{4};
size_t m_align_iw{16};
};

void do_conv(
const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index, const DirectConvRunner& runner) {
auto&& fm = kern_param.filter_meta;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];

float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
float scale_dst_div = 1.f / scale_dst;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
int8_t* pad_src_ptr = static_cast<int8_t*>(bundle.get(0));
int8_t* round_filter_ptr = static_cast<int8_t*>(bundle.get(1));
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id);
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id);
const int8_t* fptr = kern_param.filter<dt_int8>(group_id);
void* dst = kern_param.dst<void>(batch_id, group_id);
size_t pad_iw = runner.get_round_iw(kern_param);

memset(pad_src_ptr, 0, bundle.get_size(0));
rep(ih, IH) {
std::memcpy(
pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW,
sizeof(int8_t) * IW);
}

memset(round_filter_ptr, 0, bundle.get_size(1));
size_t round_fw = runner.get_round_fw(kern_param);
for (size_t fh = 0; fh < FH; ++fh) {
std::memcpy(round_filter_ptr + fh * round_fw, fptr + fh * FW, FW);
}

int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128;
int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr;

int8_t* dst_ptr = (int8_t*)dst;
runner.run(
pad_src_ptr, round_filter_ptr, bias_val, dst_ptr, OH, OW, pad_iw,
scale_bias * scale_dst_div, relu_val);
}

} // namespace

bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()) {
return false;
}
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto FW = fm.spatial[1];
auto SH = fm.stride[0];
auto SW = fm.stride[1];
auto noline = param.nonlineMode;
auto bias_mode = param.bias_mode;
bool avaible =
//! src and filter are qint8, dst is qint8 or qint32
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8)) &&
fm.format == param::Convolution::Format::NCHW && !fm.should_flip &&
(noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) &&
(bias_mode == BiasMode::NO_BIAS ||
bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 &&
fm.ocpg == 1;
return avaible;
}

size_t ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel,
midout_iv("AlgoDotS8DirectChanWiseLarge::get_workspace"_hash)) {
auto&& fm = param.filter_meta;
DirectConvRunner runner(fm.spatial[0], fm.stride[0]);
auto total_bundle = runner.get_total_bundle(param);
return total_bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::
dispatch_kerns(const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel,
midout_iv("AlgoDotS8DirectChanWiseLarge::dispatch_kerns"_hash)) {
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
auto&& fm = param.filter_meta;
DirectConvRunner runner(fm.spatial[0], fm.stride[0]);
WorkspaceBundle wbundle = runner.get_sub_bundle(param);
WorkspaceBundle total_bundle = runner.get_total_bundle(param);

auto exec_one_group = [wbundle, total_bundle, runner](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
WorkspaceBundle temp_total_bundle = total_bundle;
temp_total_bundle.set(kern_param.workspace_ptr);
WorkspaceBundle temp_bundle = wbundle;
temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id));
do_conv(temp_bundle, kern_param, ncb_index, runner);
};
size_t N = param.n;
size_t group = fm.group;
ret_kerns.push_back({exec_one_group, {N, group}});
return ret_kerns;
}
MIDOUT_END();
return {};
}

#endif

+ 425
- 0
dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp View File

@@ -0,0 +1,425 @@
/**
* \file dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megdnn/arch.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"

#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/matrix_mul/int8/gemv.h"

#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel)

using namespace megdnn;
using namespace arm_common;

using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam;
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam;
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex;

namespace {
constexpr size_t block_n = 32;
constexpr size_t block_k = 4;

WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto&& fm = param.filter_meta;
auto OH = param.osz[0];
auto OW = param.osz[1];
size_t IH = param.isz[0];
size_t IW = param.isz[1];
auto FH = fm.spatial[0];
auto FW = fm.spatial[1];
size_t PH = param.filter_meta.padding[0];
size_t PW = param.filter_meta.padding[1];

size_t round_ohw = round_up((size_t)OH * OW, block_n);
size_t round_filter = round_up((size_t)FW, block_k) * FH;
size_t pad_src = (IW + PW * 2) * (IH + PH * 2);
return {nullptr,
{pad_src, round_filter, round_ohw * round_filter,
round_ohw * sizeof(int32_t)}};
}

WorkspaceBundle get_total_bundle(const ConvBiasImpl::NCBKernSizeParam& param) {
auto sub_bundle = get_sub_bundle(param);
auto sub_bundle_size = sub_bundle.total_size_in_bytes();
size_t nr_threads = param.nr_threads;
SmallVector<size_t> sizes_in_bytes;
for (size_t i = 0; i < nr_threads; ++i) {
sizes_in_bytes.push_back(sub_bundle_size);
}
WorkspaceBundle total_bundle(nullptr, sizes_in_bytes);
return total_bundle;
}

template <size_t flt_size, size_t stride>
void im2col(
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw,
size_t round_filter) {
constexpr size_t FH = flt_size;
constexpr size_t FW = flt_size;
constexpr size_t SH = stride;
constexpr size_t SW = stride;
constexpr size_t FW_ROUND = (FW + 3) / 4 * 4;
int bn = 0;
int ni = 0;
for (size_t oh = 0; oh < OH; ++oh)
for (size_t ow = 0; ow < OW; ++ow) {
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;
int bk = 0;
int ki = 0;
for (size_t fh = 0; fh < FH; ++fh)
for (size_t fw = 0; fw < FW_ROUND; ++fw) {
dst[bn * block_n * round_filter + bk * block_n * block_k +
ni * block_k + ki] = src_n[fh * pad_iw + fw];
++ki;
if (ki == block_k) {
ki = 0;
++bk;
}
}
++ni;
if (ni == block_n) {
ni = 0;
++bn;
}
}
}

template <>
void im2col<9, 1>(
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw,
size_t round_filter) {
constexpr size_t FH = 9;
constexpr size_t SH = 1;
constexpr size_t SW = 1;
constexpr size_t k_block_stride = block_k * block_n;

constexpr size_t ow_block = 16;
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4,
2, 3, 4, 5, 3, 4, 5, 6};
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8,
6, 7, 8, 9, 7, 8, 9, 10};
static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12,
10, 11, 12, 13, 11, 12, 13, 14};

uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]);
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]);
uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]);

int bn = 0;
int ni = 0;
for (size_t oh = 0; oh < OH; ++oh)

for (size_t ow = 0; ow < OW;) {
if (ow + ow_block <= OW && ni + ow_block <= block_n) {
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k;
for (size_t fh = 0; fh < FH; ++fh) {
int8x16_t read_w[2];
read_w[0] = vld1q_s8(src_n);
read_w[1] = vld1q_s8(src_n + 16);

int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0);
int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1);
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2);
int8x16_t ncdef_0 =
vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0);

int8x16_t n0123_1 = n4567_0;
int8x16_t n4567_1 = n89ab_0;
int8x16_t n89ab_1 = ncdef_0;
int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0);

int8x16_t n0123_2 = n89ab_0;
int8x16_t n4567_2 = ncdef_0;
int8x16_t n89ab_2 = ncdef_1;
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1);

vst1q_s8(dst_n + 0 * 16, n0123_0);
vst1q_s8(dst_n + 1 * 16, n4567_0);
vst1q_s8(dst_n + 2 * 16, n89ab_0);
vst1q_s8(dst_n + 3 * 16, ncdef_0);

vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1);
vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1);
vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1);
vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1);

vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2);
vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2);
vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2);
vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2);

dst_n += 3 * k_block_stride;
src_n += pad_iw;
}
ni += ow_block;
ow += ow_block;
if (ni == block_n) {
ni = 0;
++bn;
}
} else {
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k;
for (size_t fh = 0; fh < FH; ++fh) {
int8x16_t read_w[0];
read_w[0] = vld1q_s8(src_n);
vst1q_lane_s32(dst_n, read_w[0], 0);
vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1);
vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2);
dst_n += 3 * k_block_stride;
src_n += pad_iw;
}
++ni;
++ow;
if (ni == block_n) {
ni = 0;
++bn;
}
}
}
}

template <>
void im2col<9, 2>(
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw,
size_t round_filter) {
constexpr size_t FH = 9;
constexpr size_t SH = 2;
constexpr size_t SW = 2;
constexpr size_t k_block_stride = block_k * block_n;

constexpr size_t ow_block = 16;
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5,
4, 5, 6, 7, 6, 7, 8, 9};
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9,
8, 9, 10, 11, 10, 11, 12, 13};

uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]);
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]);

int bn = 0;
int ni = 0;
for (size_t oh = 0; oh < OH; ++oh)
for (size_t ow = 0; ow < OW;) {
if (ow + ow_block <= OW && ni + ow_block <= block_n) {
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k;
for (size_t fh = 0; fh < FH; ++fh) {
int8x16_t read_w[3];
read_w[0] = vld1q_s8(src_n);
read_w[1] = vld1q_s8(src_n + 16);
read_w[2] = vld1q_s8(src_n + 32);
int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8);
int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8);

int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0);
int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0);
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0);
int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0);

int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1);
int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1);
int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1);
int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1);

int8x16_t n0123_2 = n4567_0;
int8x16_t n4567_2 = n89ab_0;
int8x16_t n89ab_2 = ncdef_0;
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0);

vst1q_s8(dst_n + 0 * 16, n0123_0);
vst1q_s8(dst_n + 1 * 16, n4567_0);
vst1q_s8(dst_n + 2 * 16, n89ab_0);
vst1q_s8(dst_n + 3 * 16, ncdef_0);

vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1);
vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1);
vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1);
vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1);

vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2);
vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2);
vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2);
vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2);

dst_n += 3 * k_block_stride;
src_n += pad_iw;
}
ni += ow_block;
ow += ow_block;
if (ni == block_n) {
ni = 0;
++bn;
}
} else {
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k;
for (size_t fh = 0; fh < FH; ++fh) {
int8x16_t read_w[0];
read_w[0] = vld1q_s8(src_n);
vst1q_lane_s32(dst_n, read_w[0], 0);
vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1);
vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2);
dst_n += 3 * k_block_stride;
src_n += pad_iw;
}
++ni;
++ow;
if (ni == block_n) {
ni = 0;
++bn;
}
}
}
}

void do_conv(
const WorkspaceBundle& bundle, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) {
auto&& fm = kern_param.filter_meta;
size_t PH = kern_param.filter_meta.padding[0];
size_t PW = kern_param.filter_meta.padding[1];
size_t OH = kern_param.osz[0];
size_t OW = kern_param.osz[1];
size_t IH = kern_param.isz[0];
size_t IW = kern_param.isz[1];
size_t FH = fm.spatial[0];
size_t FW = fm.spatial[1];
size_t SH = fm.stride[0];

float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale;
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale;
float scale_dst_div = 1.f / scale_dst;

size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
int8_t* pad_src_ptr = static_cast<int8_t*>(bundle.get(0));
int8_t* round_filter_ptr = static_cast<int8_t*>(bundle.get(1));
int8_t* im2col_ptr = static_cast<int8_t*>(bundle.get(2));
int32_t* i32_ptr = static_cast<int32_t*>(bundle.get(3));
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id);
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id);
const int8_t* fptr = kern_param.filter<dt_int8>(group_id);
void* dst = kern_param.dst<void>(batch_id, group_id);

size_t round_filter = round_up(FW, block_k) * FH;
size_t pad_iw = IW + 2 * PW;

memset(pad_src_ptr, 0, bundle.get_size(0));
rep(ih, IH) {
std::memcpy(
pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW,
sizeof(int8_t) * IW);
}

memset(round_filter_ptr, 0, bundle.get_size(1));
size_t fh_stride = round_up(FW, block_k);
for (size_t fh = 0; fh < FH; ++fh) {
std::memcpy(round_filter_ptr + fh * fh_stride, fptr + fh * FW, FW);
}

memset(im2col_ptr, 0, bundle.get_size(2));
if (SH == 1) {
im2col<9, 1>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter);
} else {
im2col<9, 2>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter);
}

gevm_naive_n32k4_dot(
round_filter_ptr, im2col_ptr, i32_ptr, 1, OH * OW, round_filter, 0, 0, 0);

int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr;
int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128;
int8_t* dst_ptr = (int8_t*)dst;
for (size_t i = 0; i < OH * OW; ++i) {
//! optimize by tbl
int val = roundf(scale_bias * scale_dst_div * (i32_ptr[i] + bias_val));
val = val < -128 ? -128 : val;
val = val > 127 ? 127 : val;
val = val > relu_val ? val : relu_val;
dst_ptr[i] = val;
}
}

} // namespace

bool ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()) {
return false;
}
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto FW = fm.spatial[1];
auto SH = fm.stride[0];
auto SW = fm.stride[1];
auto noline = param.nonlineMode;
auto bias_mode = param.bias_mode;
bool avaible =
//! src and filter are qint8, dst is qint8 or qint32
(param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8)) &&
fm.format == param::Convolution::Format::NCHW && !fm.should_flip &&
(noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) &&
(bias_mode == BiasMode::NO_BIAS ||
bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 &&
fm.ocpg == 1;
return avaible;
}

size_t ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel,
midout_iv("AlgoDotS8Im2colChanWiseLarge::get_workspace"_hash)) {
auto bundle = get_total_bundle(param);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::
dispatch_kerns(const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel,
midout_iv("AlgoDotS8Im2colChanWiseLarge::dispatch_kerns"_hash)) {
SmallVector<ConvBiasImpl::NCBKern> ret_kerns;
auto fm = param.filter_meta;
size_t N = param.n;
size_t group = fm.group;
WorkspaceBundle wbundle = get_sub_bundle(param);
WorkspaceBundle total_bundle = get_total_bundle(param);
auto exec_one_group = [wbundle, total_bundle](
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index) mutable {
WorkspaceBundle temp_total_bundle = total_bundle;
temp_total_bundle.set(kern_param.workspace_ptr);
WorkspaceBundle temp_bundle = wbundle;
temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id));
do_conv(temp_bundle, kern_param, ncb_index);
};
ret_kerns.push_back({exec_one_group, {N, group}});
return ret_kerns;
}
MIDOUT_END();
return {};
}
#endif

+ 27
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h View File

@@ -0,0 +1,27 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/arch.h"
#if MGB_ENABLE_DOT

void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh,
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val);

void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh,
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val);

#endif

+ 40
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h View File

@@ -0,0 +1,40 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/arch.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"

static inline void quant_store_s8(
float32x4_t v0, float32x4_t v1, float32x4_t v2, float32x4_t v3, int8_t* ptr,
int8x16_t relu_reg) {
int32x4_t i0 = vcvtaq_s32_f32(v0);
int32x4_t i1 = vcvtaq_s32_f32(v1);
int32x4_t i2 = vcvtaq_s32_f32(v2);
int32x4_t i3 = vcvtaq_s32_f32(v3);

int16x4_t i16_0 = vqmovn_s32(i0);
int16x4_t i16_1 = vqmovn_s32(i1);
int16x4_t i16_2 = vqmovn_s32(i2);
int16x4_t i16_3 = vqmovn_s32(i3);

int8x8_t i8_0 = vqmovn_s16(vcombine_s16(i16_0, i16_1));
int8x8_t i8_1 = vqmovn_s16(vcombine_s16(i16_2, i16_3));
int8x16_t rst = vcombine_s8(i8_0, i8_1);

rst = vmaxq_s8(rst, relu_reg);

vst1q_s8(ptr, rst);
}

#endif

+ 221
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp View File

@@ -0,0 +1,221 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/arch.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"

#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h"
#include "src/common/unroll_macro.h"

MEGDNN_ATTRIBUTE_TARGET("dotprod")
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh,
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val) {
//! 4x16
const size_t SH = 1;
const size_t SW = 1;

static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4,
2, 3, 4, 5, 3, 4, 5, 6};
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8,
6, 7, 8, 9, 7, 8, 9, 10};
static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12,
10, 11, 12, 13, 11, 12, 13, 14};

uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]);
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]);
uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]);

const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;

//! init
int32x4_t c[4][4];
#define cb(step) \
c[step][0] = vdupq_n_s32(bias); \
c[step][1] = vdupq_n_s32(bias); \
c[step][2] = vdupq_n_s32(bias); \
c[step][3] = vdupq_n_s32(bias);

UNROLL_CALL_RAW(4, cb);
#undef cb
int8x16_t flt[4];
flt[0] = vld1q_s8(weight + 0 * 16);
flt[1] = vld1q_s8(weight + 1 * 16);
flt[2] = vld1q_s8(weight + 2 * 16);
flt[3] = vld1q_s8(weight + 3 * 16);
//! row 0

int8x16_t read_w[2];
read_w[0] = vld1q_s8(src_n + 0 * pad_iw);
read_w[1] = vld1q_s8(src_n + 0 * pad_iw + 16);
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0);
int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1);
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2);
int8x16_t ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0);
int8x16_t n0123_1 = n4567_0;
int8x16_t n4567_1 = n89ab_0;
int8x16_t n89ab_1 = ncdef_0;
int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0);
int8x16_t n0123_2 = n89ab_0;
int8x16_t n4567_2 = ncdef_0;
int8x16_t n89ab_2 = ncdef_1;
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1);
#define CAL_C(oh, flt_start) \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4);

CAL_C(0, 0);

//! row 1
#define LOAD_SRC(row_id) \
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \
n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \
n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \
ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \
n0123_1 = n4567_0; \
n4567_1 = n89ab_0; \
n89ab_1 = ncdef_0; \
ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \
n0123_2 = n89ab_0; \
n4567_2 = ncdef_0; \
n89ab_2 = ncdef_1; \
ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1);

LOAD_SRC(1);
CAL_C(0, 3);
CAL_C(1, 0);

//! row 2
LOAD_SRC(2);
CAL_C(0, 3 * 2);
CAL_C(1, 3 * 1);
CAL_C(2, 3 * 0);

//! row 3
LOAD_SRC(3);
CAL_C(0, 3 * 3);
CAL_C(1, 3 * 2);
CAL_C(2, 3 * 1);
CAL_C(3, 3 * 0);

//! row 4
LOAD_SRC(4);
CAL_C(0, 3 * 4);
CAL_C(1, 3 * 3);
CAL_C(2, 3 * 2);
CAL_C(3, 3 * 1);
//! update flt 4 -> 0
flt[0] = vld1q_s8(weight + 4 * 16);

//! row 5
LOAD_SRC(5);
CAL_C(0, 3 * 5);
CAL_C(1, 3 * 4);
CAL_C(2, 3 * 3);
CAL_C(3, 3 * 2);
//! update flt 5 -> 1
flt[1] = vld1q_s8(weight + 5 * 16);

//! row 6
LOAD_SRC(6);
CAL_C(0, 3 * 6);
CAL_C(1, 3 * 5);
CAL_C(2, 3 * 4);
CAL_C(3, 3 * 3);
//! update flt 6 -> 2
flt[2] = vld1q_s8(weight + 6 * 16);

//! row 7
LOAD_SRC(7);
CAL_C(0, 3 * 7);
CAL_C(1, 3 * 6);
CAL_C(2, 3 * 5);
CAL_C(3, 3 * 4);

//! row 8
LOAD_SRC(8);
CAL_C(0, 3 * 8);
CAL_C(1, 3 * 7);
CAL_C(2, 3 * 6);
CAL_C(3, 3 * 5);

//! row 9
LOAD_SRC(9);
CAL_C(1, 3 * 8);
CAL_C(2, 3 * 7);
CAL_C(3, 3 * 6);

//! row 10
LOAD_SRC(10);
CAL_C(2, 3 * 8);
CAL_C(3, 3 * 7);

//! row 11
LOAD_SRC(11);
CAL_C(3, 3 * 8);

float32x4_t dst_reg[4][4];
#define cb(step) \
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]);

UNROLL_CALL_RAW(4, cb);
#undef cb

#define cb(step) \
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale);

UNROLL_CALL_RAW(4, cb);
#undef cb
int8_t* dst_store = dst + oh * OW + ow;
int8x16_t relu_reg = vdupq_n_s8(relu_val);
#define cb(step) \
quant_store_s8( \
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \
dst_store + step * OW, relu_reg);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
#endif

+ 250
- 0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp View File

@@ -0,0 +1,250 @@
/**
* \file
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megdnn/arch.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"

#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h"
#include "src/common/unroll_macro.h"

MEGDNN_ATTRIBUTE_TARGET("dotprod")
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16(
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh,
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale,
int8_t relu_val) {
//! 4x16
const size_t SH = 2;
const size_t SW = 2;

static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5,
4, 5, 6, 7, 6, 7, 8, 9};
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9,
8, 9, 10, 11, 10, 11, 12, 13};

uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]);
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]);

const int8_t* src_n = src + oh * SH * pad_iw + ow * SW;

//! init
int32x4_t c[4][4];
#define cb(step) \
c[step][0] = vdupq_n_s32(bias); \
c[step][1] = vdupq_n_s32(bias); \
c[step][2] = vdupq_n_s32(bias); \
c[step][3] = vdupq_n_s32(bias);

UNROLL_CALL_RAW(4, cb);
#undef cb

constexpr int flt_reg = 7;
constexpr int flt_per_reg = 4;
int8x16_t flt[7];
flt[0] = vld1q_s8(weight + 0 * 16);
flt[1] = vld1q_s8(weight + 1 * 16);
flt[2] = vld1q_s8(weight + 2 * 16);
flt[3] = vld1q_s8(weight + 3 * 16);
flt[4] = vld1q_s8(weight + 4 * 16);
flt[5] = vld1q_s8(weight + 5 * 16);
flt[6] = vld1q_s8(weight + 6 * 16);

#define CAL_C(oh, flt_start) \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg);

#define LOAD_SRC(row_id) \
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \
read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \
ext_8 = vextq_s8(read_w[0], read_w[1], 8); \
ext_24 = vextq_s8(read_w[1], read_w[2], 8); \
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \
n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \
n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \
ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \
n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \
n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \
n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \
ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \
n0123_2 = n4567_0; \
n4567_2 = n89ab_0; \
n89ab_2 = ncdef_0; \
ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0);

//! row 0
int8x16_t read_w[3];
read_w[0] = vld1q_s8(src_n);
read_w[1] = vld1q_s8(src_n + 16);
read_w[2] = vld1q_s8(src_n + 32);
int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8);
int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8);

int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0);
int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0);
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0);
int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0);

int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1);
int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1);
int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1);
int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1);

int8x16_t n0123_2 = n4567_0;
int8x16_t n4567_2 = n89ab_0;
int8x16_t n89ab_2 = ncdef_0;
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0);

CAL_C(0, 0);

//! row 1
LOAD_SRC(1);
CAL_C(0, 3 * 1);

//! row 2
LOAD_SRC(2);
CAL_C(0, 3 * 2);
CAL_C(1, 3 * 0);

//! row 3
LOAD_SRC(3);
CAL_C(0, 3 * 3);
CAL_C(1, 3 * 1);

//! row 4
LOAD_SRC(4);
CAL_C(0, 3 * 4);
CAL_C(1, 3 * 2);
CAL_C(2, 3 * 0);

//! row 5
LOAD_SRC(5);
CAL_C(0, 3 * 5);
CAL_C(1, 3 * 3);
CAL_C(2, 3 * 1);

//! row 6
LOAD_SRC(6);
CAL_C(0, 3 * 6);
CAL_C(1, 3 * 4);
CAL_C(2, 3 * 2);
CAL_C(3, 3 * 0);

//! row 7
LOAD_SRC(7);
CAL_C(0, 3 * 7);
CAL_C(1, 3 * 5);
CAL_C(2, 3 * 3);
CAL_C(3, 3 * 1);

//! row 8
LOAD_SRC(8);
CAL_C(0, 3 * 8);
CAL_C(1, 3 * 6);
CAL_C(2, 3 * 4);
CAL_C(3, 3 * 2);

//! row 9
LOAD_SRC(9);
CAL_C(1, 3 * 7);
CAL_C(2, 3 * 5);
CAL_C(3, 3 * 3);

//! row 10
LOAD_SRC(10);
CAL_C(1, 3 * 8);
CAL_C(2, 3 * 6);
CAL_C(3, 3 * 4);

//! row 11
LOAD_SRC(11);
CAL_C(2, 3 * 7);
CAL_C(3, 3 * 5);

//! row 12
LOAD_SRC(12);
CAL_C(2, 3 * 8);
CAL_C(3, 3 * 6);

//! row 13
LOAD_SRC(13);
CAL_C(3, 3 * 7);

//! row 14
LOAD_SRC(14);
CAL_C(3, 3 * 8);

float32x4_t dst_reg[4][4];
#define cb(step) \
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]);

UNROLL_CALL_RAW(4, cb);
#undef cb

#define cb(step) \
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale);

UNROLL_CALL_RAW(4, cb);
#undef cb
int8_t* dst_store = dst + oh * OW + ow;
int8x16_t relu_reg = vdupq_n_s8(relu_val);
#define cb(step) \
quant_store_s8( \
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \
dst_store + step * OW, relu_reg);

UNROLL_CALL_RAW(4, cb);
#undef cb
}
#endif

+ 4
- 0
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -54,6 +54,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {

AlgoDotS8Direct_NCHW44 ds8_direct_nchw44;
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44;
AlgoDotS8Im2colChanWiseLarge ds8_im2col_large_chanwise;
AlgoDotS8DirectChanWiseLarge ds8_direct_large_chanwise;
#endif

AlgoI8x8x16Direct i8x8x16_direct;
@@ -75,6 +77,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
#if MGB_ENABLE_DOT
m_direct_algos.emplace_back(&ds8_direct_large_chanwise);
m_direct_algos.emplace_back(&ds8_im2col_large_chanwise);
m_direct_algos.emplace_back(&ds8_direct_stride1);
m_direct_algos.emplace_back(&ds8_direct_stride2);
m_direct_algos.emplace_back(&du8_direct_stride1);


+ 2
- 0
dnn/src/arm_common/conv_bias/opr_impl.h View File

@@ -51,6 +51,8 @@ private:
#endif
#if MGB_ENABLE_DOT
class AlgoDotS8DirectNCHWNCHW44;
class AlgoDotS8DirectChanWiseLarge;
class AlgoDotS8Im2colChanWiseLarge;
class AlgoDotS8DirectStride1;
class AlgoDotS8DirectStride2;
class AlgoDotU8DirectStride1;


+ 82
- 1
dnn/src/arm_common/matrix_mul/algos.cpp View File

@@ -143,8 +143,89 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern(
const KernSizeParam&) const {
return int8x8x32_gemv_mk4_kern;
}

#if MGB_ENABLE_DOT
namespace {
void int8x8x32_gevm_dot_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gevm_naive_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

bool MatrixMulImpl::AlgoInt8x8x32GevmDot::usable(
const KernSizeParam& kern_size_param) const {
if (!cpuinfo_has_arm_neon_dot()) {
return false;
}
auto M = kern_size_param.M;
bool is_dtype_ok =
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32);

return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT && is_dtype_ok &&
!kern_size_param.trA && !kern_size_param.trB && M == 1;
}

bool MatrixMulImpl::AlgoInt8x8x32GevmDot::preferred(
const KernSizeParam& kern_size_param) const {
return true;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmDot::get_kern(
const KernSizeParam&) const {
return int8x8x32_gevm_dot_kern;
}

namespace {
void int8x8x32_gevm_n32k4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gevm_naive_n32k4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::usable(
const KernSizeParam& kern_size_param) const {
if (!cpuinfo_has_arm_neon_dot()) {
return false;
}
auto M = kern_size_param.M;

bool is_dtype_ok =
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32);
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::N32K4_DOT &&
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && M == 1;
}

bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::preferred(
const KernSizeParam& kern_size_param) const {
return true;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::get_kern(
const KernSizeParam&) const {
return int8x8x32_gevm_n32k4_dot_kern;
}

/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */
namespace {
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {


+ 62
- 1
dnn/src/arm_common/matrix_mul/algos.h View File

@@ -49,8 +49,69 @@ public:
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4)
};

#if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoInt8x8x32GevmDot : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_DOT"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, DEFAULT)
WorkspaceBundle get_bundle(const KernSizeParam&) const override {
return WorkspaceBundle{nullptr, {}};
}
kern_naked_t get_kern_naked(const KernSizeParam&) const override {
megdnn_assert(0, "naked kern no impl");
}
void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride)
const override {
megdnn_assert(0, "pack_A no impl");
}
void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax)
const override {
megdnn_assert(0, "pack_B no impl");
}
InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; };
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_DOT)
};

class MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, N32K4_DOT)
WorkspaceBundle get_bundle(const KernSizeParam&) const override {
return WorkspaceBundle{nullptr, {}};
}
kern_naked_t get_kern_naked(const KernSizeParam&) const override {
megdnn_assert(0, "naked kern no impl");
}
void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride)
const override {
megdnn_assert(0, "pack_A no impl");
}
void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax)
const override {
megdnn_assert(0, "pack_B no impl");
}
InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; };
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT)
};

class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase {
public:
AlgoAttribute attribute() const override {


+ 394
- 0
dnn/src/arm_common/matrix_mul/int8/gemv.cpp View File

@@ -2,6 +2,7 @@

#include "megdnn/oprs.h"
#include "src/arm_common/matrix_mul/int8/gemv.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"

#include "midout.h"
@@ -430,5 +431,398 @@ void arm_common::gemv_like_mk4_dot(
MIDOUT_END();
}
#endif
#if MGB_ENABLE_DOT
namespace {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void gevm_naive_dot_impl(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride,
bool load_c) {
constexpr size_t n_block = 32;
const size_t n_end = N / n_block * n_block;
const size_t n_remain = N - n_end;

constexpr size_t k_block = 4;
constexpr size_t k_block_x2 = k_block * 2;
const size_t k_end = (K / k_block_x2) * k_block_x2;
const size_t k_remain = K - k_end;
for (size_t n = 0; n < n_end; n += n_block) {
if (K < k_block_x2) {
if (!load_c) {
for (size_t i = 0; i < n_block; ++i) {
C[n + i] = 0;
}
}
for (size_t k = 0; k < K; ++k) {
for (size_t i = 0; i < n_block; ++i) {
C[n + i] += A[k] * B[k * Bstride + n + i];
}
}
continue;
}
int32x4_t c[8];
if (load_c) {
#define cb(step) c[step] = vld1q_s32(C + n + step * 4);
UNROLL_CALL_RAW(8, cb);
#undef cb
} else {
#define cb(step) c[step] = vdupq_n_s32(0);
UNROLL_CALL_RAW(8, cb);
#undef cb
}
int8x16_t a[2];
a[0] = vld1q_dup_s32(A);
int8x16_t b[2][8];
#define cb(step) \
b[0][step * 2 + 0] = vld1q_s8(B + (0 + step) * Bstride + n); \
b[0][step * 2 + 1] = vld1q_s8(B + (0 + step) * Bstride + n + 16);
UNROLL_CALL_RAW(4, cb);
#undef cb
size_t k_buffer_end = k_end - k_block_x2;
for (size_t k = 0; k < k_buffer_end; k += k_block_x2) {
//! double buffer main
#define cb(step) \
b[1][step * 2 + 0] = vld1q_s8(B + (k + step + k_block) * Bstride + n); \
b[1][step * 2 + 1] = vld1q_s8(B + (k + step + k_block) * Bstride + n + 16);
UNROLL_CALL_RAW(4, cb);
#undef cb
a[1] = vld1q_dup_s32(A + k + k_block);

int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]);
int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]);
int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]);
int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]);
int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]);
int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]);
int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]);
int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]);
c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]);
c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]);
c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]);
c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]);
c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]);
c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]);
c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]);
c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]);
#define cb(step) \
b[0][step * 2 + 0] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n); \
b[0][step * 2 + 1] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n + 16);
UNROLL_CALL_RAW(4, cb);
#undef cb
a[0] = vld1q_dup_s32(A + k + k_block_x2);

ab0 = vzipq_s8(b[1][0], b[1][2]);
cd0 = vzipq_s8(b[1][4], b[1][6]);
ab1 = vzipq_s8(b[1][1], b[1][3]);
cd1 = vzipq_s8(b[1][5], b[1][7]);

abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]);
abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]);
abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]);
abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]);
c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]);
c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]);
c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]);
c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]);
c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]);
c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]);
c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]);
c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]);
}
//! double buffer remain
#define cb(step) \
b[1][step * 2 + 0] = vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n); \
b[1][step * 2 + 1] = \
vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n + 16);
UNROLL_CALL_RAW(4, cb);
#undef cb
a[1] = vld1q_dup_s32(A + k_buffer_end + k_block);

int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]);
int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]);
int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]);
int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]);
int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]);
int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]);
int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]);
int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]);
c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]);
c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]);
c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]);
c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]);
c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]);
c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]);
c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]);
c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]);

ab0 = vzipq_s8(b[1][0], b[1][2]);
cd0 = vzipq_s8(b[1][4], b[1][6]);
ab1 = vzipq_s8(b[1][1], b[1][3]);
cd1 = vzipq_s8(b[1][5], b[1][7]);
abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]);
abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]);
abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]);
abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]);
c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]);
c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]);
c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]);
c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]);
c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]);
c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]);
c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]);
c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]);

vst1q_s32(C + n + 0 * 4, c[0]);
vst1q_s32(C + n + 1 * 4, c[1]);
vst1q_s32(C + n + 2 * 4, c[2]);
vst1q_s32(C + n + 3 * 4, c[3]);
vst1q_s32(C + n + 4 * 4, c[4]);
vst1q_s32(C + n + 5 * 4, c[5]);
vst1q_s32(C + n + 6 * 4, c[6]);
vst1q_s32(C + n + 7 * 4, c[7]);
if (k_remain > 0) {
for (size_t k = k_end; k < K; ++k) {
for (size_t i = 0; i < n_block; ++i) {
C[n + i] += A[k] * B[k * Bstride + n + i];
}
}
}
}

if (n_remain > 0) {
for (size_t n = n_end; n < N; ++n) {
if (!load_c) {
C[n] = 0;
}
for (size_t k = 0; k < K; ++k) {
C[n] += A[k] * B[k * Bstride + n];
}
}
}
}
#if MEGDNN_ARMV7
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void gevm_naive_dot_n32k4_impl(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride,
bool load_c) {
//! input must be N/32, k/4, 32, 4
//! TODO: add prefetch
//! TODO: add double buffer
constexpr size_t n_block = 32;
constexpr size_t k_block = 4;
for (size_t n = 0; n < N; n += n_block) {
int32x4_t c[n_block / 4];
#define cb(step) c[step] = vdupq_n_s32(0);
UNROLL_CALL_RAW(8, cb);
#undef cb
const int8_t* b_base = B + n * K;
for (size_t k = 0; k < K; k += k_block) {
int8x16_t a[1];
int8x16_t b[1][8];
#define cb(step) b[0][step] = vld1q_s8(b_base + k * 32 + 16 * step);
UNROLL_CALL_RAW(8, cb);
#undef cb
a[0] = vld1q_dup_s32(A + k);

c[0] = vdotq_s32(c[0], b[0][0], a[0]);
c[1] = vdotq_s32(c[1], b[0][1], a[0]);
c[2] = vdotq_s32(c[2], b[0][2], a[0]);
c[3] = vdotq_s32(c[3], b[0][3], a[0]);
c[4] = vdotq_s32(c[4], b[0][4], a[0]);
c[5] = vdotq_s32(c[5], b[0][5], a[0]);
c[6] = vdotq_s32(c[6], b[0][6], a[0]);
c[7] = vdotq_s32(c[7], b[0][7], a[0]);
}
vst1q_s32(C + n + 0 * 4, c[0]);
vst1q_s32(C + n + 1 * 4, c[1]);
vst1q_s32(C + n + 2 * 4, c[2]);
vst1q_s32(C + n + 3 * 4, c[3]);
vst1q_s32(C + n + 4 * 4, c[4]);
vst1q_s32(C + n + 5 * 4, c[5]);
vst1q_s32(C + n + 6 * 4, c[6]);
vst1q_s32(C + n + 7 * 4, c[7]);
}
}
#else
MEGDNN_ATTRIBUTE_TARGET("dotprod")
inline void n32k4_dot(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t K) {
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;
//! C q0-q7
//! A q8-q9
//! B q10-q25
asm volatile(
// load accumulator C
"1:\n"
"eor v0.16b, v0.16b, v0.16b\n"
"eor v1.16b, v1.16b, v1.16b\n"
"eor v2.16b, v2.16b, v2.16b\n"
"eor v3.16b, v3.16b, v3.16b\n"
"eor v4.16b, v4.16b, v4.16b\n"
"eor v5.16b, v5.16b, v5.16b\n"
"eor v6.16b, v6.16b, v6.16b\n"
"eor v7.16b, v7.16b, v7.16b\n"

"ld1r {v8.4s}, [%[a_ptr]]\n"
"ld1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[b_ptr]], 64\n"
"ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[b_ptr]], 64\n"
"add %[a_ptr], %[a_ptr], #4\n"

"cmp %w[k], #0\n"
"beq 4f\n"

"2: \n"
// Loop proper
"3:\n"
"ld1r {v9.4s}, [%[a_ptr]]\n"
"sdot v0.4s, v10.16b, v8.16b\n"
"ldr q18, [%[b_ptr], #0]\n"
"sdot v1.4s, v11.16b, v8.16b\n"
"ldr q19, [%[b_ptr], #16]\n"
"sdot v2.4s, v12.16b, v8.16b\n"
"ldr q20, [%[b_ptr], #32]\n"
"add %[a_ptr], %[a_ptr], #4\n"
"sdot v3.4s, v13.16b, v8.16b\n"
"ldr q21, [%[b_ptr], #48]\n"
"sdot v4.4s, v14.16b, v8.16b\n"
"ldr q22, [%[b_ptr], #64]\n"
"sdot v5.4s, v15.16b, v8.16b\n"
"ldr q23, [%[b_ptr], #80]\n"
"sdot v6.4s, v16.16b, v8.16b\n"
"ldr q24, [%[b_ptr], #96]\n"
"sdot v7.4s, v17.16b, v8.16b\n"
"ldr q25, [%[b_ptr], #112]\n"

"ld1r {v8.4s}, [%[a_ptr]]\n"
"sdot v0.4s, v18.16b, v9.16b\n"
"ldr q10, [%[b_ptr], #128]\n"
"sdot v1.4s, v19.16b, v9.16b\n"
"ldr q11, [%[b_ptr], #144]\n"
"sdot v2.4s, v20.16b, v9.16b\n"
"ldr q12, [%[b_ptr], #160]\n"
"sdot v3.4s, v21.16b, v9.16b\n"
"ldr q13, [%[b_ptr], #176]\n"
"sdot v4.4s, v22.16b, v9.16b\n"
"ldr q14, [%[b_ptr], #192]\n"
"sdot v5.4s, v23.16b, v9.16b\n"
"ldr q15, [%[b_ptr], #208]\n"
"sdot v6.4s, v24.16b, v9.16b\n"
"ldr q16, [%[b_ptr], #224]\n"
"sdot v7.4s, v25.16b, v9.16b\n"
"ldr q17, [%[b_ptr], #240]\n"

"add %[a_ptr], %[a_ptr], #4\n"
"add %[b_ptr], %[b_ptr], #256\n"

"subs %w[k], %w[k], #1\n"
"bne 3b\n"

"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
// Even tail

"ld1r {v9.4s}, [%[a_ptr]]\n"
"sdot v0.4s, v10.16b, v8.16b\n"
"ldr q18, [%[b_ptr], #0]\n"
"sdot v1.4s, v11.16b, v8.16b\n"
"ldr q19, [%[b_ptr], #16]\n"
"sdot v2.4s, v12.16b, v8.16b\n"
"ldr q20, [%[b_ptr], #32]\n"
"sdot v3.4s, v13.16b, v8.16b\n"
"ldr q21, [%[b_ptr], #48]\n"
"sdot v4.4s, v14.16b, v8.16b\n"
"ldr q22, [%[b_ptr], #64]\n"
"sdot v5.4s, v15.16b, v8.16b\n"
"ldr q23, [%[b_ptr], #80]\n"
"sdot v6.4s, v16.16b, v8.16b\n"
"ldr q24, [%[b_ptr], #96]\n"
"sdot v7.4s, v17.16b, v8.16b\n"
"ldr q25, [%[b_ptr], #112]\n"

"sdot v0.4s, v18.16b, v9.16b\n"
"sdot v1.4s, v19.16b, v9.16b\n"
"sdot v2.4s, v20.16b, v9.16b\n"
"sdot v3.4s, v21.16b, v9.16b\n"
"sdot v4.4s, v22.16b, v9.16b\n"
"sdot v5.4s, v23.16b, v9.16b\n"
"sdot v6.4s, v24.16b, v9.16b\n"
"sdot v7.4s, v25.16b, v9.16b\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n"
"st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n"
"b 6f\n"

"5:\n"
// Odd tail
"sdot v0.4s, v10.16b, v8.16b\n"
"sdot v1.4s, v11.16b, v8.16b\n"
"sdot v2.4s, v12.16b, v8.16b\n"
"sdot v3.4s, v13.16b, v8.16b\n"
"sdot v4.4s, v14.16b, v8.16b\n"
"sdot v5.4s, v15.16b, v8.16b\n"
"sdot v6.4s, v16.16b, v8.16b\n"
"sdot v7.4s, v17.16b, v8.16b\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n"
"st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n"
"6:\n"

: [a_ptr] "+r"(A), [b_ptr] "+r"(B), [k] "+r"(K), [c_ptr] "+r"(C),
[oddk] "+r"(oddk)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11",
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21",
"v22", "v23", "v24", "v25", "cc", "memory");
}
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void gevm_naive_dot_n32k4_impl(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride,
bool load_c) {
//! input must be N/32, k/4, 32, 4
//! TODO: add prefetch
//! TODO: add double buffer
constexpr size_t n_block = 32;
for (size_t n = 0; n < N; n += n_block) {
n32k4_dot(A, B + n * K, C + n, K / 4);
}
}
#endif
} // namespace

void arm_common::gevm_naive_dot(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(M == 1);
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot"_hash)) {
size_t cache_size = 256 * 1024;
size_t k_group = N * K / cache_size;
constexpr size_t k_align = 8;
if (k_group >= 2) {
size_t k_per_group = ((K / k_group) + k_align - 1) / k_align * k_align;
for (size_t k = 0; k < K; k += k_per_group) {
size_t real_k = std::min(K - k, k_per_group);
gevm_naive_dot_impl(
A + k, B + k * Bstride, C, M, N, real_k, Astride, Bstride,
Cstride, k != 0);
}
} else {
gevm_naive_dot_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false);
}
}
MIDOUT_END();
}

void arm_common::gevm_naive_n32k4_dot(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(M == 1);
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot_nk4"_hash)) {
gevm_naive_dot_n32k4_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false);
}
MIDOUT_END();
}
#endif
// vim: syntax=cpp.doxygen

+ 8
- 0
dnn/src/arm_common/matrix_mul/int8/gemv.h View File

@@ -22,6 +22,14 @@ void gemv_like_mk4(
void gemv_like_mk4_dot(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride);

void gevm_naive_dot(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride);

void gevm_naive_n32k4_dot(
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride);
#endif

} // namespace arm_common


+ 4
- 0
dnn/src/arm_common/matrix_mul/opr_impl.cpp View File

@@ -14,6 +14,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4;
#if MGB_ENABLE_DOT
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot;
AlgoInt8x8x32GevmDot int8x8x32_gevm_dot;
AlgoInt8x8x32GevmN32K4Dot int8x8x32_gevm_n32k4_dot;
#endif
AlgoGevm gevm;

@@ -28,6 +30,8 @@ public:
#endif
#if MGB_ENABLE_DOT
m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot);
m_all_algos.emplace_back(&int8x8x32_gevm_dot);
m_all_algos.emplace_back(&int8x8x32_gevm_n32k4_dot);
#endif
m_all_algos.emplace_back(&int8x8x32_gemv);
m_all_algos.emplace_back(&int8x8x32_gemv_mk4);


+ 3
- 1
dnn/src/arm_common/matrix_mul/opr_impl.h View File

@@ -31,7 +31,9 @@ protected:
class AlgoF16Gemv;
#endif
#if MGB_ENABLE_DOT
class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT
class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT
class AlgoInt8x8x32GevmDot; // Arm_common Int8x8x32 Gevm NCHW DOT
class AlgoInt8x8x32GevmN32K4Dot; // Arm_common Int8x8x32 Gevm NK4
#endif
class AlgoInt8x8x16; // Arm_common Int 8x8x16
class AlgoPack;


+ 8
- 1
dnn/src/arm_common/simd_macro/marm_neon.h View File

@@ -469,7 +469,7 @@ __ai float64x2_t vbitq_f64(float64x2_t dst, float64x2_t v1, uint64x2_t mask) {
#endif

#if MEGDNN_ARMV7
__ai int8x16_t vqtbl1q_s8(int8x16_t& a, uint8x16_t& idx) {
__ai int8x16_t vqtbl1q_s8(int8x16_t a, uint8x16_t idx) {
int8x8_t src_low = vget_low_s8(a);
int8x8_t src_high = vget_high_s8(a);
return vcombine_s8(
@@ -726,6 +726,13 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) {
asm volatile("fmls %0.4s, %1.4s, %2.4s\n" : "+w"(a) : "w"(b), "w"(v) :);
return a;
}
#if __ARM_ARCH < 8
__ai int32x4_t vcvtaq_s32_f32(float32x4_t val) {
float32x4_t vinc0 = vbslq_f32(
vcgeq_f32(val, vdupq_n_f32(0.f)), vdupq_n_f32(0.5f), vdupq_n_f32(-0.5f));
return vcvtq_s32_f32(vaddq_f32(val, vinc0));
}
#endif
#if MGB_ENABLE_DOT
#undef __ARM_FEATURE_DOTPROD
#endif


+ 21
- 0
dnn/src/common/matrix_mul.cpp View File

@@ -61,6 +61,15 @@ void MatrixMulForward::deduce_layout(
"(transposed) B is (%zu,%zu)",
A0, A1, B0, B1);
C = TensorLayout(TensorShape({A0, B1}), C.dtype);
} else if (param().format == param::MatrixMul::Format::N32K4_DOT) {
A0 = A.shape[0];
A1 = A.shape[1];
B0 = B.shape[0];
B1 = B.shape[1];
megdnn_assert(!m_param.transposeA && !m_param.transposeB);
megdnn_assert(A0 == 1 && A1 % 4 == 0);
megdnn_assert(B.ndim == 4);
C = TensorLayout(TensorShape({A0, B0 * 32}), C.dtype);
} else {
auto do_deduce = [&](size_t pack_size) {
megdnn_assert(
@@ -132,6 +141,18 @@ void MatrixMulForward::check_exec(
megdnn_assert(A0 == C0, "%s", errmsg().c_str());
megdnn_assert(B1 == C1, "%s", errmsg().c_str());
megdnn_assert(A1 == B0, "%s", errmsg().c_str());
} else if (param().format == param::MatrixMul::Format::N32K4_DOT) {
size_t A0 = A.shape[0];
size_t A1 = A.shape[1];
size_t B2 = B.shape[2];
size_t B3 = B.shape[3];
megdnn_assert(!m_param.transposeA && !m_param.transposeB);
megdnn_assert(A0 == 1 && A1 % 4 == 0);
megdnn_assert(B.ndim == 4);
megdnn_assert(B2 == 32 && B3 == 4);
megdnn_assert_contiguous(A);
megdnn_assert_contiguous(B);
megdnn_assert_contiguous(C);
} else {
megdnn_assert_eq_size_t(A.ndim, 4_z);
megdnn_assert_eq_size_t(B.ndim, 3_z);


+ 3
- 1
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -442,9 +442,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format));
}
BiasMode bias_mode;
//! dst only channel BIAS is viewed as BROADCAST_CHANNEL_BIAS
bool dst_only_c = dst[0] == 1 && dst[spatial_pos] == 1 && dst[spatial_pos + 1] == 1;
if (bias.ndim == 0) {
bias_mode = BiasMode::NO_BIAS;
} else if (bias.eq_shape(dst)) {
} else if (bias.eq_shape(dst) && !dst_only_c) {
bias_mode = BiasMode::BIAS;
} else {
//! just check the ndim, the detail shape check is in check_exec


+ 3
- 0
dnn/src/fallback/conv_bias/opr_impl.h View File

@@ -258,6 +258,9 @@ public:
ARM_COMMON_CHANWISE_STRD1_NCHW44_S8,
ARM_COMMON_CHANWISE_STRD2_NCHW44_S8,
ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8,
//! LARGE for large filter
ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8,
ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8,
ARM_COMMON_DIRECT_STRD1_DOT_S8,
ARM_COMMON_DIRECT_STRD2_DOT_S8,
ARM_COMMON_DIRECT_NCHW44_DOT_S8,


+ 5
- 5
dnn/src/fallback/matrix_mul/opr_impl.cpp View File

@@ -195,11 +195,11 @@ MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param(
kern_size_param.trB = param().transposeB;
kern_size_param.compute_mode = param().compute_mode;
kern_size_param.format = param().format;
size_t pack_size = MatrixMulForward::pack_size(param().format);
kern_size_param.K *= pack_size;
kern_size_param.M *= pack_size;
if (param().format != Param::Format::N32K4_DOT) {
size_t pack_size = MatrixMulForward::pack_size(param().format);
kern_size_param.K *= pack_size;
kern_size_param.M *= pack_size;
}
return kern_size_param;
}



+ 3
- 0
dnn/src/fallback/matrix_mul/opr_impl.h View File

@@ -122,6 +122,8 @@ public:
ARM_COMMON_INT8X8X32_GEMV,
ARM_COMMON_INT8X8X32_GEMV_MK4,
ARM_COMMON_INT8X8X32_GEMV_MK4_DOT,
ARM_COMMON_INT8X8X32_GEVM_DOT,
ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT,
ARM_COMMON_F16_GEMV,
ARM_COMMON_GEVM,
#if MEGDNN_AARCH64
@@ -175,6 +177,7 @@ public:
enum class AlgoSet : uint32_t {
ALGO_TYPE_GEMM = 0,
ALGO_TYPE_GEMV = 1,
ALGO_TYPE_GEVM = 2,
};

enum class PackMode : uint32_t {


+ 32
- 0
dnn/src/naive/matrix_mul/matrix_mul_helper.h View File

@@ -105,6 +105,34 @@ void run_matrix_mul_mk4_dot_tpl(
template <
typename itype, typename otype, bool transA, bool transB,
typename comp_type = otype>
void run_matrix_mul_n32k4_dot_tpl(
const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K,
size_t LDA, size_t, size_t, const DType& A_type, const DType& B_type) {
Getter<itype, comp_type> getterA(A_type), getterB(B_type);
megdnn_assert(!transA && !transB);
for (size_t m = 0; m < M; ++m) {
for (size_t n = 0; n < N; n += 32) {
comp_type res[32] = {comp_type(0)};
for (size_t k = 0; k < K; k += 4) {
for (size_t i = 0; i < 32; i++) {
comp_type av, bv;
for (size_t j = 0; j < 4; j++) {
av = getterA(A[m * LDA + k + j]);
bv = getterA(B[n * K + k * 32 + i * 4 + j]);
res[i] += av * bv;
}
}
}
for (size_t i = 0; i < 32; i++) {
C[n + i] = res[i];
}
}
}
}

template <
typename itype, typename otype, bool transA, bool transB,
typename comp_type = otype>
void run_matrix_mul_mk8_tpl(
const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K,
size_t LDA, size_t LDB, size_t LDC, const DType& A_type, const DType& B_type) {
@@ -251,6 +279,10 @@ void dispatch_ta_tb(
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \
} else if (format == param::MatrixMul::Format::N32K4_DOT) { \
return run_matrix_mul_n32k4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \
} else if (format == param::MatrixMul::Format::MK8) { \
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \


+ 78
- 4
dnn/test/arm_common/conv_bias.cpp View File

@@ -160,7 +160,7 @@ static void benchmark_convbias(
.set_display(false);
}
auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*";
#if MGB_ENBALE_DOT
#if MGB_ENABLE_DOT
if (!is_fp32) {
nchw44_algo_regx = ".*DOT.*";
}
@@ -1626,7 +1626,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) {

#endif

#if MGB_ENBALE_DOT
#if MGB_ENABLE_DOT
#if MEGDNN_WITH_BENCHMARK
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
// have to remove preferred restrict in usable func before run the benchmark
@@ -2011,6 +2011,80 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) {
}
}

TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) {
using namespace conv_bias;

std::vector<TestArg> args;
auto run = [&](size_t group, size_t w, size_t h, size_t kernel, size_t stride,
NonlineMode nonline_mode) {
size_t p = kernel / 2;
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
param.stride_h = stride;
param.stride_w = stride;
param.pad_h = p;
param.pad_w = p;
param.nonlineMode = nonline_mode;
param.format = param::ConvBias::Format::NCHW;
param.sparse = ConvBiasForward::Param::Sparse::GROUP;

//! channel bias
args.emplace_back(
param, TensorShape{1, group, h, w},
TensorShape{group, 1, 1, kernel, kernel}, TensorShape{1, group, 1, 1});
};

run(64, 64, 64, 9, 1, NonlineMode::RELU);
run(64, 40, 40, 9, 2, NonlineMode::RELU);
run(64, 20, 20, 9, 1, NonlineMode::RELU);

constexpr size_t RUN = 120;
Benchmarker<ConvBias> benchmark0(handle());
benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, dtype::QuantizedS8(60.25f));
benchmark0.set_display(false);
benchmark0.set_times(RUN);
benchmark0.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"ARMDOTS8_DIRECT_CHANWISE_LARGE"));

Benchmarker<ConvBias> benchmark1(handle());
benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, dtype::QuantizedS8(60.25f));
benchmark1.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"ARMDOTS8_IM2COL_CHANWISE_LARGE"));
benchmark1.set_display(false);
benchmark1.set_times(RUN);

for (auto&& arg : args) {
TensorLayout dst_layout;
auto opr = handle()->create_operator<ConvBias>();
opr->param() = arg.param;
opr->deduce_layout(
{arg.src, dtype::Int8()}, {arg.filter, dtype::Int8()},
{arg.bias, dtype::Int32()}, {}, dst_layout);
//! dst.nr_elems * FH * FW * 2
float computations =
dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6;

auto used0 = benchmark0.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}}) /
RUN;
auto used1 = benchmark1.set_param(arg.param).exec(
{arg.src, arg.filter, arg.bias, {}, {}}) /
RUN;

printf("%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops "
"speedup: %f\n",
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), used0,
computations / used0, used1, computations / used1, used1 / used0);
}
}

#endif
#endif

@@ -2194,7 +2268,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) {
dtype::QuantizedS8 stype(2.5f);
dtype::QuantizedS32 dtype(6.25f);
#if MEGDNN_AARCH64
#if MGB_ENBALE_DOT
#if MGB_ENABLE_DOT
benchmark_conv1x1(
"AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, dtype, dtype, dtype);
#else
@@ -2212,7 +2286,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) {
dtype::QuantizedS32 dtype(1.2 * 1.2);

#if MEGDNN_AARCH64
#if MGB_ENBALE_DOT
#if MGB_ENABLE_DOT
benchmark_conv1x1(
"AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, dtype, dtype);
#else


+ 103
- 1
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -136,6 +136,84 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
return args;
}

std::vector<conv_bias::TestArg> get_channel_wise_args(
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
bool no_full_bias, bool support_relu) {
using namespace conv_bias;
using Param = param::ConvBias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args;

auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
size_t stride, NLMode nlmode, bool pad) {
Param param;
param.stride_h = stride;
param.stride_w = stride;
if (pad) {
param.pad_h = kernel / 2;
param.pad_w = kernel / 2;
} else {
param.pad_h = 0;
param.pad_w = 0;
}
param.nonlineMode = nlmode;
param.format = param::ConvBias::Format::NCHW;
param.sparse = param::ConvBias::Sparse::GROUP;

args.emplace_back(
param, TensorShape{n, group, h, w},
TensorShape{group, 1, 1, kernel, kernel}, TensorShape{});
if (!no_bias) {
args.emplace_back(
param, TensorShape{n, group, h, w},
TensorShape{group, 1, 1, kernel, kernel},
TensorShape{1, group, 1, 1});
}
if (!no_full_bias) {
args.emplace_back(
param, TensorShape{n, group, h, w},
TensorShape{group, 1, 1, kernel, kernel},
TensorShape{
n, group, (h + 2 * param.pad_w - kernel) / stride + 1,
(w + 2 * param.pad_w - kernel) / stride + 1});
}
};

std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
if (!no_nonlinemode) {
nonlinemode.emplace_back(NLMode::RELU);
nonlinemode.emplace_back(NLMode::H_SWISH);
} else if (support_relu) {
nonlinemode.emplace_back(NLMode::RELU);
}

for (size_t n : {1, 2}) {
for (auto nlmode : nonlinemode) {
for (bool pad : {true}) {
for (size_t group : {1, 3, 7}) {
for (size_t size : {4, 6, 7, 9, 16, 20, 32, 55}) {
for (size_t kern : kernel) {
pack(n, group, size, size, kern, stride, nlmode, pad);
}
}
}
}
for (bool pad : {false}) {
for (size_t group : {7}) {
for (size_t size : {37}) {
for (size_t kern : kernel) {
if (size < kern)
continue;
pack(n, group, size, size, kern, stride, nlmode, pad);
}
}
}
}
}
}
return args;
}

std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args(
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode,
bool no_full_bias) {
@@ -226,7 +304,7 @@ void checker_conv_bias_qint8x8x8(
.set_rng(1, &rng)
.set_rng(2, &rng);
for (auto&& arg : args) {
checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}});
}
}
void checker_conv_bias_qint8x8x32(
@@ -532,6 +610,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {

/****************************dot qint8 direct*************************/
#if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) {
checker_conv_bias_qint8x8x8(
get_channel_wise_args({9}, 1, false, true, true, true), handle(),
"ARMDOTS8_DIRECT_CHANWISE_LARGE");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) {
checker_conv_bias_qint8x8x8(
get_channel_wise_args({9}, 2, false, true, true, true), handle(),
"ARMDOTS8_DIRECT_CHANWISE_LARGE");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S1) {
checker_conv_bias_qint8x8x8(
get_channel_wise_args({9}, 1, false, true, true, true), handle(),
"ARMDOTS8_IM2COL_CHANWISE_LARGE");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S2) {
checker_conv_bias_qint8x8x8(
get_channel_wise_args({9}, 2, false, true, true, true), handle(),
"ARMDOTS8_IM2COL_CHANWISE_LARGE");
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
auto args = get_nchw44_conv_bias_args(
{2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true);


+ 107
- 0
dnn/test/arm_common/matrix_mul.cpp View File

@@ -219,6 +219,113 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) {
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024})
run(M, K, 1);
}

TEST_F(ARM_COMMON, QINT8x8x32_GEVM_DOT) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_DOT");

checker.set_before_exec_callback(algo_ck);

std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30);
checker.set_rng(0, rng.get()).set_rng(1, rng.get());
Param param;
param.format = Param::Format::DEFAULT;
param.transposeA = false;
param.transposeB = false;

auto run = [&](size_t M, size_t N, size_t K) {
TensorShape A, B;
A = TensorShape{M, K};
B = TensorShape{K, N};
checker.set_param(param)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32())
.execs({A, B, {}});
};
run(1, 32, 4);
for (int n = 7; n < 43; n += 3) {
for (int k = 1; k < 33; k += 3) {
run(1, n, k);
}
}
}

TEST_F(ARM_COMMON, QINT8x8x32_GEVM_N32K4_DOT) {
Checker<MatrixMul> checker(handle());
using Param = MatrixMul::Param;
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT");
checker.set_before_exec_callback(algo_ck);

std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30);
checker.set_rng(0, rng.get()).set_rng(1, rng.get());
Param param;
param.format = Param::Format::N32K4_DOT;
param.transposeA = false;
param.transposeB = false;

auto run = [&](size_t M, size_t N, size_t K) {
TensorShape A, B;
A = TensorShape{M, K};
B = TensorShape{N / 32, K / 4, 32, 4};
checker.set_param(param)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32())
.execs({A, B, {}});
};
run(1, 32, 4);
for (int n = 32; n < 65; n += 32) {
for (int k = 4; k < 39; k += 4) {
run(1, n, k);
}
}
}

#if MEGDNN_WITH_BENCHMARK
TEST_F(ARM_COMMON, BENCHMARK_QINT8x8x32_GEVM_N32K4_DOT) {
using Param = MatrixMul::Param;
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT");

std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30);
Param param;
param.format = Param::Format::N32K4_DOT;
param.transposeA = false;
param.transposeB = false;

constexpr size_t RUNS = 2000;

Benchmarker<MatrixMul> benchmarker_int(handle());
benchmarker_int.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int32{})
.set_param(param)
.set_before_exec_callback(algo_ck)
.set_display(false);
Benchmarker<MatrixMul> benchmarker_float(handle());
benchmarker_float.set_display(false).set_times(RUNS);

auto bench = [&](size_t M, size_t N, size_t K) {
auto int_used =
benchmarker_int.exec({{M, K}, {N / 32, K / 4, 32, 4}, {}}) / RUNS;
auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS;
float computations = 2.f * M * K * N * 1e-6;
float through_put = (M * K + N * K + M * N) * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms "
"%f Gflops speedup: %f, through put %f G\n",
M, K, N, float_used, computations / float_used, int_used,
computations / int_used, float_used / int_used, through_put / int_used);
};

bench(1, 256, 512);
bench(1, 256, 1024);
bench(1, 512, 512);
bench(1, 512, 1024);
}
#endif

#endif

TEST_F(ARM_COMMON, QINT8x8x32_GEVM) {


Loading…
Cancel
Save