GitOrigin-RevId: a28a97fcb5
release-1.10
@@ -557,7 +557,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | |||
Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:' | |||
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | |||
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | |||
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'), | |||
Doc('N32K4_DOT = 4', 'Split 32 from N and 4 from K, better for neon gevm dotprod:' | |||
'N/32, K/4, 32(n), 4(k)') | |||
) | |||
) | |||
(pdef('SVD'). | |||
@@ -127,6 +127,37 @@ public: | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) | |||
}; | |||
class ConvBiasImpl::AlgoDotS8DirectChanWiseLarge final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "ARMDOTS8_DIRECT_CHANWISE_LARGE"; } | |||
bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) | |||
const override; | |||
size_t get_workspace(const NCBKernSizeParam&) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8) | |||
}; | |||
class ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "ARMDOTS8_IM2COL_CHANWISE_LARGE"; } | |||
bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) | |||
const override; | |||
size_t get_workspace(const NCBKernSizeParam&) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8) | |||
}; | |||
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
@@ -0,0 +1,270 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include <arm_neon.h> | |||
#include "src/arm_common/conv_bias/int8/algos.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||
#include "src/common/unroll_macro.h" | |||
#if MGB_ENABLE_DOT | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||
namespace { | |||
class DirectConvRunner { | |||
public: | |||
DirectConvRunner(size_t flt_size, size_t stride) { | |||
if (flt_size == 9 && stride == 1) { | |||
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16; | |||
} else { | |||
megdnn_assert(flt_size == 9 && stride == 2); | |||
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16; | |||
} | |||
} | |||
size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||
auto&& fm = param.filter_meta; | |||
auto FW = fm.spatial[1]; | |||
return round_up((size_t)FW, m_block_k); | |||
} | |||
size_t get_round_iw(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||
auto&& fm = param.filter_meta; | |||
size_t SW = fm.stride[1]; | |||
size_t OW = param.osz[1]; | |||
size_t round_ow = round_up(OW, m_block_ow); | |||
size_t round_fw = get_round_fw(param); | |||
size_t pad_iw = round_ow * SW - SW + round_fw; | |||
return round_up(pad_iw, m_align_iw); | |||
} | |||
size_t get_round_ih(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||
auto&& fm = param.filter_meta; | |||
size_t SH = fm.stride[0]; | |||
size_t OH = param.osz[0]; | |||
auto FH = fm.spatial[0]; | |||
size_t round_oh = round_up(OH, m_block_oh); | |||
return round_oh * SH - SH + FH; | |||
} | |||
WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
size_t round_filter = get_round_fw(param) * FH; | |||
size_t round_ih = get_round_ih(param); | |||
size_t round_iw = get_round_iw(param); | |||
size_t pad_src = round_iw * round_ih; | |||
return {nullptr, {pad_src, round_filter}}; | |||
} | |||
WorkspaceBundle get_total_bundle( | |||
const ConvBiasImpl::NCBKernSizeParam& param) const { | |||
auto sub_bundle = get_sub_bundle(param); | |||
auto sub_bundle_size = sub_bundle.total_size_in_bytes(); | |||
size_t nr_threads = param.nr_threads; | |||
SmallVector<size_t> sizes_in_bytes; | |||
for (size_t i = 0; i < nr_threads; ++i) { | |||
sizes_in_bytes.push_back(sub_bundle_size); | |||
} | |||
WorkspaceBundle total_bundle(nullptr, sizes_in_bytes); | |||
return total_bundle; | |||
} | |||
void run( | |||
const int8_t* pad_src_ptr, const int8_t* round_filter_ptr, int32_t bias, | |||
int8_t* dst_ptr, size_t OH, size_t OW, size_t pad_iw, float scale, | |||
int8_t relu_val) const { | |||
const size_t ow_end = OW / m_block_ow * m_block_ow; | |||
const size_t ow_remain = OW - ow_end; | |||
const size_t oh_end = OH / m_block_oh * m_block_oh; | |||
const size_t oh_remain = OH - oh_end; | |||
int8_t cache[4 * 16]; | |||
for (size_t oh = 0; oh < oh_end; oh += m_block_oh) { | |||
for (size_t ow = 0; ow < ow_end; ow += m_block_ow) { | |||
m_func(pad_src_ptr, round_filter_ptr, bias, dst_ptr, oh, ow, OH, OW, | |||
pad_iw, scale, relu_val); | |||
} | |||
if (ow_remain > 0) { | |||
m_func(pad_src_ptr, round_filter_ptr, bias, | |||
&cache[0] - (oh * m_block_ow + ow_end), oh, ow_end, OH, | |||
m_block_ow, pad_iw, scale, relu_val); | |||
for (size_t i = 0; i < m_block_oh; ++i) { | |||
for (size_t j = 0; j < ow_remain; ++j) { | |||
dst_ptr[(i + oh) * OW + (j + ow_end)] = cache[i * 16 + j]; | |||
} | |||
} | |||
} | |||
} | |||
if (oh_remain > 0) { | |||
for (size_t ow = 0; ow < ow_end; ow += m_block_ow) { | |||
m_func(pad_src_ptr, round_filter_ptr, bias, | |||
&cache[0] - (oh_end * m_block_ow + ow), oh_end, ow, OH, | |||
m_block_ow, pad_iw, scale, relu_val); | |||
for (size_t i = 0; i < oh_remain; ++i) { | |||
for (size_t j = 0; j < m_block_ow; ++j) { | |||
dst_ptr[(i + oh_end) * OW + (j + ow)] = cache[i * 16 + j]; | |||
} | |||
} | |||
} | |||
if (ow_remain > 0) { | |||
m_func(pad_src_ptr, round_filter_ptr, bias, | |||
&cache[0] - (oh_end * m_block_ow + ow_end), oh_end, ow_end, OH, | |||
m_block_ow, pad_iw, scale, relu_val); | |||
for (size_t i = 0; i < oh_remain; ++i) { | |||
for (size_t j = 0; j < ow_remain; ++j) { | |||
dst_ptr[(i + oh_end) * OW + (j + ow_end)] = cache[i * 16 + j]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
private: | |||
std::function<void( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, | |||
size_t oh, size_t ow, size_t OH, size_t OW, size_t pad_iw, | |||
const float scale, int8_t relu_val)> | |||
m_func; | |||
size_t m_block_oh{4}; | |||
size_t m_block_ow{16}; | |||
size_t m_block_k{4}; | |||
size_t m_align_iw{16}; | |||
}; | |||
void do_conv( | |||
const WorkspaceBundle& bundle, const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index, const DirectConvRunner& runner) { | |||
auto&& fm = kern_param.filter_meta; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t FH = fm.spatial[0]; | |||
size_t FW = fm.spatial[1]; | |||
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||
float scale_dst_div = 1.f / scale_dst; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
int8_t* pad_src_ptr = static_cast<int8_t*>(bundle.get(0)); | |||
int8_t* round_filter_ptr = static_cast<int8_t*>(bundle.get(1)); | |||
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | |||
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | |||
const int8_t* fptr = kern_param.filter<dt_int8>(group_id); | |||
void* dst = kern_param.dst<void>(batch_id, group_id); | |||
size_t pad_iw = runner.get_round_iw(kern_param); | |||
memset(pad_src_ptr, 0, bundle.get_size(0)); | |||
rep(ih, IH) { | |||
std::memcpy( | |||
pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW, | |||
sizeof(int8_t) * IW); | |||
} | |||
memset(round_filter_ptr, 0, bundle.get_size(1)); | |||
size_t round_fw = runner.get_round_fw(kern_param); | |||
for (size_t fh = 0; fh < FH; ++fh) { | |||
std::memcpy(round_filter_ptr + fh * round_fw, fptr + fh * FW, FW); | |||
} | |||
int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128; | |||
int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr; | |||
int8_t* dst_ptr = (int8_t*)dst; | |||
runner.run( | |||
pad_src_ptr, round_filter_ptr, bias_val, dst_ptr, OH, OW, pad_iw, | |||
scale_bias * scale_dst_div, relu_val); | |||
} | |||
} // namespace | |||
bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
if (!cpuinfo_has_arm_neon_dot()) { | |||
return false; | |||
} | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto FW = fm.spatial[1]; | |||
auto SH = fm.stride[0]; | |||
auto SW = fm.stride[1]; | |||
auto noline = param.nonlineMode; | |||
auto bias_mode = param.bias_mode; | |||
bool avaible = | |||
//! src and filter are qint8, dst is qint8 or qint32 | |||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8)) && | |||
fm.format == param::Convolution::Format::NCHW && !fm.should_flip && | |||
(noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) && | |||
(bias_mode == BiasMode::NO_BIAS || | |||
bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && | |||
fm.ocpg == 1; | |||
return avaible; | |||
} | |||
size_t ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel, | |||
midout_iv("AlgoDotS8DirectChanWiseLarge::get_workspace"_hash)) { | |||
auto&& fm = param.filter_meta; | |||
DirectConvRunner runner(fm.spatial[0], fm.stride[0]); | |||
auto total_bundle = runner.get_total_bundle(param); | |||
return total_bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotS8DirectChanWiseLarge:: | |||
dispatch_kerns(const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel, | |||
midout_iv("AlgoDotS8DirectChanWiseLarge::dispatch_kerns"_hash)) { | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||
auto&& fm = param.filter_meta; | |||
DirectConvRunner runner(fm.spatial[0], fm.stride[0]); | |||
WorkspaceBundle wbundle = runner.get_sub_bundle(param); | |||
WorkspaceBundle total_bundle = runner.get_total_bundle(param); | |||
auto exec_one_group = [wbundle, total_bundle, runner]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
WorkspaceBundle temp_total_bundle = total_bundle; | |||
temp_total_bundle.set(kern_param.workspace_ptr); | |||
WorkspaceBundle temp_bundle = wbundle; | |||
temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id)); | |||
do_conv(temp_bundle, kern_param, ncb_index, runner); | |||
}; | |||
size_t N = param.n; | |||
size_t group = fm.group; | |||
ret_kerns.push_back({exec_one_group, {N, group}}); | |||
return ret_kerns; | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
#endif |
@@ -0,0 +1,425 @@ | |||
/** | |||
* \file dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/arch.h" | |||
#if MGB_ENABLE_DOT | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/conv_bias/int8/algos.h" | |||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||
namespace { | |||
constexpr size_t block_n = 32; | |||
constexpr size_t block_k = 4; | |||
WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
auto&& fm = param.filter_meta; | |||
auto OH = param.osz[0]; | |||
auto OW = param.osz[1]; | |||
size_t IH = param.isz[0]; | |||
size_t IW = param.isz[1]; | |||
auto FH = fm.spatial[0]; | |||
auto FW = fm.spatial[1]; | |||
size_t PH = param.filter_meta.padding[0]; | |||
size_t PW = param.filter_meta.padding[1]; | |||
size_t round_ohw = round_up((size_t)OH * OW, block_n); | |||
size_t round_filter = round_up((size_t)FW, block_k) * FH; | |||
size_t pad_src = (IW + PW * 2) * (IH + PH * 2); | |||
return {nullptr, | |||
{pad_src, round_filter, round_ohw * round_filter, | |||
round_ohw * sizeof(int32_t)}}; | |||
} | |||
WorkspaceBundle get_total_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||
auto sub_bundle = get_sub_bundle(param); | |||
auto sub_bundle_size = sub_bundle.total_size_in_bytes(); | |||
size_t nr_threads = param.nr_threads; | |||
SmallVector<size_t> sizes_in_bytes; | |||
for (size_t i = 0; i < nr_threads; ++i) { | |||
sizes_in_bytes.push_back(sub_bundle_size); | |||
} | |||
WorkspaceBundle total_bundle(nullptr, sizes_in_bytes); | |||
return total_bundle; | |||
} | |||
template <size_t flt_size, size_t stride> | |||
void im2col( | |||
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||
size_t round_filter) { | |||
constexpr size_t FH = flt_size; | |||
constexpr size_t FW = flt_size; | |||
constexpr size_t SH = stride; | |||
constexpr size_t SW = stride; | |||
constexpr size_t FW_ROUND = (FW + 3) / 4 * 4; | |||
int bn = 0; | |||
int ni = 0; | |||
for (size_t oh = 0; oh < OH; ++oh) | |||
for (size_t ow = 0; ow < OW; ++ow) { | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
int bk = 0; | |||
int ki = 0; | |||
for (size_t fh = 0; fh < FH; ++fh) | |||
for (size_t fw = 0; fw < FW_ROUND; ++fw) { | |||
dst[bn * block_n * round_filter + bk * block_n * block_k + | |||
ni * block_k + ki] = src_n[fh * pad_iw + fw]; | |||
++ki; | |||
if (ki == block_k) { | |||
ki = 0; | |||
++bk; | |||
} | |||
} | |||
++ni; | |||
if (ni == block_n) { | |||
ni = 0; | |||
++bn; | |||
} | |||
} | |||
} | |||
template <> | |||
void im2col<9, 1>( | |||
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||
size_t round_filter) { | |||
constexpr size_t FH = 9; | |||
constexpr size_t SH = 1; | |||
constexpr size_t SW = 1; | |||
constexpr size_t k_block_stride = block_k * block_n; | |||
constexpr size_t ow_block = 16; | |||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||
2, 3, 4, 5, 3, 4, 5, 6}; | |||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, | |||
6, 7, 8, 9, 7, 8, 9, 10}; | |||
static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, | |||
10, 11, 12, 13, 11, 12, 13, 14}; | |||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||
uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); | |||
int bn = 0; | |||
int ni = 0; | |||
for (size_t oh = 0; oh < OH; ++oh) | |||
for (size_t ow = 0; ow < OW;) { | |||
if (ow + ow_block <= OW && ni + ow_block <= block_n) { | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||
for (size_t fh = 0; fh < FH; ++fh) { | |||
int8x16_t read_w[2]; | |||
read_w[0] = vld1q_s8(src_n); | |||
read_w[1] = vld1q_s8(src_n + 16); | |||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||
int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); | |||
int8x16_t ncdef_0 = | |||
vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); | |||
int8x16_t n0123_1 = n4567_0; | |||
int8x16_t n4567_1 = n89ab_0; | |||
int8x16_t n89ab_1 = ncdef_0; | |||
int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||
int8x16_t n0123_2 = n89ab_0; | |||
int8x16_t n4567_2 = ncdef_0; | |||
int8x16_t n89ab_2 = ncdef_1; | |||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||
vst1q_s8(dst_n + 0 * 16, n0123_0); | |||
vst1q_s8(dst_n + 1 * 16, n4567_0); | |||
vst1q_s8(dst_n + 2 * 16, n89ab_0); | |||
vst1q_s8(dst_n + 3 * 16, ncdef_0); | |||
vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1); | |||
vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1); | |||
vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1); | |||
vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1); | |||
vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2); | |||
vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2); | |||
vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2); | |||
vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2); | |||
dst_n += 3 * k_block_stride; | |||
src_n += pad_iw; | |||
} | |||
ni += ow_block; | |||
ow += ow_block; | |||
if (ni == block_n) { | |||
ni = 0; | |||
++bn; | |||
} | |||
} else { | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||
for (size_t fh = 0; fh < FH; ++fh) { | |||
int8x16_t read_w[0]; | |||
read_w[0] = vld1q_s8(src_n); | |||
vst1q_lane_s32(dst_n, read_w[0], 0); | |||
vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1); | |||
vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2); | |||
dst_n += 3 * k_block_stride; | |||
src_n += pad_iw; | |||
} | |||
++ni; | |||
++ow; | |||
if (ni == block_n) { | |||
ni = 0; | |||
++bn; | |||
} | |||
} | |||
} | |||
} | |||
template <> | |||
void im2col<9, 2>( | |||
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||
size_t round_filter) { | |||
constexpr size_t FH = 9; | |||
constexpr size_t SH = 2; | |||
constexpr size_t SW = 2; | |||
constexpr size_t k_block_stride = block_k * block_n; | |||
constexpr size_t ow_block = 16; | |||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, | |||
4, 5, 6, 7, 6, 7, 8, 9}; | |||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, | |||
8, 9, 10, 11, 10, 11, 12, 13}; | |||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||
int bn = 0; | |||
int ni = 0; | |||
for (size_t oh = 0; oh < OH; ++oh) | |||
for (size_t ow = 0; ow < OW;) { | |||
if (ow + ow_block <= OW && ni + ow_block <= block_n) { | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||
for (size_t fh = 0; fh < FH; ++fh) { | |||
int8x16_t read_w[3]; | |||
read_w[0] = vld1q_s8(src_n); | |||
read_w[1] = vld1q_s8(src_n + 16); | |||
read_w[2] = vld1q_s8(src_n + 32); | |||
int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); | |||
int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); | |||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||
int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); | |||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||
int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); | |||
int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||
int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); | |||
int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||
int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); | |||
int8x16_t n0123_2 = n4567_0; | |||
int8x16_t n4567_2 = n89ab_0; | |||
int8x16_t n89ab_2 = ncdef_0; | |||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||
vst1q_s8(dst_n + 0 * 16, n0123_0); | |||
vst1q_s8(dst_n + 1 * 16, n4567_0); | |||
vst1q_s8(dst_n + 2 * 16, n89ab_0); | |||
vst1q_s8(dst_n + 3 * 16, ncdef_0); | |||
vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1); | |||
vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1); | |||
vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1); | |||
vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1); | |||
vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2); | |||
vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2); | |||
vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2); | |||
vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2); | |||
dst_n += 3 * k_block_stride; | |||
src_n += pad_iw; | |||
} | |||
ni += ow_block; | |||
ow += ow_block; | |||
if (ni == block_n) { | |||
ni = 0; | |||
++bn; | |||
} | |||
} else { | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||
for (size_t fh = 0; fh < FH; ++fh) { | |||
int8x16_t read_w[0]; | |||
read_w[0] = vld1q_s8(src_n); | |||
vst1q_lane_s32(dst_n, read_w[0], 0); | |||
vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1); | |||
vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2); | |||
dst_n += 3 * k_block_stride; | |||
src_n += pad_iw; | |||
} | |||
++ni; | |||
++ow; | |||
if (ni == block_n) { | |||
ni = 0; | |||
++bn; | |||
} | |||
} | |||
} | |||
} | |||
void do_conv( | |||
const WorkspaceBundle& bundle, const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) { | |||
auto&& fm = kern_param.filter_meta; | |||
size_t PH = kern_param.filter_meta.padding[0]; | |||
size_t PW = kern_param.filter_meta.padding[1]; | |||
size_t OH = kern_param.osz[0]; | |||
size_t OW = kern_param.osz[1]; | |||
size_t IH = kern_param.isz[0]; | |||
size_t IW = kern_param.isz[1]; | |||
size_t FH = fm.spatial[0]; | |||
size_t FW = fm.spatial[1]; | |||
size_t SH = fm.stride[0]; | |||
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||
float scale_dst_div = 1.f / scale_dst; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
int8_t* pad_src_ptr = static_cast<int8_t*>(bundle.get(0)); | |||
int8_t* round_filter_ptr = static_cast<int8_t*>(bundle.get(1)); | |||
int8_t* im2col_ptr = static_cast<int8_t*>(bundle.get(2)); | |||
int32_t* i32_ptr = static_cast<int32_t*>(bundle.get(3)); | |||
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | |||
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | |||
const int8_t* fptr = kern_param.filter<dt_int8>(group_id); | |||
void* dst = kern_param.dst<void>(batch_id, group_id); | |||
size_t round_filter = round_up(FW, block_k) * FH; | |||
size_t pad_iw = IW + 2 * PW; | |||
memset(pad_src_ptr, 0, bundle.get_size(0)); | |||
rep(ih, IH) { | |||
std::memcpy( | |||
pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW, | |||
sizeof(int8_t) * IW); | |||
} | |||
memset(round_filter_ptr, 0, bundle.get_size(1)); | |||
size_t fh_stride = round_up(FW, block_k); | |||
for (size_t fh = 0; fh < FH; ++fh) { | |||
std::memcpy(round_filter_ptr + fh * fh_stride, fptr + fh * FW, FW); | |||
} | |||
memset(im2col_ptr, 0, bundle.get_size(2)); | |||
if (SH == 1) { | |||
im2col<9, 1>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter); | |||
} else { | |||
im2col<9, 2>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter); | |||
} | |||
gevm_naive_n32k4_dot( | |||
round_filter_ptr, im2col_ptr, i32_ptr, 1, OH * OW, round_filter, 0, 0, 0); | |||
int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr; | |||
int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128; | |||
int8_t* dst_ptr = (int8_t*)dst; | |||
for (size_t i = 0; i < OH * OW; ++i) { | |||
//! optimize by tbl | |||
int val = roundf(scale_bias * scale_dst_div * (i32_ptr[i] + bias_val)); | |||
val = val < -128 ? -128 : val; | |||
val = val > 127 ? 127 : val; | |||
val = val > relu_val ? val : relu_val; | |||
dst_ptr[i] = val; | |||
} | |||
} | |||
} // namespace | |||
bool ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::usable( | |||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
if (!cpuinfo_has_arm_neon_dot()) { | |||
return false; | |||
} | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto FW = fm.spatial[1]; | |||
auto SH = fm.stride[0]; | |||
auto SW = fm.stride[1]; | |||
auto noline = param.nonlineMode; | |||
auto bias_mode = param.bias_mode; | |||
bool avaible = | |||
//! src and filter are qint8, dst is qint8 or qint32 | |||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8)) && | |||
fm.format == param::Convolution::Format::NCHW && !fm.should_flip && | |||
(noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) && | |||
(bias_mode == BiasMode::NO_BIAS || | |||
bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && | |||
fm.ocpg == 1; | |||
return avaible; | |||
} | |||
size_t ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel, | |||
midout_iv("AlgoDotS8Im2colChanWiseLarge::get_workspace"_hash)) { | |||
auto bundle = get_total_bundle(param); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge:: | |||
dispatch_kerns(const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel, | |||
midout_iv("AlgoDotS8Im2colChanWiseLarge::dispatch_kerns"_hash)) { | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||
auto fm = param.filter_meta; | |||
size_t N = param.n; | |||
size_t group = fm.group; | |||
WorkspaceBundle wbundle = get_sub_bundle(param); | |||
WorkspaceBundle total_bundle = get_total_bundle(param); | |||
auto exec_one_group = [wbundle, total_bundle]( | |||
const NCBKernParam& kern_param, | |||
const NCBKernIndex& ncb_index) mutable { | |||
WorkspaceBundle temp_total_bundle = total_bundle; | |||
temp_total_bundle.set(kern_param.workspace_ptr); | |||
WorkspaceBundle temp_bundle = wbundle; | |||
temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id)); | |||
do_conv(temp_bundle, kern_param, ncb_index); | |||
}; | |||
ret_kerns.push_back({exec_one_group, {N, group}}); | |||
return ret_kerns; | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
#endif |
@@ -0,0 +1,27 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
#if MGB_ENABLE_DOT | |||
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val); | |||
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val); | |||
#endif |
@@ -0,0 +1,40 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
#if MGB_ENABLE_DOT | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
static inline void quant_store_s8( | |||
float32x4_t v0, float32x4_t v1, float32x4_t v2, float32x4_t v3, int8_t* ptr, | |||
int8x16_t relu_reg) { | |||
int32x4_t i0 = vcvtaq_s32_f32(v0); | |||
int32x4_t i1 = vcvtaq_s32_f32(v1); | |||
int32x4_t i2 = vcvtaq_s32_f32(v2); | |||
int32x4_t i3 = vcvtaq_s32_f32(v3); | |||
int16x4_t i16_0 = vqmovn_s32(i0); | |||
int16x4_t i16_1 = vqmovn_s32(i1); | |||
int16x4_t i16_2 = vqmovn_s32(i2); | |||
int16x4_t i16_3 = vqmovn_s32(i3); | |||
int8x8_t i8_0 = vqmovn_s16(vcombine_s16(i16_0, i16_1)); | |||
int8x8_t i8_1 = vqmovn_s16(vcombine_s16(i16_2, i16_3)); | |||
int8x16_t rst = vcombine_s8(i8_0, i8_1); | |||
rst = vmaxq_s8(rst, relu_reg); | |||
vst1q_s8(ptr, rst); | |||
} | |||
#endif |
@@ -0,0 +1,221 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/arch.h" | |||
#if MGB_ENABLE_DOT | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" | |||
#include "src/common/unroll_macro.h" | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val) { | |||
//! 4x16 | |||
const size_t SH = 1; | |||
const size_t SW = 1; | |||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||
2, 3, 4, 5, 3, 4, 5, 6}; | |||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, | |||
6, 7, 8, 9, 7, 8, 9, 10}; | |||
static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, | |||
10, 11, 12, 13, 11, 12, 13, 14}; | |||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||
uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
//! init | |||
int32x4_t c[4][4]; | |||
#define cb(step) \ | |||
c[step][0] = vdupq_n_s32(bias); \ | |||
c[step][1] = vdupq_n_s32(bias); \ | |||
c[step][2] = vdupq_n_s32(bias); \ | |||
c[step][3] = vdupq_n_s32(bias); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
int8x16_t flt[4]; | |||
flt[0] = vld1q_s8(weight + 0 * 16); | |||
flt[1] = vld1q_s8(weight + 1 * 16); | |||
flt[2] = vld1q_s8(weight + 2 * 16); | |||
flt[3] = vld1q_s8(weight + 3 * 16); | |||
//! row 0 | |||
int8x16_t read_w[2]; | |||
read_w[0] = vld1q_s8(src_n + 0 * pad_iw); | |||
read_w[1] = vld1q_s8(src_n + 0 * pad_iw + 16); | |||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||
int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); | |||
int8x16_t ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); | |||
int8x16_t n0123_1 = n4567_0; | |||
int8x16_t n4567_1 = n89ab_0; | |||
int8x16_t n89ab_1 = ncdef_0; | |||
int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||
int8x16_t n0123_2 = n89ab_0; | |||
int8x16_t n4567_2 = ncdef_0; | |||
int8x16_t n89ab_2 = ncdef_1; | |||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||
#define CAL_C(oh, flt_start) \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); | |||
CAL_C(0, 0); | |||
//! row 1 | |||
#define LOAD_SRC(row_id) \ | |||
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ | |||
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ | |||
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ | |||
n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ | |||
n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \ | |||
ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \ | |||
n0123_1 = n4567_0; \ | |||
n4567_1 = n89ab_0; \ | |||
n89ab_1 = ncdef_0; \ | |||
ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ | |||
n0123_2 = n89ab_0; \ | |||
n4567_2 = ncdef_0; \ | |||
n89ab_2 = ncdef_1; \ | |||
ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||
LOAD_SRC(1); | |||
CAL_C(0, 3); | |||
CAL_C(1, 0); | |||
//! row 2 | |||
LOAD_SRC(2); | |||
CAL_C(0, 3 * 2); | |||
CAL_C(1, 3 * 1); | |||
CAL_C(2, 3 * 0); | |||
//! row 3 | |||
LOAD_SRC(3); | |||
CAL_C(0, 3 * 3); | |||
CAL_C(1, 3 * 2); | |||
CAL_C(2, 3 * 1); | |||
CAL_C(3, 3 * 0); | |||
//! row 4 | |||
LOAD_SRC(4); | |||
CAL_C(0, 3 * 4); | |||
CAL_C(1, 3 * 3); | |||
CAL_C(2, 3 * 2); | |||
CAL_C(3, 3 * 1); | |||
//! update flt 4 -> 0 | |||
flt[0] = vld1q_s8(weight + 4 * 16); | |||
//! row 5 | |||
LOAD_SRC(5); | |||
CAL_C(0, 3 * 5); | |||
CAL_C(1, 3 * 4); | |||
CAL_C(2, 3 * 3); | |||
CAL_C(3, 3 * 2); | |||
//! update flt 5 -> 1 | |||
flt[1] = vld1q_s8(weight + 5 * 16); | |||
//! row 6 | |||
LOAD_SRC(6); | |||
CAL_C(0, 3 * 6); | |||
CAL_C(1, 3 * 5); | |||
CAL_C(2, 3 * 4); | |||
CAL_C(3, 3 * 3); | |||
//! update flt 6 -> 2 | |||
flt[2] = vld1q_s8(weight + 6 * 16); | |||
//! row 7 | |||
LOAD_SRC(7); | |||
CAL_C(0, 3 * 7); | |||
CAL_C(1, 3 * 6); | |||
CAL_C(2, 3 * 5); | |||
CAL_C(3, 3 * 4); | |||
//! row 8 | |||
LOAD_SRC(8); | |||
CAL_C(0, 3 * 8); | |||
CAL_C(1, 3 * 7); | |||
CAL_C(2, 3 * 6); | |||
CAL_C(3, 3 * 5); | |||
//! row 9 | |||
LOAD_SRC(9); | |||
CAL_C(1, 3 * 8); | |||
CAL_C(2, 3 * 7); | |||
CAL_C(3, 3 * 6); | |||
//! row 10 | |||
LOAD_SRC(10); | |||
CAL_C(2, 3 * 8); | |||
CAL_C(3, 3 * 7); | |||
//! row 11 | |||
LOAD_SRC(11); | |||
CAL_C(3, 3 * 8); | |||
float32x4_t dst_reg[4][4]; | |||
#define cb(step) \ | |||
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ | |||
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ | |||
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ | |||
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
#define cb(step) \ | |||
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ | |||
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ | |||
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ | |||
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
int8_t* dst_store = dst + oh * OW + ow; | |||
int8x16_t relu_reg = vdupq_n_s8(relu_val); | |||
#define cb(step) \ | |||
quant_store_s8( \ | |||
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ | |||
dst_store + step * OW, relu_reg); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
} | |||
#endif |
@@ -0,0 +1,250 @@ | |||
/** | |||
* \file | |||
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megdnn/arch.h" | |||
#if MGB_ENABLE_DOT | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" | |||
#include "src/common/unroll_macro.h" | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val) { | |||
//! 4x16 | |||
const size_t SH = 2; | |||
const size_t SW = 2; | |||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, | |||
4, 5, 6, 7, 6, 7, 8, 9}; | |||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, | |||
8, 9, 10, 11, 10, 11, 12, 13}; | |||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
//! init | |||
int32x4_t c[4][4]; | |||
#define cb(step) \ | |||
c[step][0] = vdupq_n_s32(bias); \ | |||
c[step][1] = vdupq_n_s32(bias); \ | |||
c[step][2] = vdupq_n_s32(bias); \ | |||
c[step][3] = vdupq_n_s32(bias); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
constexpr int flt_reg = 7; | |||
constexpr int flt_per_reg = 4; | |||
int8x16_t flt[7]; | |||
flt[0] = vld1q_s8(weight + 0 * 16); | |||
flt[1] = vld1q_s8(weight + 1 * 16); | |||
flt[2] = vld1q_s8(weight + 2 * 16); | |||
flt[3] = vld1q_s8(weight + 3 * 16); | |||
flt[4] = vld1q_s8(weight + 4 * 16); | |||
flt[5] = vld1q_s8(weight + 5 * 16); | |||
flt[6] = vld1q_s8(weight + 6 * 16); | |||
#define CAL_C(oh, flt_start) \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||
(flt_start + 0) % flt_per_reg); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||
(flt_start + 0) % flt_per_reg); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||
(flt_start + 0) % flt_per_reg); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||
(flt_start + 0) % flt_per_reg); \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||
(flt_start + 1) % flt_per_reg); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||
(flt_start + 1) % flt_per_reg); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||
(flt_start + 1) % flt_per_reg); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||
(flt_start + 1) % flt_per_reg); \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||
(flt_start + 2) % flt_per_reg); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||
(flt_start + 2) % flt_per_reg); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||
(flt_start + 2) % flt_per_reg); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||
(flt_start + 2) % flt_per_reg); | |||
#define LOAD_SRC(row_id) \ | |||
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ | |||
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ | |||
read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \ | |||
ext_8 = vextq_s8(read_w[0], read_w[1], 8); \ | |||
ext_24 = vextq_s8(read_w[1], read_w[2], 8); \ | |||
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ | |||
n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \ | |||
n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ | |||
ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \ | |||
n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ | |||
n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \ | |||
n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \ | |||
ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \ | |||
n0123_2 = n4567_0; \ | |||
n4567_2 = n89ab_0; \ | |||
n89ab_2 = ncdef_0; \ | |||
ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||
//! row 0 | |||
int8x16_t read_w[3]; | |||
read_w[0] = vld1q_s8(src_n); | |||
read_w[1] = vld1q_s8(src_n + 16); | |||
read_w[2] = vld1q_s8(src_n + 32); | |||
int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); | |||
int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); | |||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||
int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); | |||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||
int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); | |||
int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||
int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); | |||
int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||
int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); | |||
int8x16_t n0123_2 = n4567_0; | |||
int8x16_t n4567_2 = n89ab_0; | |||
int8x16_t n89ab_2 = ncdef_0; | |||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||
CAL_C(0, 0); | |||
//! row 1 | |||
LOAD_SRC(1); | |||
CAL_C(0, 3 * 1); | |||
//! row 2 | |||
LOAD_SRC(2); | |||
CAL_C(0, 3 * 2); | |||
CAL_C(1, 3 * 0); | |||
//! row 3 | |||
LOAD_SRC(3); | |||
CAL_C(0, 3 * 3); | |||
CAL_C(1, 3 * 1); | |||
//! row 4 | |||
LOAD_SRC(4); | |||
CAL_C(0, 3 * 4); | |||
CAL_C(1, 3 * 2); | |||
CAL_C(2, 3 * 0); | |||
//! row 5 | |||
LOAD_SRC(5); | |||
CAL_C(0, 3 * 5); | |||
CAL_C(1, 3 * 3); | |||
CAL_C(2, 3 * 1); | |||
//! row 6 | |||
LOAD_SRC(6); | |||
CAL_C(0, 3 * 6); | |||
CAL_C(1, 3 * 4); | |||
CAL_C(2, 3 * 2); | |||
CAL_C(3, 3 * 0); | |||
//! row 7 | |||
LOAD_SRC(7); | |||
CAL_C(0, 3 * 7); | |||
CAL_C(1, 3 * 5); | |||
CAL_C(2, 3 * 3); | |||
CAL_C(3, 3 * 1); | |||
//! row 8 | |||
LOAD_SRC(8); | |||
CAL_C(0, 3 * 8); | |||
CAL_C(1, 3 * 6); | |||
CAL_C(2, 3 * 4); | |||
CAL_C(3, 3 * 2); | |||
//! row 9 | |||
LOAD_SRC(9); | |||
CAL_C(1, 3 * 7); | |||
CAL_C(2, 3 * 5); | |||
CAL_C(3, 3 * 3); | |||
//! row 10 | |||
LOAD_SRC(10); | |||
CAL_C(1, 3 * 8); | |||
CAL_C(2, 3 * 6); | |||
CAL_C(3, 3 * 4); | |||
//! row 11 | |||
LOAD_SRC(11); | |||
CAL_C(2, 3 * 7); | |||
CAL_C(3, 3 * 5); | |||
//! row 12 | |||
LOAD_SRC(12); | |||
CAL_C(2, 3 * 8); | |||
CAL_C(3, 3 * 6); | |||
//! row 13 | |||
LOAD_SRC(13); | |||
CAL_C(3, 3 * 7); | |||
//! row 14 | |||
LOAD_SRC(14); | |||
CAL_C(3, 3 * 8); | |||
float32x4_t dst_reg[4][4]; | |||
#define cb(step) \ | |||
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ | |||
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ | |||
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ | |||
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
#define cb(step) \ | |||
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ | |||
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ | |||
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ | |||
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
int8_t* dst_store = dst + oh * OW + ow; | |||
int8x16_t relu_reg = vdupq_n_s8(relu_val); | |||
#define cb(step) \ | |||
quant_store_s8( \ | |||
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ | |||
dst_store + step * OW, relu_reg); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
} | |||
#endif |
@@ -54,6 +54,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | |||
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | |||
AlgoDotS8Im2colChanWiseLarge ds8_im2col_large_chanwise; | |||
AlgoDotS8DirectChanWiseLarge ds8_direct_large_chanwise; | |||
#endif | |||
AlgoI8x8x16Direct i8x8x16_direct; | |||
@@ -75,6 +77,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
public: | |||
AlgoPack() { | |||
#if MGB_ENABLE_DOT | |||
m_direct_algos.emplace_back(&ds8_direct_large_chanwise); | |||
m_direct_algos.emplace_back(&ds8_im2col_large_chanwise); | |||
m_direct_algos.emplace_back(&ds8_direct_stride1); | |||
m_direct_algos.emplace_back(&ds8_direct_stride2); | |||
m_direct_algos.emplace_back(&du8_direct_stride1); | |||
@@ -51,6 +51,8 @@ private: | |||
#endif | |||
#if MGB_ENABLE_DOT | |||
class AlgoDotS8DirectNCHWNCHW44; | |||
class AlgoDotS8DirectChanWiseLarge; | |||
class AlgoDotS8Im2colChanWiseLarge; | |||
class AlgoDotS8DirectStride1; | |||
class AlgoDotS8DirectStride2; | |||
class AlgoDotU8DirectStride1; | |||
@@ -143,8 +143,89 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( | |||
const KernSizeParam&) const { | |||
return int8x8x32_gemv_mk4_kern; | |||
} | |||
#if MGB_ENABLE_DOT | |||
namespace { | |||
void int8x8x32_gevm_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gevm_naive_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoInt8x8x32GevmDot::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
if (!cpuinfo_has_arm_neon_dot()) { | |||
return false; | |||
} | |||
auto M = kern_size_param.M; | |||
bool is_dtype_ok = | |||
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | |||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && is_dtype_ok && | |||
!kern_size_param.trA && !kern_size_param.trB && M == 1; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x32GevmDot::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
return true; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmDot::get_kern( | |||
const KernSizeParam&) const { | |||
return int8x8x32_gevm_dot_kern; | |||
} | |||
namespace { | |||
void int8x8x32_gevm_n32k4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gevm_naive_n32k4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
if (!cpuinfo_has_arm_neon_dot()) { | |||
return false; | |||
} | |||
auto M = kern_size_param.M; | |||
bool is_dtype_ok = | |||
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | |||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::N32K4_DOT && | |||
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && M == 1; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
return true; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::get_kern( | |||
const KernSizeParam&) const { | |||
return int8x8x32_gevm_n32k4_dot_kern; | |||
} | |||
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | |||
namespace { | |||
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
@@ -49,8 +49,69 @@ public: | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | |||
}; | |||
#if MGB_ENABLE_DOT | |||
class MatrixMulImpl::AlgoInt8x8x32GevmDot : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_DOT"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||
WorkspaceBundle get_bundle(const KernSizeParam&) const override { | |||
return WorkspaceBundle{nullptr, {}}; | |||
} | |||
kern_naked_t get_kern_naked(const KernSizeParam&) const override { | |||
megdnn_assert(0, "naked kern no impl"); | |||
} | |||
void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride) | |||
const override { | |||
megdnn_assert(0, "pack_A no impl"); | |||
} | |||
void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax) | |||
const override { | |||
megdnn_assert(0, "pack_B no impl"); | |||
} | |||
InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; }; | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_DOT) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, N32K4_DOT) | |||
WorkspaceBundle get_bundle(const KernSizeParam&) const override { | |||
return WorkspaceBundle{nullptr, {}}; | |||
} | |||
kern_naked_t get_kern_naked(const KernSizeParam&) const override { | |||
megdnn_assert(0, "naked kern no impl"); | |||
} | |||
void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride) | |||
const override { | |||
megdnn_assert(0, "pack_A no impl"); | |||
} | |||
void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax) | |||
const override { | |||
megdnn_assert(0, "pack_B no impl"); | |||
} | |||
InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; }; | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
@@ -2,6 +2,7 @@ | |||
#include "megdnn/oprs.h" | |||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "midout.h" | |||
@@ -430,5 +431,398 @@ void arm_common::gemv_like_mk4_dot( | |||
MIDOUT_END(); | |||
} | |||
#endif | |||
#if MGB_ENABLE_DOT | |||
namespace { | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
void gevm_naive_dot_impl( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||
bool load_c) { | |||
constexpr size_t n_block = 32; | |||
const size_t n_end = N / n_block * n_block; | |||
const size_t n_remain = N - n_end; | |||
constexpr size_t k_block = 4; | |||
constexpr size_t k_block_x2 = k_block * 2; | |||
const size_t k_end = (K / k_block_x2) * k_block_x2; | |||
const size_t k_remain = K - k_end; | |||
for (size_t n = 0; n < n_end; n += n_block) { | |||
if (K < k_block_x2) { | |||
if (!load_c) { | |||
for (size_t i = 0; i < n_block; ++i) { | |||
C[n + i] = 0; | |||
} | |||
} | |||
for (size_t k = 0; k < K; ++k) { | |||
for (size_t i = 0; i < n_block; ++i) { | |||
C[n + i] += A[k] * B[k * Bstride + n + i]; | |||
} | |||
} | |||
continue; | |||
} | |||
int32x4_t c[8]; | |||
if (load_c) { | |||
#define cb(step) c[step] = vld1q_s32(C + n + step * 4); | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
} else { | |||
#define cb(step) c[step] = vdupq_n_s32(0); | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
} | |||
int8x16_t a[2]; | |||
a[0] = vld1q_dup_s32(A); | |||
int8x16_t b[2][8]; | |||
#define cb(step) \ | |||
b[0][step * 2 + 0] = vld1q_s8(B + (0 + step) * Bstride + n); \ | |||
b[0][step * 2 + 1] = vld1q_s8(B + (0 + step) * Bstride + n + 16); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
size_t k_buffer_end = k_end - k_block_x2; | |||
for (size_t k = 0; k < k_buffer_end; k += k_block_x2) { | |||
//! double buffer main | |||
#define cb(step) \ | |||
b[1][step * 2 + 0] = vld1q_s8(B + (k + step + k_block) * Bstride + n); \ | |||
b[1][step * 2 + 1] = vld1q_s8(B + (k + step + k_block) * Bstride + n + 16); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
a[1] = vld1q_dup_s32(A + k + k_block); | |||
int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]); | |||
int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]); | |||
int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]); | |||
int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]); | |||
int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||
int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||
int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||
int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||
c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]); | |||
c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]); | |||
c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]); | |||
c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]); | |||
c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]); | |||
c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]); | |||
c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]); | |||
c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]); | |||
#define cb(step) \ | |||
b[0][step * 2 + 0] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n); \ | |||
b[0][step * 2 + 1] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n + 16); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
a[0] = vld1q_dup_s32(A + k + k_block_x2); | |||
ab0 = vzipq_s8(b[1][0], b[1][2]); | |||
cd0 = vzipq_s8(b[1][4], b[1][6]); | |||
ab1 = vzipq_s8(b[1][1], b[1][3]); | |||
cd1 = vzipq_s8(b[1][5], b[1][7]); | |||
abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||
abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||
abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||
abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||
c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]); | |||
c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]); | |||
c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]); | |||
c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]); | |||
c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]); | |||
c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]); | |||
c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]); | |||
c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]); | |||
} | |||
//! double buffer remain | |||
#define cb(step) \ | |||
b[1][step * 2 + 0] = vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n); \ | |||
b[1][step * 2 + 1] = \ | |||
vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n + 16); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
a[1] = vld1q_dup_s32(A + k_buffer_end + k_block); | |||
int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]); | |||
int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]); | |||
int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]); | |||
int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]); | |||
int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||
int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||
int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||
int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||
c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]); | |||
c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]); | |||
c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]); | |||
c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]); | |||
c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]); | |||
c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]); | |||
c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]); | |||
c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]); | |||
ab0 = vzipq_s8(b[1][0], b[1][2]); | |||
cd0 = vzipq_s8(b[1][4], b[1][6]); | |||
ab1 = vzipq_s8(b[1][1], b[1][3]); | |||
cd1 = vzipq_s8(b[1][5], b[1][7]); | |||
abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||
abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||
abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||
abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||
c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]); | |||
c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]); | |||
c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]); | |||
c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]); | |||
c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]); | |||
c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]); | |||
c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]); | |||
c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]); | |||
vst1q_s32(C + n + 0 * 4, c[0]); | |||
vst1q_s32(C + n + 1 * 4, c[1]); | |||
vst1q_s32(C + n + 2 * 4, c[2]); | |||
vst1q_s32(C + n + 3 * 4, c[3]); | |||
vst1q_s32(C + n + 4 * 4, c[4]); | |||
vst1q_s32(C + n + 5 * 4, c[5]); | |||
vst1q_s32(C + n + 6 * 4, c[6]); | |||
vst1q_s32(C + n + 7 * 4, c[7]); | |||
if (k_remain > 0) { | |||
for (size_t k = k_end; k < K; ++k) { | |||
for (size_t i = 0; i < n_block; ++i) { | |||
C[n + i] += A[k] * B[k * Bstride + n + i]; | |||
} | |||
} | |||
} | |||
} | |||
if (n_remain > 0) { | |||
for (size_t n = n_end; n < N; ++n) { | |||
if (!load_c) { | |||
C[n] = 0; | |||
} | |||
for (size_t k = 0; k < K; ++k) { | |||
C[n] += A[k] * B[k * Bstride + n]; | |||
} | |||
} | |||
} | |||
} | |||
#if MEGDNN_ARMV7 | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
void gevm_naive_dot_n32k4_impl( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||
bool load_c) { | |||
//! input must be N/32, k/4, 32, 4 | |||
//! TODO: add prefetch | |||
//! TODO: add double buffer | |||
constexpr size_t n_block = 32; | |||
constexpr size_t k_block = 4; | |||
for (size_t n = 0; n < N; n += n_block) { | |||
int32x4_t c[n_block / 4]; | |||
#define cb(step) c[step] = vdupq_n_s32(0); | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
const int8_t* b_base = B + n * K; | |||
for (size_t k = 0; k < K; k += k_block) { | |||
int8x16_t a[1]; | |||
int8x16_t b[1][8]; | |||
#define cb(step) b[0][step] = vld1q_s8(b_base + k * 32 + 16 * step); | |||
UNROLL_CALL_RAW(8, cb); | |||
#undef cb | |||
a[0] = vld1q_dup_s32(A + k); | |||
c[0] = vdotq_s32(c[0], b[0][0], a[0]); | |||
c[1] = vdotq_s32(c[1], b[0][1], a[0]); | |||
c[2] = vdotq_s32(c[2], b[0][2], a[0]); | |||
c[3] = vdotq_s32(c[3], b[0][3], a[0]); | |||
c[4] = vdotq_s32(c[4], b[0][4], a[0]); | |||
c[5] = vdotq_s32(c[5], b[0][5], a[0]); | |||
c[6] = vdotq_s32(c[6], b[0][6], a[0]); | |||
c[7] = vdotq_s32(c[7], b[0][7], a[0]); | |||
} | |||
vst1q_s32(C + n + 0 * 4, c[0]); | |||
vst1q_s32(C + n + 1 * 4, c[1]); | |||
vst1q_s32(C + n + 2 * 4, c[2]); | |||
vst1q_s32(C + n + 3 * 4, c[3]); | |||
vst1q_s32(C + n + 4 * 4, c[4]); | |||
vst1q_s32(C + n + 5 * 4, c[5]); | |||
vst1q_s32(C + n + 6 * 4, c[6]); | |||
vst1q_s32(C + n + 7 * 4, c[7]); | |||
} | |||
} | |||
#else | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
inline void n32k4_dot( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t K) { | |||
int oddk = (K & 1); | |||
K = ((K + 1) / 2) - 1; | |||
//! C q0-q7 | |||
//! A q8-q9 | |||
//! B q10-q25 | |||
asm volatile( | |||
// load accumulator C | |||
"1:\n" | |||
"eor v0.16b, v0.16b, v0.16b\n" | |||
"eor v1.16b, v1.16b, v1.16b\n" | |||
"eor v2.16b, v2.16b, v2.16b\n" | |||
"eor v3.16b, v3.16b, v3.16b\n" | |||
"eor v4.16b, v4.16b, v4.16b\n" | |||
"eor v5.16b, v5.16b, v5.16b\n" | |||
"eor v6.16b, v6.16b, v6.16b\n" | |||
"eor v7.16b, v7.16b, v7.16b\n" | |||
"ld1r {v8.4s}, [%[a_ptr]]\n" | |||
"ld1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[b_ptr]], 64\n" | |||
"ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[b_ptr]], 64\n" | |||
"add %[a_ptr], %[a_ptr], #4\n" | |||
"cmp %w[k], #0\n" | |||
"beq 4f\n" | |||
"2: \n" | |||
// Loop proper | |||
"3:\n" | |||
"ld1r {v9.4s}, [%[a_ptr]]\n" | |||
"sdot v0.4s, v10.16b, v8.16b\n" | |||
"ldr q18, [%[b_ptr], #0]\n" | |||
"sdot v1.4s, v11.16b, v8.16b\n" | |||
"ldr q19, [%[b_ptr], #16]\n" | |||
"sdot v2.4s, v12.16b, v8.16b\n" | |||
"ldr q20, [%[b_ptr], #32]\n" | |||
"add %[a_ptr], %[a_ptr], #4\n" | |||
"sdot v3.4s, v13.16b, v8.16b\n" | |||
"ldr q21, [%[b_ptr], #48]\n" | |||
"sdot v4.4s, v14.16b, v8.16b\n" | |||
"ldr q22, [%[b_ptr], #64]\n" | |||
"sdot v5.4s, v15.16b, v8.16b\n" | |||
"ldr q23, [%[b_ptr], #80]\n" | |||
"sdot v6.4s, v16.16b, v8.16b\n" | |||
"ldr q24, [%[b_ptr], #96]\n" | |||
"sdot v7.4s, v17.16b, v8.16b\n" | |||
"ldr q25, [%[b_ptr], #112]\n" | |||
"ld1r {v8.4s}, [%[a_ptr]]\n" | |||
"sdot v0.4s, v18.16b, v9.16b\n" | |||
"ldr q10, [%[b_ptr], #128]\n" | |||
"sdot v1.4s, v19.16b, v9.16b\n" | |||
"ldr q11, [%[b_ptr], #144]\n" | |||
"sdot v2.4s, v20.16b, v9.16b\n" | |||
"ldr q12, [%[b_ptr], #160]\n" | |||
"sdot v3.4s, v21.16b, v9.16b\n" | |||
"ldr q13, [%[b_ptr], #176]\n" | |||
"sdot v4.4s, v22.16b, v9.16b\n" | |||
"ldr q14, [%[b_ptr], #192]\n" | |||
"sdot v5.4s, v23.16b, v9.16b\n" | |||
"ldr q15, [%[b_ptr], #208]\n" | |||
"sdot v6.4s, v24.16b, v9.16b\n" | |||
"ldr q16, [%[b_ptr], #224]\n" | |||
"sdot v7.4s, v25.16b, v9.16b\n" | |||
"ldr q17, [%[b_ptr], #240]\n" | |||
"add %[a_ptr], %[a_ptr], #4\n" | |||
"add %[b_ptr], %[b_ptr], #256\n" | |||
"subs %w[k], %w[k], #1\n" | |||
"bne 3b\n" | |||
"4:\n" | |||
"cmp %w[oddk], #1\n" | |||
"beq 5f\n" | |||
// Even tail | |||
"ld1r {v9.4s}, [%[a_ptr]]\n" | |||
"sdot v0.4s, v10.16b, v8.16b\n" | |||
"ldr q18, [%[b_ptr], #0]\n" | |||
"sdot v1.4s, v11.16b, v8.16b\n" | |||
"ldr q19, [%[b_ptr], #16]\n" | |||
"sdot v2.4s, v12.16b, v8.16b\n" | |||
"ldr q20, [%[b_ptr], #32]\n" | |||
"sdot v3.4s, v13.16b, v8.16b\n" | |||
"ldr q21, [%[b_ptr], #48]\n" | |||
"sdot v4.4s, v14.16b, v8.16b\n" | |||
"ldr q22, [%[b_ptr], #64]\n" | |||
"sdot v5.4s, v15.16b, v8.16b\n" | |||
"ldr q23, [%[b_ptr], #80]\n" | |||
"sdot v6.4s, v16.16b, v8.16b\n" | |||
"ldr q24, [%[b_ptr], #96]\n" | |||
"sdot v7.4s, v17.16b, v8.16b\n" | |||
"ldr q25, [%[b_ptr], #112]\n" | |||
"sdot v0.4s, v18.16b, v9.16b\n" | |||
"sdot v1.4s, v19.16b, v9.16b\n" | |||
"sdot v2.4s, v20.16b, v9.16b\n" | |||
"sdot v3.4s, v21.16b, v9.16b\n" | |||
"sdot v4.4s, v22.16b, v9.16b\n" | |||
"sdot v5.4s, v23.16b, v9.16b\n" | |||
"sdot v6.4s, v24.16b, v9.16b\n" | |||
"sdot v7.4s, v25.16b, v9.16b\n" | |||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n" | |||
"st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n" | |||
"b 6f\n" | |||
"5:\n" | |||
// Odd tail | |||
"sdot v0.4s, v10.16b, v8.16b\n" | |||
"sdot v1.4s, v11.16b, v8.16b\n" | |||
"sdot v2.4s, v12.16b, v8.16b\n" | |||
"sdot v3.4s, v13.16b, v8.16b\n" | |||
"sdot v4.4s, v14.16b, v8.16b\n" | |||
"sdot v5.4s, v15.16b, v8.16b\n" | |||
"sdot v6.4s, v16.16b, v8.16b\n" | |||
"sdot v7.4s, v17.16b, v8.16b\n" | |||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n" | |||
"st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n" | |||
"6:\n" | |||
: [a_ptr] "+r"(A), [b_ptr] "+r"(B), [k] "+r"(K), [c_ptr] "+r"(C), | |||
[oddk] "+r"(oddk) | |||
: | |||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
"v22", "v23", "v24", "v25", "cc", "memory"); | |||
} | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
void gevm_naive_dot_n32k4_impl( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||
bool load_c) { | |||
//! input must be N/32, k/4, 32, 4 | |||
//! TODO: add prefetch | |||
//! TODO: add double buffer | |||
constexpr size_t n_block = 32; | |||
for (size_t n = 0; n < N; n += n_block) { | |||
n32k4_dot(A, B + n * K, C + n, K / 4); | |||
} | |||
} | |||
#endif | |||
} // namespace | |||
void arm_common::gevm_naive_dot( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(M == 1); | |||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot"_hash)) { | |||
size_t cache_size = 256 * 1024; | |||
size_t k_group = N * K / cache_size; | |||
constexpr size_t k_align = 8; | |||
if (k_group >= 2) { | |||
size_t k_per_group = ((K / k_group) + k_align - 1) / k_align * k_align; | |||
for (size_t k = 0; k < K; k += k_per_group) { | |||
size_t real_k = std::min(K - k, k_per_group); | |||
gevm_naive_dot_impl( | |||
A + k, B + k * Bstride, C, M, N, real_k, Astride, Bstride, | |||
Cstride, k != 0); | |||
} | |||
} else { | |||
gevm_naive_dot_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false); | |||
} | |||
} | |||
MIDOUT_END(); | |||
} | |||
void arm_common::gevm_naive_n32k4_dot( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(M == 1); | |||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot_nk4"_hash)) { | |||
gevm_naive_dot_n32k4_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false); | |||
} | |||
MIDOUT_END(); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -22,6 +22,14 @@ void gemv_like_mk4( | |||
void gemv_like_mk4_dot( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||
void gevm_naive_dot( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||
void gevm_naive_n32k4_dot( | |||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||
#endif | |||
} // namespace arm_common | |||
@@ -14,6 +14,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | |||
#if MGB_ENABLE_DOT | |||
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | |||
AlgoInt8x8x32GevmDot int8x8x32_gevm_dot; | |||
AlgoInt8x8x32GevmN32K4Dot int8x8x32_gevm_n32k4_dot; | |||
#endif | |||
AlgoGevm gevm; | |||
@@ -28,6 +30,8 @@ public: | |||
#endif | |||
#if MGB_ENABLE_DOT | |||
m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||
m_all_algos.emplace_back(&int8x8x32_gevm_dot); | |||
m_all_algos.emplace_back(&int8x8x32_gevm_n32k4_dot); | |||
#endif | |||
m_all_algos.emplace_back(&int8x8x32_gemv); | |||
m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
@@ -31,7 +31,9 @@ protected: | |||
class AlgoF16Gemv; | |||
#endif | |||
#if MGB_ENABLE_DOT | |||
class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT | |||
class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT | |||
class AlgoInt8x8x32GevmDot; // Arm_common Int8x8x32 Gevm NCHW DOT | |||
class AlgoInt8x8x32GevmN32K4Dot; // Arm_common Int8x8x32 Gevm NK4 | |||
#endif | |||
class AlgoInt8x8x16; // Arm_common Int 8x8x16 | |||
class AlgoPack; | |||
@@ -469,7 +469,7 @@ __ai float64x2_t vbitq_f64(float64x2_t dst, float64x2_t v1, uint64x2_t mask) { | |||
#endif | |||
#if MEGDNN_ARMV7 | |||
__ai int8x16_t vqtbl1q_s8(int8x16_t& a, uint8x16_t& idx) { | |||
__ai int8x16_t vqtbl1q_s8(int8x16_t a, uint8x16_t idx) { | |||
int8x8_t src_low = vget_low_s8(a); | |||
int8x8_t src_high = vget_high_s8(a); | |||
return vcombine_s8( | |||
@@ -726,6 +726,13 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { | |||
asm volatile("fmls %0.4s, %1.4s, %2.4s\n" : "+w"(a) : "w"(b), "w"(v) :); | |||
return a; | |||
} | |||
#if __ARM_ARCH < 8 | |||
__ai int32x4_t vcvtaq_s32_f32(float32x4_t val) { | |||
float32x4_t vinc0 = vbslq_f32( | |||
vcgeq_f32(val, vdupq_n_f32(0.f)), vdupq_n_f32(0.5f), vdupq_n_f32(-0.5f)); | |||
return vcvtq_s32_f32(vaddq_f32(val, vinc0)); | |||
} | |||
#endif | |||
#if MGB_ENABLE_DOT | |||
#undef __ARM_FEATURE_DOTPROD | |||
#endif | |||
@@ -61,6 +61,15 @@ void MatrixMulForward::deduce_layout( | |||
"(transposed) B is (%zu,%zu)", | |||
A0, A1, B0, B1); | |||
C = TensorLayout(TensorShape({A0, B1}), C.dtype); | |||
} else if (param().format == param::MatrixMul::Format::N32K4_DOT) { | |||
A0 = A.shape[0]; | |||
A1 = A.shape[1]; | |||
B0 = B.shape[0]; | |||
B1 = B.shape[1]; | |||
megdnn_assert(!m_param.transposeA && !m_param.transposeB); | |||
megdnn_assert(A0 == 1 && A1 % 4 == 0); | |||
megdnn_assert(B.ndim == 4); | |||
C = TensorLayout(TensorShape({A0, B0 * 32}), C.dtype); | |||
} else { | |||
auto do_deduce = [&](size_t pack_size) { | |||
megdnn_assert( | |||
@@ -132,6 +141,18 @@ void MatrixMulForward::check_exec( | |||
megdnn_assert(A0 == C0, "%s", errmsg().c_str()); | |||
megdnn_assert(B1 == C1, "%s", errmsg().c_str()); | |||
megdnn_assert(A1 == B0, "%s", errmsg().c_str()); | |||
} else if (param().format == param::MatrixMul::Format::N32K4_DOT) { | |||
size_t A0 = A.shape[0]; | |||
size_t A1 = A.shape[1]; | |||
size_t B2 = B.shape[2]; | |||
size_t B3 = B.shape[3]; | |||
megdnn_assert(!m_param.transposeA && !m_param.transposeB); | |||
megdnn_assert(A0 == 1 && A1 % 4 == 0); | |||
megdnn_assert(B.ndim == 4); | |||
megdnn_assert(B2 == 32 && B3 == 4); | |||
megdnn_assert_contiguous(A); | |||
megdnn_assert_contiguous(B); | |||
megdnn_assert_contiguous(C); | |||
} else { | |||
megdnn_assert_eq_size_t(A.ndim, 4_z); | |||
megdnn_assert_eq_size_t(B.ndim, 3_z); | |||
@@ -442,9 +442,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( | |||
megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format)); | |||
} | |||
BiasMode bias_mode; | |||
//! dst only channel BIAS is viewed as BROADCAST_CHANNEL_BIAS | |||
bool dst_only_c = dst[0] == 1 && dst[spatial_pos] == 1 && dst[spatial_pos + 1] == 1; | |||
if (bias.ndim == 0) { | |||
bias_mode = BiasMode::NO_BIAS; | |||
} else if (bias.eq_shape(dst)) { | |||
} else if (bias.eq_shape(dst) && !dst_only_c) { | |||
bias_mode = BiasMode::BIAS; | |||
} else { | |||
//! just check the ndim, the detail shape check is in check_exec | |||
@@ -258,6 +258,9 @@ public: | |||
ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, | |||
ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, | |||
ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, | |||
//! LARGE for large filter | |||
ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8, | |||
ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8, | |||
ARM_COMMON_DIRECT_STRD1_DOT_S8, | |||
ARM_COMMON_DIRECT_STRD2_DOT_S8, | |||
ARM_COMMON_DIRECT_NCHW44_DOT_S8, | |||
@@ -195,11 +195,11 @@ MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( | |||
kern_size_param.trB = param().transposeB; | |||
kern_size_param.compute_mode = param().compute_mode; | |||
kern_size_param.format = param().format; | |||
size_t pack_size = MatrixMulForward::pack_size(param().format); | |||
kern_size_param.K *= pack_size; | |||
kern_size_param.M *= pack_size; | |||
if (param().format != Param::Format::N32K4_DOT) { | |||
size_t pack_size = MatrixMulForward::pack_size(param().format); | |||
kern_size_param.K *= pack_size; | |||
kern_size_param.M *= pack_size; | |||
} | |||
return kern_size_param; | |||
} | |||
@@ -122,6 +122,8 @@ public: | |||
ARM_COMMON_INT8X8X32_GEMV, | |||
ARM_COMMON_INT8X8X32_GEMV_MK4, | |||
ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | |||
ARM_COMMON_INT8X8X32_GEVM_DOT, | |||
ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT, | |||
ARM_COMMON_F16_GEMV, | |||
ARM_COMMON_GEVM, | |||
#if MEGDNN_AARCH64 | |||
@@ -175,6 +177,7 @@ public: | |||
enum class AlgoSet : uint32_t { | |||
ALGO_TYPE_GEMM = 0, | |||
ALGO_TYPE_GEMV = 1, | |||
ALGO_TYPE_GEVM = 2, | |||
}; | |||
enum class PackMode : uint32_t { | |||
@@ -105,6 +105,34 @@ void run_matrix_mul_mk4_dot_tpl( | |||
template < | |||
typename itype, typename otype, bool transA, bool transB, | |||
typename comp_type = otype> | |||
void run_matrix_mul_n32k4_dot_tpl( | |||
const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K, | |||
size_t LDA, size_t, size_t, const DType& A_type, const DType& B_type) { | |||
Getter<itype, comp_type> getterA(A_type), getterB(B_type); | |||
megdnn_assert(!transA && !transB); | |||
for (size_t m = 0; m < M; ++m) { | |||
for (size_t n = 0; n < N; n += 32) { | |||
comp_type res[32] = {comp_type(0)}; | |||
for (size_t k = 0; k < K; k += 4) { | |||
for (size_t i = 0; i < 32; i++) { | |||
comp_type av, bv; | |||
for (size_t j = 0; j < 4; j++) { | |||
av = getterA(A[m * LDA + k + j]); | |||
bv = getterA(B[n * K + k * 32 + i * 4 + j]); | |||
res[i] += av * bv; | |||
} | |||
} | |||
} | |||
for (size_t i = 0; i < 32; i++) { | |||
C[n + i] = res[i]; | |||
} | |||
} | |||
} | |||
} | |||
template < | |||
typename itype, typename otype, bool transA, bool transB, | |||
typename comp_type = otype> | |||
void run_matrix_mul_mk8_tpl( | |||
const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K, | |||
size_t LDA, size_t LDB, size_t LDC, const DType& A_type, const DType& B_type) { | |||
@@ -251,6 +279,10 @@ void dispatch_ta_tb( | |||
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ | |||
} else if (format == param::MatrixMul::Format::N32K4_DOT) { \ | |||
return run_matrix_mul_n32k4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ | |||
} else if (format == param::MatrixMul::Format::MK8) { \ | |||
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||
@@ -160,7 +160,7 @@ static void benchmark_convbias( | |||
.set_display(false); | |||
} | |||
auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; | |||
#if MGB_ENBALE_DOT | |||
#if MGB_ENABLE_DOT | |||
if (!is_fp32) { | |||
nchw44_algo_regx = ".*DOT.*"; | |||
} | |||
@@ -1626,7 +1626,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | |||
#endif | |||
#if MGB_ENBALE_DOT | |||
#if MGB_ENABLE_DOT | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | |||
// have to remove preferred restrict in usable func before run the benchmark | |||
@@ -2011,6 +2011,80 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { | |||
} | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args; | |||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel, size_t stride, | |||
NonlineMode nonline_mode) { | |||
size_t p = kernel / 2; | |||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||
return; | |||
param::ConvBias param; | |||
param.stride_h = stride; | |||
param.stride_w = stride; | |||
param.pad_h = p; | |||
param.pad_w = p; | |||
param.nonlineMode = nonline_mode; | |||
param.format = param::ConvBias::Format::NCHW; | |||
param.sparse = ConvBiasForward::Param::Sparse::GROUP; | |||
//! channel bias | |||
args.emplace_back( | |||
param, TensorShape{1, group, h, w}, | |||
TensorShape{group, 1, 1, kernel, kernel}, TensorShape{1, group, 1, 1}); | |||
}; | |||
run(64, 64, 64, 9, 1, NonlineMode::RELU); | |||
run(64, 40, 40, 9, 2, NonlineMode::RELU); | |||
run(64, 20, 20, 9, 1, NonlineMode::RELU); | |||
constexpr size_t RUN = 120; | |||
Benchmarker<ConvBias> benchmark0(handle()); | |||
benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.set_dtype(4, dtype::QuantizedS8(60.25f)); | |||
benchmark0.set_display(false); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"ARMDOTS8_DIRECT_CHANWISE_LARGE")); | |||
Benchmarker<ConvBias> benchmark1(handle()); | |||
benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.set_dtype(4, dtype::QuantizedS8(60.25f)); | |||
benchmark1.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"ARMDOTS8_IM2COL_CHANWISE_LARGE")); | |||
benchmark1.set_display(false); | |||
benchmark1.set_times(RUN); | |||
for (auto&& arg : args) { | |||
TensorLayout dst_layout; | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
opr->param() = arg.param; | |||
opr->deduce_layout( | |||
{arg.src, dtype::Int8()}, {arg.filter, dtype::Int8()}, | |||
{arg.bias, dtype::Int32()}, {}, dst_layout); | |||
//! dst.nr_elems * FH * FW * 2 | |||
float computations = | |||
dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6; | |||
auto used0 = benchmark0.set_param(arg.param).exec( | |||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||
RUN; | |||
auto used1 = benchmark1.set_param(arg.param).exec( | |||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||
RUN; | |||
printf("%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " | |||
"speedup: %f\n", | |||
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), used0, | |||
computations / used0, used1, computations / used1, used1 / used0); | |||
} | |||
} | |||
#endif | |||
#endif | |||
@@ -2194,7 +2268,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) { | |||
dtype::QuantizedS8 stype(2.5f); | |||
dtype::QuantizedS32 dtype(6.25f); | |||
#if MEGDNN_AARCH64 | |||
#if MGB_ENBALE_DOT | |||
#if MGB_ENABLE_DOT | |||
benchmark_conv1x1( | |||
"AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | |||
#else | |||
@@ -2212,7 +2286,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) { | |||
dtype::QuantizedS32 dtype(1.2 * 1.2); | |||
#if MEGDNN_AARCH64 | |||
#if MGB_ENBALE_DOT | |||
#if MGB_ENABLE_DOT | |||
benchmark_conv1x1( | |||
"AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | |||
#else | |||
@@ -136,6 +136,84 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args( | |||
return args; | |||
} | |||
std::vector<conv_bias::TestArg> get_channel_wise_args( | |||
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | |||
bool no_full_bias, bool support_relu) { | |||
using namespace conv_bias; | |||
using Param = param::ConvBias; | |||
using NLMode = param::ConvBias::NonlineMode; | |||
std::vector<TestArg> args; | |||
auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, | |||
size_t stride, NLMode nlmode, bool pad) { | |||
Param param; | |||
param.stride_h = stride; | |||
param.stride_w = stride; | |||
if (pad) { | |||
param.pad_h = kernel / 2; | |||
param.pad_w = kernel / 2; | |||
} else { | |||
param.pad_h = 0; | |||
param.pad_w = 0; | |||
} | |||
param.nonlineMode = nlmode; | |||
param.format = param::ConvBias::Format::NCHW; | |||
param.sparse = param::ConvBias::Sparse::GROUP; | |||
args.emplace_back( | |||
param, TensorShape{n, group, h, w}, | |||
TensorShape{group, 1, 1, kernel, kernel}, TensorShape{}); | |||
if (!no_bias) { | |||
args.emplace_back( | |||
param, TensorShape{n, group, h, w}, | |||
TensorShape{group, 1, 1, kernel, kernel}, | |||
TensorShape{1, group, 1, 1}); | |||
} | |||
if (!no_full_bias) { | |||
args.emplace_back( | |||
param, TensorShape{n, group, h, w}, | |||
TensorShape{group, 1, 1, kernel, kernel}, | |||
TensorShape{ | |||
n, group, (h + 2 * param.pad_w - kernel) / stride + 1, | |||
(w + 2 * param.pad_w - kernel) / stride + 1}); | |||
} | |||
}; | |||
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | |||
if (!no_nonlinemode) { | |||
nonlinemode.emplace_back(NLMode::RELU); | |||
nonlinemode.emplace_back(NLMode::H_SWISH); | |||
} else if (support_relu) { | |||
nonlinemode.emplace_back(NLMode::RELU); | |||
} | |||
for (size_t n : {1, 2}) { | |||
for (auto nlmode : nonlinemode) { | |||
for (bool pad : {true}) { | |||
for (size_t group : {1, 3, 7}) { | |||
for (size_t size : {4, 6, 7, 9, 16, 20, 32, 55}) { | |||
for (size_t kern : kernel) { | |||
pack(n, group, size, size, kern, stride, nlmode, pad); | |||
} | |||
} | |||
} | |||
} | |||
for (bool pad : {false}) { | |||
for (size_t group : {7}) { | |||
for (size_t size : {37}) { | |||
for (size_t kern : kernel) { | |||
if (size < kern) | |||
continue; | |||
pack(n, group, size, size, kern, stride, nlmode, pad); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
return args; | |||
} | |||
std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args( | |||
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | |||
bool no_full_bias) { | |||
@@ -226,7 +304,7 @@ void checker_conv_bias_qint8x8x8( | |||
.set_rng(1, &rng) | |||
.set_rng(2, &rng); | |||
for (auto&& arg : args) { | |||
checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); | |||
checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
} | |||
void checker_conv_bias_qint8x8x32( | |||
@@ -532,6 +610,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { | |||
/****************************dot qint8 direct*************************/ | |||
#if MGB_ENABLE_DOT | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) { | |||
checker_conv_bias_qint8x8x8( | |||
get_channel_wise_args({9}, 1, false, true, true, true), handle(), | |||
"ARMDOTS8_DIRECT_CHANWISE_LARGE"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) { | |||
checker_conv_bias_qint8x8x8( | |||
get_channel_wise_args({9}, 2, false, true, true, true), handle(), | |||
"ARMDOTS8_DIRECT_CHANWISE_LARGE"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S1) { | |||
checker_conv_bias_qint8x8x8( | |||
get_channel_wise_args({9}, 1, false, true, true, true), handle(), | |||
"ARMDOTS8_IM2COL_CHANWISE_LARGE"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S2) { | |||
checker_conv_bias_qint8x8x8( | |||
get_channel_wise_args({9}, 2, false, true, true, true), handle(), | |||
"ARMDOTS8_IM2COL_CHANWISE_LARGE"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | |||
auto args = get_nchw44_conv_bias_args( | |||
{2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true); | |||
@@ -219,6 +219,113 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) { | |||
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | |||
run(M, K, 1); | |||
} | |||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM_DOT) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_DOT"); | |||
checker.set_before_exec_callback(algo_ck); | |||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||
Param param; | |||
param.format = Param::Format::DEFAULT; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
auto run = [&](size_t M, size_t N, size_t K) { | |||
TensorShape A, B; | |||
A = TensorShape{M, K}; | |||
B = TensorShape{K, N}; | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::Int8()) | |||
.set_dtype(1, dtype::Int8()) | |||
.set_dtype(2, dtype::Int32()) | |||
.execs({A, B, {}}); | |||
}; | |||
run(1, 32, 4); | |||
for (int n = 7; n < 43; n += 3) { | |||
for (int k = 1; k < 33; k += 3) { | |||
run(1, n, k); | |||
} | |||
} | |||
} | |||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM_N32K4_DOT) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"); | |||
checker.set_before_exec_callback(algo_ck); | |||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||
Param param; | |||
param.format = Param::Format::N32K4_DOT; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
auto run = [&](size_t M, size_t N, size_t K) { | |||
TensorShape A, B; | |||
A = TensorShape{M, K}; | |||
B = TensorShape{N / 32, K / 4, 32, 4}; | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::Int8()) | |||
.set_dtype(1, dtype::Int8()) | |||
.set_dtype(2, dtype::Int32()) | |||
.execs({A, B, {}}); | |||
}; | |||
run(1, 32, 4); | |||
for (int n = 32; n < 65; n += 32) { | |||
for (int k = 4; k < 39; k += 4) { | |||
run(1, n, k); | |||
} | |||
} | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(ARM_COMMON, BENCHMARK_QINT8x8x32_GEVM_N32K4_DOT) { | |||
using Param = MatrixMul::Param; | |||
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"); | |||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||
Param param; | |||
param.format = Param::Format::N32K4_DOT; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
constexpr size_t RUNS = 2000; | |||
Benchmarker<MatrixMul> benchmarker_int(handle()); | |||
benchmarker_int.set_times(RUNS) | |||
.set_dtype(0, dtype::Int8{}) | |||
.set_dtype(1, dtype::Int8{}) | |||
.set_dtype(2, dtype::Int32{}) | |||
.set_param(param) | |||
.set_before_exec_callback(algo_ck) | |||
.set_display(false); | |||
Benchmarker<MatrixMul> benchmarker_float(handle()); | |||
benchmarker_float.set_display(false).set_times(RUNS); | |||
auto bench = [&](size_t M, size_t N, size_t K) { | |||
auto int_used = | |||
benchmarker_int.exec({{M, K}, {N / 32, K / 4, 32, 4}, {}}) / RUNS; | |||
auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
float computations = 2.f * M * K * N * 1e-6; | |||
float through_put = (M * K + N * K + M * N) * 1e-6; | |||
printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " | |||
"%f Gflops speedup: %f, through put %f G\n", | |||
M, K, N, float_used, computations / float_used, int_used, | |||
computations / int_used, float_used / int_used, through_put / int_used); | |||
}; | |||
bench(1, 256, 512); | |||
bench(1, 256, 1024); | |||
bench(1, 512, 512); | |||
bench(1, 512, 1024); | |||
} | |||
#endif | |||
#endif | |||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | |||