GitOrigin-RevId: a28a97fcb5
release-1.10
@@ -557,7 +557,10 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), | ||||
Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:' | Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:' | ||||
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' | ||||
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) | |||||
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'), | |||||
Doc('N32K4_DOT = 4', 'Split 32 from N and 4 from K, better for neon gevm dotprod:' | |||||
'N/32, K/4, 32(n), 4(k)') | |||||
) | |||||
) | ) | ||||
(pdef('SVD'). | (pdef('SVD'). | ||||
@@ -127,6 +127,37 @@ public: | |||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8) | ||||
}; | }; | ||||
class ConvBiasImpl::AlgoDotS8DirectChanWiseLarge final : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "ARMDOTS8_DIRECT_CHANWISE_LARGE"; } | |||||
bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) | |||||
const override; | |||||
size_t get_workspace(const NCBKernSizeParam&) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
const NCBKernSizeParam& param) const override; | |||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8) | |||||
}; | |||||
class ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge final : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "ARMDOTS8_IM2COL_CHANWISE_LARGE"; } | |||||
bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) | |||||
const override; | |||||
size_t get_workspace(const NCBKernSizeParam&) const override; | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | |||||
const NCBKernSizeParam& param) const override; | |||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8) | |||||
}; | |||||
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | ||||
@@ -0,0 +1,270 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <arm_neon.h> | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | |||||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#if MGB_ENABLE_DOT | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||||
namespace { | |||||
class DirectConvRunner { | |||||
public: | |||||
DirectConvRunner(size_t flt_size, size_t stride) { | |||||
if (flt_size == 9 && stride == 1) { | |||||
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16; | |||||
} else { | |||||
megdnn_assert(flt_size == 9 && stride == 2); | |||||
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16; | |||||
} | |||||
} | |||||
size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
auto&& fm = param.filter_meta; | |||||
auto FW = fm.spatial[1]; | |||||
return round_up((size_t)FW, m_block_k); | |||||
} | |||||
size_t get_round_iw(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
auto&& fm = param.filter_meta; | |||||
size_t SW = fm.stride[1]; | |||||
size_t OW = param.osz[1]; | |||||
size_t round_ow = round_up(OW, m_block_ow); | |||||
size_t round_fw = get_round_fw(param); | |||||
size_t pad_iw = round_ow * SW - SW + round_fw; | |||||
return round_up(pad_iw, m_align_iw); | |||||
} | |||||
size_t get_round_ih(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
auto&& fm = param.filter_meta; | |||||
size_t SH = fm.stride[0]; | |||||
size_t OH = param.osz[0]; | |||||
auto FH = fm.spatial[0]; | |||||
size_t round_oh = round_up(OH, m_block_oh); | |||||
return round_oh * SH - SH + FH; | |||||
} | |||||
WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
size_t round_filter = get_round_fw(param) * FH; | |||||
size_t round_ih = get_round_ih(param); | |||||
size_t round_iw = get_round_iw(param); | |||||
size_t pad_src = round_iw * round_ih; | |||||
return {nullptr, {pad_src, round_filter}}; | |||||
} | |||||
WorkspaceBundle get_total_bundle( | |||||
const ConvBiasImpl::NCBKernSizeParam& param) const { | |||||
auto sub_bundle = get_sub_bundle(param); | |||||
auto sub_bundle_size = sub_bundle.total_size_in_bytes(); | |||||
size_t nr_threads = param.nr_threads; | |||||
SmallVector<size_t> sizes_in_bytes; | |||||
for (size_t i = 0; i < nr_threads; ++i) { | |||||
sizes_in_bytes.push_back(sub_bundle_size); | |||||
} | |||||
WorkspaceBundle total_bundle(nullptr, sizes_in_bytes); | |||||
return total_bundle; | |||||
} | |||||
void run( | |||||
const int8_t* pad_src_ptr, const int8_t* round_filter_ptr, int32_t bias, | |||||
int8_t* dst_ptr, size_t OH, size_t OW, size_t pad_iw, float scale, | |||||
int8_t relu_val) const { | |||||
const size_t ow_end = OW / m_block_ow * m_block_ow; | |||||
const size_t ow_remain = OW - ow_end; | |||||
const size_t oh_end = OH / m_block_oh * m_block_oh; | |||||
const size_t oh_remain = OH - oh_end; | |||||
int8_t cache[4 * 16]; | |||||
for (size_t oh = 0; oh < oh_end; oh += m_block_oh) { | |||||
for (size_t ow = 0; ow < ow_end; ow += m_block_ow) { | |||||
m_func(pad_src_ptr, round_filter_ptr, bias, dst_ptr, oh, ow, OH, OW, | |||||
pad_iw, scale, relu_val); | |||||
} | |||||
if (ow_remain > 0) { | |||||
m_func(pad_src_ptr, round_filter_ptr, bias, | |||||
&cache[0] - (oh * m_block_ow + ow_end), oh, ow_end, OH, | |||||
m_block_ow, pad_iw, scale, relu_val); | |||||
for (size_t i = 0; i < m_block_oh; ++i) { | |||||
for (size_t j = 0; j < ow_remain; ++j) { | |||||
dst_ptr[(i + oh) * OW + (j + ow_end)] = cache[i * 16 + j]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
if (oh_remain > 0) { | |||||
for (size_t ow = 0; ow < ow_end; ow += m_block_ow) { | |||||
m_func(pad_src_ptr, round_filter_ptr, bias, | |||||
&cache[0] - (oh_end * m_block_ow + ow), oh_end, ow, OH, | |||||
m_block_ow, pad_iw, scale, relu_val); | |||||
for (size_t i = 0; i < oh_remain; ++i) { | |||||
for (size_t j = 0; j < m_block_ow; ++j) { | |||||
dst_ptr[(i + oh_end) * OW + (j + ow)] = cache[i * 16 + j]; | |||||
} | |||||
} | |||||
} | |||||
if (ow_remain > 0) { | |||||
m_func(pad_src_ptr, round_filter_ptr, bias, | |||||
&cache[0] - (oh_end * m_block_ow + ow_end), oh_end, ow_end, OH, | |||||
m_block_ow, pad_iw, scale, relu_val); | |||||
for (size_t i = 0; i < oh_remain; ++i) { | |||||
for (size_t j = 0; j < ow_remain; ++j) { | |||||
dst_ptr[(i + oh_end) * OW + (j + ow_end)] = cache[i * 16 + j]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
private: | |||||
std::function<void( | |||||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, | |||||
size_t oh, size_t ow, size_t OH, size_t OW, size_t pad_iw, | |||||
const float scale, int8_t relu_val)> | |||||
m_func; | |||||
size_t m_block_oh{4}; | |||||
size_t m_block_ow{16}; | |||||
size_t m_block_k{4}; | |||||
size_t m_align_iw{16}; | |||||
}; | |||||
void do_conv( | |||||
const WorkspaceBundle& bundle, const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index, const DirectConvRunner& runner) { | |||||
auto&& fm = kern_param.filter_meta; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t FH = fm.spatial[0]; | |||||
size_t FW = fm.spatial[1]; | |||||
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
float scale_dst_div = 1.f / scale_dst; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
int8_t* pad_src_ptr = static_cast<int8_t*>(bundle.get(0)); | |||||
int8_t* round_filter_ptr = static_cast<int8_t*>(bundle.get(1)); | |||||
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | |||||
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | |||||
const int8_t* fptr = kern_param.filter<dt_int8>(group_id); | |||||
void* dst = kern_param.dst<void>(batch_id, group_id); | |||||
size_t pad_iw = runner.get_round_iw(kern_param); | |||||
memset(pad_src_ptr, 0, bundle.get_size(0)); | |||||
rep(ih, IH) { | |||||
std::memcpy( | |||||
pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW, | |||||
sizeof(int8_t) * IW); | |||||
} | |||||
memset(round_filter_ptr, 0, bundle.get_size(1)); | |||||
size_t round_fw = runner.get_round_fw(kern_param); | |||||
for (size_t fh = 0; fh < FH; ++fh) { | |||||
std::memcpy(round_filter_ptr + fh * round_fw, fptr + fh * FW, FW); | |||||
} | |||||
int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128; | |||||
int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr; | |||||
int8_t* dst_ptr = (int8_t*)dst; | |||||
runner.run( | |||||
pad_src_ptr, round_filter_ptr, bias_val, dst_ptr, OH, OW, pad_iw, | |||||
scale_bias * scale_dst_div, relu_val); | |||||
} | |||||
} // namespace | |||||
bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable( | |||||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||||
if (!cpuinfo_has_arm_neon_dot()) { | |||||
return false; | |||||
} | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto FW = fm.spatial[1]; | |||||
auto SH = fm.stride[0]; | |||||
auto SW = fm.stride[1]; | |||||
auto noline = param.nonlineMode; | |||||
auto bias_mode = param.bias_mode; | |||||
bool avaible = | |||||
//! src and filter are qint8, dst is qint8 or qint32 | |||||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8)) && | |||||
fm.format == param::Convolution::Format::NCHW && !fm.should_flip && | |||||
(noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) && | |||||
(bias_mode == BiasMode::NO_BIAS || | |||||
bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && | |||||
fm.ocpg == 1; | |||||
return avaible; | |||||
} | |||||
size_t ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::get_workspace( | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN( | |||||
megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel, | |||||
midout_iv("AlgoDotS8DirectChanWiseLarge::get_workspace"_hash)) { | |||||
auto&& fm = param.filter_meta; | |||||
DirectConvRunner runner(fm.spatial[0], fm.stride[0]); | |||||
auto total_bundle = runner.get_total_bundle(param); | |||||
return total_bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotS8DirectChanWiseLarge:: | |||||
dispatch_kerns(const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN( | |||||
megdnn_arm_common_conv_bias_int8_direct_dot_large_kernel, | |||||
midout_iv("AlgoDotS8DirectChanWiseLarge::dispatch_kerns"_hash)) { | |||||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
auto&& fm = param.filter_meta; | |||||
DirectConvRunner runner(fm.spatial[0], fm.stride[0]); | |||||
WorkspaceBundle wbundle = runner.get_sub_bundle(param); | |||||
WorkspaceBundle total_bundle = runner.get_total_bundle(param); | |||||
auto exec_one_group = [wbundle, total_bundle, runner]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
WorkspaceBundle temp_total_bundle = total_bundle; | |||||
temp_total_bundle.set(kern_param.workspace_ptr); | |||||
WorkspaceBundle temp_bundle = wbundle; | |||||
temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id)); | |||||
do_conv(temp_bundle, kern_param, ncb_index, runner); | |||||
}; | |||||
size_t N = param.n; | |||||
size_t group = fm.group; | |||||
ret_kerns.push_back({exec_one_group, {N, group}}); | |||||
return ret_kerns; | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
#endif |
@@ -0,0 +1,425 @@ | |||||
/** | |||||
* \file dnn/src/arm_common/conv_bias/int8/chanwise_im2col_dot.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/arch.h" | |||||
#if MGB_ENABLE_DOT | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | |||||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel) | |||||
using namespace megdnn; | |||||
using namespace arm_common; | |||||
using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; | |||||
using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; | |||||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||||
namespace { | |||||
constexpr size_t block_n = 32; | |||||
constexpr size_t block_k = 4; | |||||
WorkspaceBundle get_sub_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
auto&& fm = param.filter_meta; | |||||
auto OH = param.osz[0]; | |||||
auto OW = param.osz[1]; | |||||
size_t IH = param.isz[0]; | |||||
size_t IW = param.isz[1]; | |||||
auto FH = fm.spatial[0]; | |||||
auto FW = fm.spatial[1]; | |||||
size_t PH = param.filter_meta.padding[0]; | |||||
size_t PW = param.filter_meta.padding[1]; | |||||
size_t round_ohw = round_up((size_t)OH * OW, block_n); | |||||
size_t round_filter = round_up((size_t)FW, block_k) * FH; | |||||
size_t pad_src = (IW + PW * 2) * (IH + PH * 2); | |||||
return {nullptr, | |||||
{pad_src, round_filter, round_ohw * round_filter, | |||||
round_ohw * sizeof(int32_t)}}; | |||||
} | |||||
WorkspaceBundle get_total_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { | |||||
auto sub_bundle = get_sub_bundle(param); | |||||
auto sub_bundle_size = sub_bundle.total_size_in_bytes(); | |||||
size_t nr_threads = param.nr_threads; | |||||
SmallVector<size_t> sizes_in_bytes; | |||||
for (size_t i = 0; i < nr_threads; ++i) { | |||||
sizes_in_bytes.push_back(sub_bundle_size); | |||||
} | |||||
WorkspaceBundle total_bundle(nullptr, sizes_in_bytes); | |||||
return total_bundle; | |||||
} | |||||
template <size_t flt_size, size_t stride> | |||||
void im2col( | |||||
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||||
size_t round_filter) { | |||||
constexpr size_t FH = flt_size; | |||||
constexpr size_t FW = flt_size; | |||||
constexpr size_t SH = stride; | |||||
constexpr size_t SW = stride; | |||||
constexpr size_t FW_ROUND = (FW + 3) / 4 * 4; | |||||
int bn = 0; | |||||
int ni = 0; | |||||
for (size_t oh = 0; oh < OH; ++oh) | |||||
for (size_t ow = 0; ow < OW; ++ow) { | |||||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
int bk = 0; | |||||
int ki = 0; | |||||
for (size_t fh = 0; fh < FH; ++fh) | |||||
for (size_t fw = 0; fw < FW_ROUND; ++fw) { | |||||
dst[bn * block_n * round_filter + bk * block_n * block_k + | |||||
ni * block_k + ki] = src_n[fh * pad_iw + fw]; | |||||
++ki; | |||||
if (ki == block_k) { | |||||
ki = 0; | |||||
++bk; | |||||
} | |||||
} | |||||
++ni; | |||||
if (ni == block_n) { | |||||
ni = 0; | |||||
++bn; | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void im2col<9, 1>( | |||||
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||||
size_t round_filter) { | |||||
constexpr size_t FH = 9; | |||||
constexpr size_t SH = 1; | |||||
constexpr size_t SW = 1; | |||||
constexpr size_t k_block_stride = block_k * block_n; | |||||
constexpr size_t ow_block = 16; | |||||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||||
2, 3, 4, 5, 3, 4, 5, 6}; | |||||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, | |||||
6, 7, 8, 9, 7, 8, 9, 10}; | |||||
static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, | |||||
10, 11, 12, 13, 11, 12, 13, 14}; | |||||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||||
uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); | |||||
int bn = 0; | |||||
int ni = 0; | |||||
for (size_t oh = 0; oh < OH; ++oh) | |||||
for (size_t ow = 0; ow < OW;) { | |||||
if (ow + ow_block <= OW && ni + ow_block <= block_n) { | |||||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||||
for (size_t fh = 0; fh < FH; ++fh) { | |||||
int8x16_t read_w[2]; | |||||
read_w[0] = vld1q_s8(src_n); | |||||
read_w[1] = vld1q_s8(src_n + 16); | |||||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||||
int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); | |||||
int8x16_t ncdef_0 = | |||||
vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); | |||||
int8x16_t n0123_1 = n4567_0; | |||||
int8x16_t n4567_1 = n89ab_0; | |||||
int8x16_t n89ab_1 = ncdef_0; | |||||
int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||||
int8x16_t n0123_2 = n89ab_0; | |||||
int8x16_t n4567_2 = ncdef_0; | |||||
int8x16_t n89ab_2 = ncdef_1; | |||||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
vst1q_s8(dst_n + 0 * 16, n0123_0); | |||||
vst1q_s8(dst_n + 1 * 16, n4567_0); | |||||
vst1q_s8(dst_n + 2 * 16, n89ab_0); | |||||
vst1q_s8(dst_n + 3 * 16, ncdef_0); | |||||
vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1); | |||||
vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1); | |||||
vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1); | |||||
vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1); | |||||
vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2); | |||||
vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2); | |||||
vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2); | |||||
vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2); | |||||
dst_n += 3 * k_block_stride; | |||||
src_n += pad_iw; | |||||
} | |||||
ni += ow_block; | |||||
ow += ow_block; | |||||
if (ni == block_n) { | |||||
ni = 0; | |||||
++bn; | |||||
} | |||||
} else { | |||||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||||
for (size_t fh = 0; fh < FH; ++fh) { | |||||
int8x16_t read_w[0]; | |||||
read_w[0] = vld1q_s8(src_n); | |||||
vst1q_lane_s32(dst_n, read_w[0], 0); | |||||
vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1); | |||||
vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2); | |||||
dst_n += 3 * k_block_stride; | |||||
src_n += pad_iw; | |||||
} | |||||
++ni; | |||||
++ow; | |||||
if (ni == block_n) { | |||||
ni = 0; | |||||
++bn; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <> | |||||
void im2col<9, 2>( | |||||
const int8_t* src, int8_t* dst, size_t OH, size_t OW, size_t pad_iw, | |||||
size_t round_filter) { | |||||
constexpr size_t FH = 9; | |||||
constexpr size_t SH = 2; | |||||
constexpr size_t SW = 2; | |||||
constexpr size_t k_block_stride = block_k * block_n; | |||||
constexpr size_t ow_block = 16; | |||||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, | |||||
4, 5, 6, 7, 6, 7, 8, 9}; | |||||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, | |||||
8, 9, 10, 11, 10, 11, 12, 13}; | |||||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||||
int bn = 0; | |||||
int ni = 0; | |||||
for (size_t oh = 0; oh < OH; ++oh) | |||||
for (size_t ow = 0; ow < OW;) { | |||||
if (ow + ow_block <= OW && ni + ow_block <= block_n) { | |||||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||||
for (size_t fh = 0; fh < FH; ++fh) { | |||||
int8x16_t read_w[3]; | |||||
read_w[0] = vld1q_s8(src_n); | |||||
read_w[1] = vld1q_s8(src_n + 16); | |||||
read_w[2] = vld1q_s8(src_n + 32); | |||||
int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); | |||||
int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); | |||||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||||
int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); | |||||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||||
int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); | |||||
int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||||
int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); | |||||
int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); | |||||
int8x16_t n0123_2 = n4567_0; | |||||
int8x16_t n4567_2 = n89ab_0; | |||||
int8x16_t n89ab_2 = ncdef_0; | |||||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||||
vst1q_s8(dst_n + 0 * 16, n0123_0); | |||||
vst1q_s8(dst_n + 1 * 16, n4567_0); | |||||
vst1q_s8(dst_n + 2 * 16, n89ab_0); | |||||
vst1q_s8(dst_n + 3 * 16, ncdef_0); | |||||
vst1q_s8(dst_n + 1 * k_block_stride + 0 * 16, n0123_1); | |||||
vst1q_s8(dst_n + 1 * k_block_stride + 1 * 16, n4567_1); | |||||
vst1q_s8(dst_n + 1 * k_block_stride + 2 * 16, n89ab_1); | |||||
vst1q_s8(dst_n + 1 * k_block_stride + 3 * 16, ncdef_1); | |||||
vst1q_s8(dst_n + 2 * k_block_stride + 0 * 16, n0123_2); | |||||
vst1q_s8(dst_n + 2 * k_block_stride + 1 * 16, n4567_2); | |||||
vst1q_s8(dst_n + 2 * k_block_stride + 2 * 16, n89ab_2); | |||||
vst1q_s8(dst_n + 2 * k_block_stride + 3 * 16, ncdef_2); | |||||
dst_n += 3 * k_block_stride; | |||||
src_n += pad_iw; | |||||
} | |||||
ni += ow_block; | |||||
ow += ow_block; | |||||
if (ni == block_n) { | |||||
ni = 0; | |||||
++bn; | |||||
} | |||||
} else { | |||||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
int8_t* dst_n = dst + bn * block_n * round_filter + ni * block_k; | |||||
for (size_t fh = 0; fh < FH; ++fh) { | |||||
int8x16_t read_w[0]; | |||||
read_w[0] = vld1q_s8(src_n); | |||||
vst1q_lane_s32(dst_n, read_w[0], 0); | |||||
vst1q_lane_s32(dst_n + 1 * k_block_stride, read_w[0], 1); | |||||
vst1q_lane_s32(dst_n + 2 * k_block_stride, read_w[0], 2); | |||||
dst_n += 3 * k_block_stride; | |||||
src_n += pad_iw; | |||||
} | |||||
++ni; | |||||
++ow; | |||||
if (ni == block_n) { | |||||
ni = 0; | |||||
++bn; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void do_conv( | |||||
const WorkspaceBundle& bundle, const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) { | |||||
auto&& fm = kern_param.filter_meta; | |||||
size_t PH = kern_param.filter_meta.padding[0]; | |||||
size_t PW = kern_param.filter_meta.padding[1]; | |||||
size_t OH = kern_param.osz[0]; | |||||
size_t OW = kern_param.osz[1]; | |||||
size_t IH = kern_param.isz[0]; | |||||
size_t IW = kern_param.isz[1]; | |||||
size_t FH = fm.spatial[0]; | |||||
size_t FW = fm.spatial[1]; | |||||
size_t SH = fm.stride[0]; | |||||
float scale_bias = kern_param.bias_type.param<dtype::QuantizedS32>().scale; | |||||
float scale_dst = kern_param.dst_type.param<dtype::QuantizedS8>().scale; | |||||
float scale_dst_div = 1.f / scale_dst; | |||||
size_t batch_id = ncb_index.ndrange_id[0]; | |||||
size_t group_id = ncb_index.ndrange_id[1]; | |||||
int8_t* pad_src_ptr = static_cast<int8_t*>(bundle.get(0)); | |||||
int8_t* round_filter_ptr = static_cast<int8_t*>(bundle.get(1)); | |||||
int8_t* im2col_ptr = static_cast<int8_t*>(bundle.get(2)); | |||||
int32_t* i32_ptr = static_cast<int32_t*>(bundle.get(3)); | |||||
const int8_t* sptr = kern_param.src<dt_int8>(batch_id, group_id); | |||||
const int32_t* bptr = kern_param.bias<dt_int32>(batch_id, group_id); | |||||
const int8_t* fptr = kern_param.filter<dt_int8>(group_id); | |||||
void* dst = kern_param.dst<void>(batch_id, group_id); | |||||
size_t round_filter = round_up(FW, block_k) * FH; | |||||
size_t pad_iw = IW + 2 * PW; | |||||
memset(pad_src_ptr, 0, bundle.get_size(0)); | |||||
rep(ih, IH) { | |||||
std::memcpy( | |||||
pad_src_ptr + (ih + PH) * pad_iw + PW, sptr + ih * IW, | |||||
sizeof(int8_t) * IW); | |||||
} | |||||
memset(round_filter_ptr, 0, bundle.get_size(1)); | |||||
size_t fh_stride = round_up(FW, block_k); | |||||
for (size_t fh = 0; fh < FH; ++fh) { | |||||
std::memcpy(round_filter_ptr + fh * fh_stride, fptr + fh * FW, FW); | |||||
} | |||||
memset(im2col_ptr, 0, bundle.get_size(2)); | |||||
if (SH == 1) { | |||||
im2col<9, 1>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter); | |||||
} else { | |||||
im2col<9, 2>(pad_src_ptr, im2col_ptr, OH, OW, pad_iw, round_filter); | |||||
} | |||||
gevm_naive_n32k4_dot( | |||||
round_filter_ptr, im2col_ptr, i32_ptr, 1, OH * OW, round_filter, 0, 0, 0); | |||||
int32_t bias_val = kern_param.bias_mode == BiasMode::NO_BIAS ? 0 : *bptr; | |||||
int8_t relu_val = kern_param.nonlineMode == NonlineMode::RELU ? 0 : -128; | |||||
int8_t* dst_ptr = (int8_t*)dst; | |||||
for (size_t i = 0; i < OH * OW; ++i) { | |||||
//! optimize by tbl | |||||
int val = roundf(scale_bias * scale_dst_div * (i32_ptr[i] + bias_val)); | |||||
val = val < -128 ? -128 : val; | |||||
val = val > 127 ? 127 : val; | |||||
val = val > relu_val ? val : relu_val; | |||||
dst_ptr[i] = val; | |||||
} | |||||
} | |||||
} // namespace | |||||
bool ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::usable( | |||||
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||||
if (!cpuinfo_has_arm_neon_dot()) { | |||||
return false; | |||||
} | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto FW = fm.spatial[1]; | |||||
auto SH = fm.stride[0]; | |||||
auto SW = fm.stride[1]; | |||||
auto noline = param.nonlineMode; | |||||
auto bias_mode = param.bias_mode; | |||||
bool avaible = | |||||
//! src and filter are qint8, dst is qint8 or qint32 | |||||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
(param.dst_type.enumv() == DTypeEnum::QuantizedS8)) && | |||||
fm.format == param::Convolution::Format::NCHW && !fm.should_flip && | |||||
(noline == NonlineMode::IDENTITY || noline == NonlineMode::RELU) && | |||||
(bias_mode == BiasMode::NO_BIAS || | |||||
bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && | |||||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && | |||||
fm.ocpg == 1; | |||||
return avaible; | |||||
} | |||||
size_t ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge::get_workspace( | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN( | |||||
megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel, | |||||
midout_iv("AlgoDotS8Im2colChanWiseLarge::get_workspace"_hash)) { | |||||
auto bundle = get_total_bundle(param); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge:: | |||||
dispatch_kerns(const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN( | |||||
megdnn_arm_common_conv_bias_int8_im2col_dot_large_kernel, | |||||
midout_iv("AlgoDotS8Im2colChanWiseLarge::dispatch_kerns"_hash)) { | |||||
SmallVector<ConvBiasImpl::NCBKern> ret_kerns; | |||||
auto fm = param.filter_meta; | |||||
size_t N = param.n; | |||||
size_t group = fm.group; | |||||
WorkspaceBundle wbundle = get_sub_bundle(param); | |||||
WorkspaceBundle total_bundle = get_total_bundle(param); | |||||
auto exec_one_group = [wbundle, total_bundle]( | |||||
const NCBKernParam& kern_param, | |||||
const NCBKernIndex& ncb_index) mutable { | |||||
WorkspaceBundle temp_total_bundle = total_bundle; | |||||
temp_total_bundle.set(kern_param.workspace_ptr); | |||||
WorkspaceBundle temp_bundle = wbundle; | |||||
temp_bundle.set(temp_total_bundle.get(ncb_index.thread_id)); | |||||
do_conv(temp_bundle, kern_param, ncb_index); | |||||
}; | |||||
ret_kerns.push_back({exec_one_group, {N, group}}); | |||||
return ret_kerns; | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | |||||
#endif |
@@ -0,0 +1,27 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/arch.h" | |||||
#if MGB_ENABLE_DOT | |||||
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16( | |||||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||||
int8_t relu_val); | |||||
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( | |||||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||||
int8_t relu_val); | |||||
#endif |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/arch.h" | |||||
#if MGB_ENABLE_DOT | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
static inline void quant_store_s8( | |||||
float32x4_t v0, float32x4_t v1, float32x4_t v2, float32x4_t v3, int8_t* ptr, | |||||
int8x16_t relu_reg) { | |||||
int32x4_t i0 = vcvtaq_s32_f32(v0); | |||||
int32x4_t i1 = vcvtaq_s32_f32(v1); | |||||
int32x4_t i2 = vcvtaq_s32_f32(v2); | |||||
int32x4_t i3 = vcvtaq_s32_f32(v3); | |||||
int16x4_t i16_0 = vqmovn_s32(i0); | |||||
int16x4_t i16_1 = vqmovn_s32(i1); | |||||
int16x4_t i16_2 = vqmovn_s32(i2); | |||||
int16x4_t i16_3 = vqmovn_s32(i3); | |||||
int8x8_t i8_0 = vqmovn_s16(vcombine_s16(i16_0, i16_1)); | |||||
int8x8_t i8_1 = vqmovn_s16(vcombine_s16(i16_2, i16_3)); | |||||
int8x16_t rst = vcombine_s8(i8_0, i8_1); | |||||
rst = vmaxq_s8(rst, relu_reg); | |||||
vst1q_s8(ptr, rst); | |||||
} | |||||
#endif |
@@ -0,0 +1,221 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s1.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/arch.h" | |||||
#if MGB_ENABLE_DOT | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" | |||||
#include "src/common/unroll_macro.h" | |||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16( | |||||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||||
int8_t relu_val) { | |||||
//! 4x16 | |||||
const size_t SH = 1; | |||||
const size_t SW = 1; | |||||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||||
2, 3, 4, 5, 3, 4, 5, 6}; | |||||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, | |||||
6, 7, 8, 9, 7, 8, 9, 10}; | |||||
static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, | |||||
10, 11, 12, 13, 11, 12, 13, 14}; | |||||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||||
uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); | |||||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
//! init | |||||
int32x4_t c[4][4]; | |||||
#define cb(step) \ | |||||
c[step][0] = vdupq_n_s32(bias); \ | |||||
c[step][1] = vdupq_n_s32(bias); \ | |||||
c[step][2] = vdupq_n_s32(bias); \ | |||||
c[step][3] = vdupq_n_s32(bias); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
int8x16_t flt[4]; | |||||
flt[0] = vld1q_s8(weight + 0 * 16); | |||||
flt[1] = vld1q_s8(weight + 1 * 16); | |||||
flt[2] = vld1q_s8(weight + 2 * 16); | |||||
flt[3] = vld1q_s8(weight + 3 * 16); | |||||
//! row 0 | |||||
int8x16_t read_w[2]; | |||||
read_w[0] = vld1q_s8(src_n + 0 * pad_iw); | |||||
read_w[1] = vld1q_s8(src_n + 0 * pad_iw + 16); | |||||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||||
int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); | |||||
int8x16_t ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); | |||||
int8x16_t n0123_1 = n4567_0; | |||||
int8x16_t n4567_1 = n89ab_0; | |||||
int8x16_t n89ab_1 = ncdef_0; | |||||
int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||||
int8x16_t n0123_2 = n89ab_0; | |||||
int8x16_t n4567_2 = ncdef_0; | |||||
int8x16_t n89ab_2 = ncdef_1; | |||||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
#define CAL_C(oh, flt_start) \ | |||||
c[oh][0] = vdotq_laneq_s32( \ | |||||
c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||||
c[oh][1] = vdotq_laneq_s32( \ | |||||
c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||||
c[oh][2] = vdotq_laneq_s32( \ | |||||
c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||||
c[oh][3] = vdotq_laneq_s32( \ | |||||
c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % 4], (flt_start + 0) % 4); \ | |||||
c[oh][0] = vdotq_laneq_s32( \ | |||||
c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||||
c[oh][1] = vdotq_laneq_s32( \ | |||||
c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||||
c[oh][2] = vdotq_laneq_s32( \ | |||||
c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||||
c[oh][3] = vdotq_laneq_s32( \ | |||||
c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % 4], (flt_start + 1) % 4); \ | |||||
c[oh][0] = vdotq_laneq_s32( \ | |||||
c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||||
c[oh][1] = vdotq_laneq_s32( \ | |||||
c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||||
c[oh][2] = vdotq_laneq_s32( \ | |||||
c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); \ | |||||
c[oh][3] = vdotq_laneq_s32( \ | |||||
c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % 4], (flt_start + 2) % 4); | |||||
CAL_C(0, 0); | |||||
//! row 1 | |||||
#define LOAD_SRC(row_id) \ | |||||
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ | |||||
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ | |||||
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ | |||||
n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ | |||||
n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \ | |||||
ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \ | |||||
n0123_1 = n4567_0; \ | |||||
n4567_1 = n89ab_0; \ | |||||
n89ab_1 = ncdef_0; \ | |||||
ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ | |||||
n0123_2 = n89ab_0; \ | |||||
n4567_2 = ncdef_0; \ | |||||
n89ab_2 = ncdef_1; \ | |||||
ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
LOAD_SRC(1); | |||||
CAL_C(0, 3); | |||||
CAL_C(1, 0); | |||||
//! row 2 | |||||
LOAD_SRC(2); | |||||
CAL_C(0, 3 * 2); | |||||
CAL_C(1, 3 * 1); | |||||
CAL_C(2, 3 * 0); | |||||
//! row 3 | |||||
LOAD_SRC(3); | |||||
CAL_C(0, 3 * 3); | |||||
CAL_C(1, 3 * 2); | |||||
CAL_C(2, 3 * 1); | |||||
CAL_C(3, 3 * 0); | |||||
//! row 4 | |||||
LOAD_SRC(4); | |||||
CAL_C(0, 3 * 4); | |||||
CAL_C(1, 3 * 3); | |||||
CAL_C(2, 3 * 2); | |||||
CAL_C(3, 3 * 1); | |||||
//! update flt 4 -> 0 | |||||
flt[0] = vld1q_s8(weight + 4 * 16); | |||||
//! row 5 | |||||
LOAD_SRC(5); | |||||
CAL_C(0, 3 * 5); | |||||
CAL_C(1, 3 * 4); | |||||
CAL_C(2, 3 * 3); | |||||
CAL_C(3, 3 * 2); | |||||
//! update flt 5 -> 1 | |||||
flt[1] = vld1q_s8(weight + 5 * 16); | |||||
//! row 6 | |||||
LOAD_SRC(6); | |||||
CAL_C(0, 3 * 6); | |||||
CAL_C(1, 3 * 5); | |||||
CAL_C(2, 3 * 4); | |||||
CAL_C(3, 3 * 3); | |||||
//! update flt 6 -> 2 | |||||
flt[2] = vld1q_s8(weight + 6 * 16); | |||||
//! row 7 | |||||
LOAD_SRC(7); | |||||
CAL_C(0, 3 * 7); | |||||
CAL_C(1, 3 * 6); | |||||
CAL_C(2, 3 * 5); | |||||
CAL_C(3, 3 * 4); | |||||
//! row 8 | |||||
LOAD_SRC(8); | |||||
CAL_C(0, 3 * 8); | |||||
CAL_C(1, 3 * 7); | |||||
CAL_C(2, 3 * 6); | |||||
CAL_C(3, 3 * 5); | |||||
//! row 9 | |||||
LOAD_SRC(9); | |||||
CAL_C(1, 3 * 8); | |||||
CAL_C(2, 3 * 7); | |||||
CAL_C(3, 3 * 6); | |||||
//! row 10 | |||||
LOAD_SRC(10); | |||||
CAL_C(2, 3 * 8); | |||||
CAL_C(3, 3 * 7); | |||||
//! row 11 | |||||
LOAD_SRC(11); | |||||
CAL_C(3, 3 * 8); | |||||
float32x4_t dst_reg[4][4]; | |||||
#define cb(step) \ | |||||
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ | |||||
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ | |||||
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ | |||||
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
#define cb(step) \ | |||||
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ | |||||
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ | |||||
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ | |||||
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
int8_t* dst_store = dst + oh * OW + ow; | |||||
int8x16_t relu_reg = vdupq_n_s8(relu_val); | |||||
#define cb(step) \ | |||||
quant_store_s8( \ | |||||
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ | |||||
dst_store + step * OW, relu_reg); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
} | |||||
#endif |
@@ -0,0 +1,250 @@ | |||||
/** | |||||
* \file | |||||
* dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/arch.h" | |||||
#if MGB_ENABLE_DOT | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" | |||||
#include "src/common/unroll_macro.h" | |||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( | |||||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||||
int8_t relu_val) { | |||||
//! 4x16 | |||||
const size_t SH = 2; | |||||
const size_t SW = 2; | |||||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, | |||||
4, 5, 6, 7, 6, 7, 8, 9}; | |||||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, | |||||
8, 9, 10, 11, 10, 11, 12, 13}; | |||||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||||
//! init | |||||
int32x4_t c[4][4]; | |||||
#define cb(step) \ | |||||
c[step][0] = vdupq_n_s32(bias); \ | |||||
c[step][1] = vdupq_n_s32(bias); \ | |||||
c[step][2] = vdupq_n_s32(bias); \ | |||||
c[step][3] = vdupq_n_s32(bias); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
constexpr int flt_reg = 7; | |||||
constexpr int flt_per_reg = 4; | |||||
int8x16_t flt[7]; | |||||
flt[0] = vld1q_s8(weight + 0 * 16); | |||||
flt[1] = vld1q_s8(weight + 1 * 16); | |||||
flt[2] = vld1q_s8(weight + 2 * 16); | |||||
flt[3] = vld1q_s8(weight + 3 * 16); | |||||
flt[4] = vld1q_s8(weight + 4 * 16); | |||||
flt[5] = vld1q_s8(weight + 5 * 16); | |||||
flt[6] = vld1q_s8(weight + 6 * 16); | |||||
#define CAL_C(oh, flt_start) \ | |||||
c[oh][0] = vdotq_laneq_s32( \ | |||||
c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 0) % flt_per_reg); \ | |||||
c[oh][1] = vdotq_laneq_s32( \ | |||||
c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 0) % flt_per_reg); \ | |||||
c[oh][2] = vdotq_laneq_s32( \ | |||||
c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 0) % flt_per_reg); \ | |||||
c[oh][3] = vdotq_laneq_s32( \ | |||||
c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 0) % flt_per_reg); \ | |||||
c[oh][0] = vdotq_laneq_s32( \ | |||||
c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 1) % flt_per_reg); \ | |||||
c[oh][1] = vdotq_laneq_s32( \ | |||||
c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 1) % flt_per_reg); \ | |||||
c[oh][2] = vdotq_laneq_s32( \ | |||||
c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 1) % flt_per_reg); \ | |||||
c[oh][3] = vdotq_laneq_s32( \ | |||||
c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 1) % flt_per_reg); \ | |||||
c[oh][0] = vdotq_laneq_s32( \ | |||||
c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 2) % flt_per_reg); \ | |||||
c[oh][1] = vdotq_laneq_s32( \ | |||||
c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 2) % flt_per_reg); \ | |||||
c[oh][2] = vdotq_laneq_s32( \ | |||||
c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 2) % flt_per_reg); \ | |||||
c[oh][3] = vdotq_laneq_s32( \ | |||||
c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||||
(flt_start + 2) % flt_per_reg); | |||||
#define LOAD_SRC(row_id) \ | |||||
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ | |||||
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ | |||||
read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \ | |||||
ext_8 = vextq_s8(read_w[0], read_w[1], 8); \ | |||||
ext_24 = vextq_s8(read_w[1], read_w[2], 8); \ | |||||
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ | |||||
n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \ | |||||
n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ | |||||
ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \ | |||||
n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ | |||||
n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \ | |||||
n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \ | |||||
ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \ | |||||
n0123_2 = n4567_0; \ | |||||
n4567_2 = n89ab_0; \ | |||||
n89ab_2 = ncdef_0; \ | |||||
ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||||
//! row 0 | |||||
int8x16_t read_w[3]; | |||||
read_w[0] = vld1q_s8(src_n); | |||||
read_w[1] = vld1q_s8(src_n + 16); | |||||
read_w[2] = vld1q_s8(src_n + 32); | |||||
int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); | |||||
int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); | |||||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||||
int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); | |||||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||||
int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); | |||||
int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||||
int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); | |||||
int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||||
int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); | |||||
int8x16_t n0123_2 = n4567_0; | |||||
int8x16_t n4567_2 = n89ab_0; | |||||
int8x16_t n89ab_2 = ncdef_0; | |||||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||||
CAL_C(0, 0); | |||||
//! row 1 | |||||
LOAD_SRC(1); | |||||
CAL_C(0, 3 * 1); | |||||
//! row 2 | |||||
LOAD_SRC(2); | |||||
CAL_C(0, 3 * 2); | |||||
CAL_C(1, 3 * 0); | |||||
//! row 3 | |||||
LOAD_SRC(3); | |||||
CAL_C(0, 3 * 3); | |||||
CAL_C(1, 3 * 1); | |||||
//! row 4 | |||||
LOAD_SRC(4); | |||||
CAL_C(0, 3 * 4); | |||||
CAL_C(1, 3 * 2); | |||||
CAL_C(2, 3 * 0); | |||||
//! row 5 | |||||
LOAD_SRC(5); | |||||
CAL_C(0, 3 * 5); | |||||
CAL_C(1, 3 * 3); | |||||
CAL_C(2, 3 * 1); | |||||
//! row 6 | |||||
LOAD_SRC(6); | |||||
CAL_C(0, 3 * 6); | |||||
CAL_C(1, 3 * 4); | |||||
CAL_C(2, 3 * 2); | |||||
CAL_C(3, 3 * 0); | |||||
//! row 7 | |||||
LOAD_SRC(7); | |||||
CAL_C(0, 3 * 7); | |||||
CAL_C(1, 3 * 5); | |||||
CAL_C(2, 3 * 3); | |||||
CAL_C(3, 3 * 1); | |||||
//! row 8 | |||||
LOAD_SRC(8); | |||||
CAL_C(0, 3 * 8); | |||||
CAL_C(1, 3 * 6); | |||||
CAL_C(2, 3 * 4); | |||||
CAL_C(3, 3 * 2); | |||||
//! row 9 | |||||
LOAD_SRC(9); | |||||
CAL_C(1, 3 * 7); | |||||
CAL_C(2, 3 * 5); | |||||
CAL_C(3, 3 * 3); | |||||
//! row 10 | |||||
LOAD_SRC(10); | |||||
CAL_C(1, 3 * 8); | |||||
CAL_C(2, 3 * 6); | |||||
CAL_C(3, 3 * 4); | |||||
//! row 11 | |||||
LOAD_SRC(11); | |||||
CAL_C(2, 3 * 7); | |||||
CAL_C(3, 3 * 5); | |||||
//! row 12 | |||||
LOAD_SRC(12); | |||||
CAL_C(2, 3 * 8); | |||||
CAL_C(3, 3 * 6); | |||||
//! row 13 | |||||
LOAD_SRC(13); | |||||
CAL_C(3, 3 * 7); | |||||
//! row 14 | |||||
LOAD_SRC(14); | |||||
CAL_C(3, 3 * 8); | |||||
float32x4_t dst_reg[4][4]; | |||||
#define cb(step) \ | |||||
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ | |||||
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ | |||||
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ | |||||
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
#define cb(step) \ | |||||
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ | |||||
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ | |||||
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ | |||||
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
int8_t* dst_store = dst + oh * OW + ow; | |||||
int8x16_t relu_reg = vdupq_n_s8(relu_val); | |||||
#define cb(step) \ | |||||
quant_store_s8( \ | |||||
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ | |||||
dst_store + step * OW, relu_reg); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
} | |||||
#endif |
@@ -54,6 +54,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | AlgoDotS8Direct_NCHW44 ds8_direct_nchw44; | ||||
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44; | ||||
AlgoDotS8Im2colChanWiseLarge ds8_im2col_large_chanwise; | |||||
AlgoDotS8DirectChanWiseLarge ds8_direct_large_chanwise; | |||||
#endif | #endif | ||||
AlgoI8x8x16Direct i8x8x16_direct; | AlgoI8x8x16Direct i8x8x16_direct; | ||||
@@ -75,6 +77,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
m_direct_algos.emplace_back(&ds8_direct_large_chanwise); | |||||
m_direct_algos.emplace_back(&ds8_im2col_large_chanwise); | |||||
m_direct_algos.emplace_back(&ds8_direct_stride1); | m_direct_algos.emplace_back(&ds8_direct_stride1); | ||||
m_direct_algos.emplace_back(&ds8_direct_stride2); | m_direct_algos.emplace_back(&ds8_direct_stride2); | ||||
m_direct_algos.emplace_back(&du8_direct_stride1); | m_direct_algos.emplace_back(&du8_direct_stride1); | ||||
@@ -51,6 +51,8 @@ private: | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
class AlgoDotS8DirectNCHWNCHW44; | class AlgoDotS8DirectNCHWNCHW44; | ||||
class AlgoDotS8DirectChanWiseLarge; | |||||
class AlgoDotS8Im2colChanWiseLarge; | |||||
class AlgoDotS8DirectStride1; | class AlgoDotS8DirectStride1; | ||||
class AlgoDotS8DirectStride2; | class AlgoDotS8DirectStride2; | ||||
class AlgoDotU8DirectStride1; | class AlgoDotU8DirectStride1; | ||||
@@ -143,8 +143,89 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( | |||||
const KernSizeParam&) const { | const KernSizeParam&) const { | ||||
return int8x8x32_gemv_mk4_kern; | return int8x8x32_gemv_mk4_kern; | ||||
} | } | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
namespace { | |||||
void int8x8x32_gevm_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gevm_naive_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt8x8x32GevmDot::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
if (!cpuinfo_has_arm_neon_dot()) { | |||||
return false; | |||||
} | |||||
auto M = kern_size_param.M; | |||||
bool is_dtype_ok = | |||||
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | |||||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && is_dtype_ok && | |||||
!kern_size_param.trA && !kern_size_param.trB && M == 1; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x32GevmDot::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return true; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmDot::get_kern( | |||||
const KernSizeParam&) const { | |||||
return int8x8x32_gevm_dot_kern; | |||||
} | |||||
namespace { | |||||
void int8x8x32_gevm_n32k4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gevm_dot_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gevm_naive_n32k4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
if (!cpuinfo_has_arm_neon_dot()) { | |||||
return false; | |||||
} | |||||
auto M = kern_size_param.M; | |||||
bool is_dtype_ok = | |||||
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | |||||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::N32K4_DOT && | |||||
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && M == 1; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return true; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot::get_kern( | |||||
const KernSizeParam&) const { | |||||
return int8x8x32_gevm_n32k4_dot_kern; | |||||
} | |||||
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | ||||
namespace { | namespace { | ||||
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -49,8 +49,69 @@ public: | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) | ||||
}; | }; | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
class MatrixMulImpl::AlgoInt8x8x32GevmDot : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_DOT"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||||
WorkspaceBundle get_bundle(const KernSizeParam&) const override { | |||||
return WorkspaceBundle{nullptr, {}}; | |||||
} | |||||
kern_naked_t get_kern_naked(const KernSizeParam&) const override { | |||||
megdnn_assert(0, "naked kern no impl"); | |||||
} | |||||
void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride) | |||||
const override { | |||||
megdnn_assert(0, "pack_A no impl"); | |||||
} | |||||
void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax) | |||||
const override { | |||||
megdnn_assert(0, "pack_B no impl"); | |||||
} | |||||
InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; }; | |||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_DOT) | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x32GevmN32K4Dot : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEVM; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(1, 32, 4, 2, AlgoDataType::QINT8X8X32, N32K4_DOT) | |||||
WorkspaceBundle get_bundle(const KernSizeParam&) const override { | |||||
return WorkspaceBundle{nullptr, {}}; | |||||
} | |||||
kern_naked_t get_kern_naked(const KernSizeParam&) const override { | |||||
megdnn_assert(0, "naked kern no impl"); | |||||
} | |||||
void pack_A(const KernParam& kern_param, void* out, size_t index, size_t stride) | |||||
const override { | |||||
megdnn_assert(0, "pack_A no impl"); | |||||
} | |||||
void pack_B(const KernParam& kern_param, void* out, size_t x0, size_t xmax) | |||||
const override { | |||||
megdnn_assert(0, "pack_B no impl"); | |||||
} | |||||
InnerBlockSize get_inner_block_size() const override { return {1, 32, 4}; }; | |||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT) | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
@@ -2,6 +2,7 @@ | |||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/arm_common/matrix_mul/int8/gemv.h" | #include "src/arm_common/matrix_mul/int8/gemv.h" | ||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "midout.h" | #include "midout.h" | ||||
@@ -430,5 +431,398 @@ void arm_common::gemv_like_mk4_dot( | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
} | } | ||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | |||||
namespace { | |||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
void gevm_naive_dot_impl( | |||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||||
bool load_c) { | |||||
constexpr size_t n_block = 32; | |||||
const size_t n_end = N / n_block * n_block; | |||||
const size_t n_remain = N - n_end; | |||||
constexpr size_t k_block = 4; | |||||
constexpr size_t k_block_x2 = k_block * 2; | |||||
const size_t k_end = (K / k_block_x2) * k_block_x2; | |||||
const size_t k_remain = K - k_end; | |||||
for (size_t n = 0; n < n_end; n += n_block) { | |||||
if (K < k_block_x2) { | |||||
if (!load_c) { | |||||
for (size_t i = 0; i < n_block; ++i) { | |||||
C[n + i] = 0; | |||||
} | |||||
} | |||||
for (size_t k = 0; k < K; ++k) { | |||||
for (size_t i = 0; i < n_block; ++i) { | |||||
C[n + i] += A[k] * B[k * Bstride + n + i]; | |||||
} | |||||
} | |||||
continue; | |||||
} | |||||
int32x4_t c[8]; | |||||
if (load_c) { | |||||
#define cb(step) c[step] = vld1q_s32(C + n + step * 4); | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
} else { | |||||
#define cb(step) c[step] = vdupq_n_s32(0); | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
} | |||||
int8x16_t a[2]; | |||||
a[0] = vld1q_dup_s32(A); | |||||
int8x16_t b[2][8]; | |||||
#define cb(step) \ | |||||
b[0][step * 2 + 0] = vld1q_s8(B + (0 + step) * Bstride + n); \ | |||||
b[0][step * 2 + 1] = vld1q_s8(B + (0 + step) * Bstride + n + 16); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
size_t k_buffer_end = k_end - k_block_x2; | |||||
for (size_t k = 0; k < k_buffer_end; k += k_block_x2) { | |||||
//! double buffer main | |||||
#define cb(step) \ | |||||
b[1][step * 2 + 0] = vld1q_s8(B + (k + step + k_block) * Bstride + n); \ | |||||
b[1][step * 2 + 1] = vld1q_s8(B + (k + step + k_block) * Bstride + n + 16); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
a[1] = vld1q_dup_s32(A + k + k_block); | |||||
int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]); | |||||
int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]); | |||||
int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]); | |||||
int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]); | |||||
int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||||
int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||||
int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||||
int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||||
c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]); | |||||
c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]); | |||||
c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]); | |||||
c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]); | |||||
c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]); | |||||
c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]); | |||||
c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]); | |||||
c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]); | |||||
#define cb(step) \ | |||||
b[0][step * 2 + 0] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n); \ | |||||
b[0][step * 2 + 1] = vld1q_s8(B + (k + step + k_block_x2) * Bstride + n + 16); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
a[0] = vld1q_dup_s32(A + k + k_block_x2); | |||||
ab0 = vzipq_s8(b[1][0], b[1][2]); | |||||
cd0 = vzipq_s8(b[1][4], b[1][6]); | |||||
ab1 = vzipq_s8(b[1][1], b[1][3]); | |||||
cd1 = vzipq_s8(b[1][5], b[1][7]); | |||||
abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||||
abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||||
abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||||
abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||||
c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]); | |||||
c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]); | |||||
c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]); | |||||
c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]); | |||||
c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]); | |||||
c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]); | |||||
c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]); | |||||
c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]); | |||||
} | |||||
//! double buffer remain | |||||
#define cb(step) \ | |||||
b[1][step * 2 + 0] = vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n); \ | |||||
b[1][step * 2 + 1] = \ | |||||
vld1q_s8(B + (k_buffer_end + step + k_block) * Bstride + n + 16); | |||||
UNROLL_CALL_RAW(4, cb); | |||||
#undef cb | |||||
a[1] = vld1q_dup_s32(A + k_buffer_end + k_block); | |||||
int8x16x2_t ab0 = vzipq_s8(b[0][0], b[0][2]); | |||||
int8x16x2_t cd0 = vzipq_s8(b[0][4], b[0][6]); | |||||
int8x16x2_t ab1 = vzipq_s8(b[0][1], b[0][3]); | |||||
int8x16x2_t cd1 = vzipq_s8(b[0][5], b[0][7]); | |||||
int16x8x2_t abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||||
int16x8x2_t abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||||
int16x8x2_t abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||||
int16x8x2_t abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||||
c[0] = vdotq_s32(c[0], abcd0.val[0], a[0]); | |||||
c[1] = vdotq_s32(c[1], abcd0.val[1], a[0]); | |||||
c[2] = vdotq_s32(c[2], abcd1.val[0], a[0]); | |||||
c[3] = vdotq_s32(c[3], abcd1.val[1], a[0]); | |||||
c[4] = vdotq_s32(c[4], abcd2.val[0], a[0]); | |||||
c[5] = vdotq_s32(c[5], abcd2.val[1], a[0]); | |||||
c[6] = vdotq_s32(c[6], abcd3.val[0], a[0]); | |||||
c[7] = vdotq_s32(c[7], abcd3.val[1], a[0]); | |||||
ab0 = vzipq_s8(b[1][0], b[1][2]); | |||||
cd0 = vzipq_s8(b[1][4], b[1][6]); | |||||
ab1 = vzipq_s8(b[1][1], b[1][3]); | |||||
cd1 = vzipq_s8(b[1][5], b[1][7]); | |||||
abcd0 = vzipq_s16(ab0.val[0], cd0.val[0]); | |||||
abcd1 = vzipq_s16(ab0.val[1], cd0.val[1]); | |||||
abcd2 = vzipq_s16(ab1.val[0], cd1.val[0]); | |||||
abcd3 = vzipq_s16(ab1.val[1], cd1.val[1]); | |||||
c[0] = vdotq_s32(c[0], abcd0.val[0], a[1]); | |||||
c[1] = vdotq_s32(c[1], abcd0.val[1], a[1]); | |||||
c[2] = vdotq_s32(c[2], abcd1.val[0], a[1]); | |||||
c[3] = vdotq_s32(c[3], abcd1.val[1], a[1]); | |||||
c[4] = vdotq_s32(c[4], abcd2.val[0], a[1]); | |||||
c[5] = vdotq_s32(c[5], abcd2.val[1], a[1]); | |||||
c[6] = vdotq_s32(c[6], abcd3.val[0], a[1]); | |||||
c[7] = vdotq_s32(c[7], abcd3.val[1], a[1]); | |||||
vst1q_s32(C + n + 0 * 4, c[0]); | |||||
vst1q_s32(C + n + 1 * 4, c[1]); | |||||
vst1q_s32(C + n + 2 * 4, c[2]); | |||||
vst1q_s32(C + n + 3 * 4, c[3]); | |||||
vst1q_s32(C + n + 4 * 4, c[4]); | |||||
vst1q_s32(C + n + 5 * 4, c[5]); | |||||
vst1q_s32(C + n + 6 * 4, c[6]); | |||||
vst1q_s32(C + n + 7 * 4, c[7]); | |||||
if (k_remain > 0) { | |||||
for (size_t k = k_end; k < K; ++k) { | |||||
for (size_t i = 0; i < n_block; ++i) { | |||||
C[n + i] += A[k] * B[k * Bstride + n + i]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
if (n_remain > 0) { | |||||
for (size_t n = n_end; n < N; ++n) { | |||||
if (!load_c) { | |||||
C[n] = 0; | |||||
} | |||||
for (size_t k = 0; k < K; ++k) { | |||||
C[n] += A[k] * B[k * Bstride + n]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
#if MEGDNN_ARMV7 | |||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
void gevm_naive_dot_n32k4_impl( | |||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||||
bool load_c) { | |||||
//! input must be N/32, k/4, 32, 4 | |||||
//! TODO: add prefetch | |||||
//! TODO: add double buffer | |||||
constexpr size_t n_block = 32; | |||||
constexpr size_t k_block = 4; | |||||
for (size_t n = 0; n < N; n += n_block) { | |||||
int32x4_t c[n_block / 4]; | |||||
#define cb(step) c[step] = vdupq_n_s32(0); | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
const int8_t* b_base = B + n * K; | |||||
for (size_t k = 0; k < K; k += k_block) { | |||||
int8x16_t a[1]; | |||||
int8x16_t b[1][8]; | |||||
#define cb(step) b[0][step] = vld1q_s8(b_base + k * 32 + 16 * step); | |||||
UNROLL_CALL_RAW(8, cb); | |||||
#undef cb | |||||
a[0] = vld1q_dup_s32(A + k); | |||||
c[0] = vdotq_s32(c[0], b[0][0], a[0]); | |||||
c[1] = vdotq_s32(c[1], b[0][1], a[0]); | |||||
c[2] = vdotq_s32(c[2], b[0][2], a[0]); | |||||
c[3] = vdotq_s32(c[3], b[0][3], a[0]); | |||||
c[4] = vdotq_s32(c[4], b[0][4], a[0]); | |||||
c[5] = vdotq_s32(c[5], b[0][5], a[0]); | |||||
c[6] = vdotq_s32(c[6], b[0][6], a[0]); | |||||
c[7] = vdotq_s32(c[7], b[0][7], a[0]); | |||||
} | |||||
vst1q_s32(C + n + 0 * 4, c[0]); | |||||
vst1q_s32(C + n + 1 * 4, c[1]); | |||||
vst1q_s32(C + n + 2 * 4, c[2]); | |||||
vst1q_s32(C + n + 3 * 4, c[3]); | |||||
vst1q_s32(C + n + 4 * 4, c[4]); | |||||
vst1q_s32(C + n + 5 * 4, c[5]); | |||||
vst1q_s32(C + n + 6 * 4, c[6]); | |||||
vst1q_s32(C + n + 7 * 4, c[7]); | |||||
} | |||||
} | |||||
#else | |||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
inline void n32k4_dot( | |||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t K) { | |||||
int oddk = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
//! C q0-q7 | |||||
//! A q8-q9 | |||||
//! B q10-q25 | |||||
asm volatile( | |||||
// load accumulator C | |||||
"1:\n" | |||||
"eor v0.16b, v0.16b, v0.16b\n" | |||||
"eor v1.16b, v1.16b, v1.16b\n" | |||||
"eor v2.16b, v2.16b, v2.16b\n" | |||||
"eor v3.16b, v3.16b, v3.16b\n" | |||||
"eor v4.16b, v4.16b, v4.16b\n" | |||||
"eor v5.16b, v5.16b, v5.16b\n" | |||||
"eor v6.16b, v6.16b, v6.16b\n" | |||||
"eor v7.16b, v7.16b, v7.16b\n" | |||||
"ld1r {v8.4s}, [%[a_ptr]]\n" | |||||
"ld1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[b_ptr]], 64\n" | |||||
"ld1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[b_ptr]], 64\n" | |||||
"add %[a_ptr], %[a_ptr], #4\n" | |||||
"cmp %w[k], #0\n" | |||||
"beq 4f\n" | |||||
"2: \n" | |||||
// Loop proper | |||||
"3:\n" | |||||
"ld1r {v9.4s}, [%[a_ptr]]\n" | |||||
"sdot v0.4s, v10.16b, v8.16b\n" | |||||
"ldr q18, [%[b_ptr], #0]\n" | |||||
"sdot v1.4s, v11.16b, v8.16b\n" | |||||
"ldr q19, [%[b_ptr], #16]\n" | |||||
"sdot v2.4s, v12.16b, v8.16b\n" | |||||
"ldr q20, [%[b_ptr], #32]\n" | |||||
"add %[a_ptr], %[a_ptr], #4\n" | |||||
"sdot v3.4s, v13.16b, v8.16b\n" | |||||
"ldr q21, [%[b_ptr], #48]\n" | |||||
"sdot v4.4s, v14.16b, v8.16b\n" | |||||
"ldr q22, [%[b_ptr], #64]\n" | |||||
"sdot v5.4s, v15.16b, v8.16b\n" | |||||
"ldr q23, [%[b_ptr], #80]\n" | |||||
"sdot v6.4s, v16.16b, v8.16b\n" | |||||
"ldr q24, [%[b_ptr], #96]\n" | |||||
"sdot v7.4s, v17.16b, v8.16b\n" | |||||
"ldr q25, [%[b_ptr], #112]\n" | |||||
"ld1r {v8.4s}, [%[a_ptr]]\n" | |||||
"sdot v0.4s, v18.16b, v9.16b\n" | |||||
"ldr q10, [%[b_ptr], #128]\n" | |||||
"sdot v1.4s, v19.16b, v9.16b\n" | |||||
"ldr q11, [%[b_ptr], #144]\n" | |||||
"sdot v2.4s, v20.16b, v9.16b\n" | |||||
"ldr q12, [%[b_ptr], #160]\n" | |||||
"sdot v3.4s, v21.16b, v9.16b\n" | |||||
"ldr q13, [%[b_ptr], #176]\n" | |||||
"sdot v4.4s, v22.16b, v9.16b\n" | |||||
"ldr q14, [%[b_ptr], #192]\n" | |||||
"sdot v5.4s, v23.16b, v9.16b\n" | |||||
"ldr q15, [%[b_ptr], #208]\n" | |||||
"sdot v6.4s, v24.16b, v9.16b\n" | |||||
"ldr q16, [%[b_ptr], #224]\n" | |||||
"sdot v7.4s, v25.16b, v9.16b\n" | |||||
"ldr q17, [%[b_ptr], #240]\n" | |||||
"add %[a_ptr], %[a_ptr], #4\n" | |||||
"add %[b_ptr], %[b_ptr], #256\n" | |||||
"subs %w[k], %w[k], #1\n" | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %w[oddk], #1\n" | |||||
"beq 5f\n" | |||||
// Even tail | |||||
"ld1r {v9.4s}, [%[a_ptr]]\n" | |||||
"sdot v0.4s, v10.16b, v8.16b\n" | |||||
"ldr q18, [%[b_ptr], #0]\n" | |||||
"sdot v1.4s, v11.16b, v8.16b\n" | |||||
"ldr q19, [%[b_ptr], #16]\n" | |||||
"sdot v2.4s, v12.16b, v8.16b\n" | |||||
"ldr q20, [%[b_ptr], #32]\n" | |||||
"sdot v3.4s, v13.16b, v8.16b\n" | |||||
"ldr q21, [%[b_ptr], #48]\n" | |||||
"sdot v4.4s, v14.16b, v8.16b\n" | |||||
"ldr q22, [%[b_ptr], #64]\n" | |||||
"sdot v5.4s, v15.16b, v8.16b\n" | |||||
"ldr q23, [%[b_ptr], #80]\n" | |||||
"sdot v6.4s, v16.16b, v8.16b\n" | |||||
"ldr q24, [%[b_ptr], #96]\n" | |||||
"sdot v7.4s, v17.16b, v8.16b\n" | |||||
"ldr q25, [%[b_ptr], #112]\n" | |||||
"sdot v0.4s, v18.16b, v9.16b\n" | |||||
"sdot v1.4s, v19.16b, v9.16b\n" | |||||
"sdot v2.4s, v20.16b, v9.16b\n" | |||||
"sdot v3.4s, v21.16b, v9.16b\n" | |||||
"sdot v4.4s, v22.16b, v9.16b\n" | |||||
"sdot v5.4s, v23.16b, v9.16b\n" | |||||
"sdot v6.4s, v24.16b, v9.16b\n" | |||||
"sdot v7.4s, v25.16b, v9.16b\n" | |||||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n" | |||||
"st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n" | |||||
"b 6f\n" | |||||
"5:\n" | |||||
// Odd tail | |||||
"sdot v0.4s, v10.16b, v8.16b\n" | |||||
"sdot v1.4s, v11.16b, v8.16b\n" | |||||
"sdot v2.4s, v12.16b, v8.16b\n" | |||||
"sdot v3.4s, v13.16b, v8.16b\n" | |||||
"sdot v4.4s, v14.16b, v8.16b\n" | |||||
"sdot v5.4s, v15.16b, v8.16b\n" | |||||
"sdot v6.4s, v16.16b, v8.16b\n" | |||||
"sdot v7.4s, v17.16b, v8.16b\n" | |||||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[c_ptr]], 64\n" | |||||
"st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[c_ptr]], 64\n" | |||||
"6:\n" | |||||
: [a_ptr] "+r"(A), [b_ptr] "+r"(B), [k] "+r"(K), [c_ptr] "+r"(C), | |||||
[oddk] "+r"(oddk) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
"v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
"v22", "v23", "v24", "v25", "cc", "memory"); | |||||
} | |||||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||||
void gevm_naive_dot_n32k4_impl( | |||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, | |||||
bool load_c) { | |||||
//! input must be N/32, k/4, 32, 4 | |||||
//! TODO: add prefetch | |||||
//! TODO: add double buffer | |||||
constexpr size_t n_block = 32; | |||||
for (size_t n = 0; n < N; n += n_block) { | |||||
n32k4_dot(A, B + n * K, C + n, K / 4); | |||||
} | |||||
} | |||||
#endif | |||||
} // namespace | |||||
void arm_common::gevm_naive_dot( | |||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(M == 1); | |||||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot"_hash)) { | |||||
size_t cache_size = 256 * 1024; | |||||
size_t k_group = N * K / cache_size; | |||||
constexpr size_t k_align = 8; | |||||
if (k_group >= 2) { | |||||
size_t k_per_group = ((K / k_group) + k_align - 1) / k_align * k_align; | |||||
for (size_t k = 0; k < K; k += k_per_group) { | |||||
size_t real_k = std::min(K - k, k_per_group); | |||||
gevm_naive_dot_impl( | |||||
A + k, B + k * Bstride, C, M, N, real_k, Astride, Bstride, | |||||
Cstride, k != 0); | |||||
} | |||||
} else { | |||||
gevm_naive_dot_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false); | |||||
} | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
void arm_common::gevm_naive_n32k4_dot( | |||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(M == 1); | |||||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gevm_dot_nk4"_hash)) { | |||||
gevm_naive_dot_n32k4_impl(A, B, C, M, N, K, Astride, Bstride, Cstride, false); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -22,6 +22,14 @@ void gemv_like_mk4( | |||||
void gemv_like_mk4_dot( | void gemv_like_mk4_dot( | ||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | ||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | ||||
void gevm_naive_dot( | |||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||||
void gevm_naive_n32k4_dot( | |||||
const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||||
#endif | #endif | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -14,6 +14,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | ||||
AlgoInt8x8x32GevmDot int8x8x32_gevm_dot; | |||||
AlgoInt8x8x32GevmN32K4Dot int8x8x32_gevm_n32k4_dot; | |||||
#endif | #endif | ||||
AlgoGevm gevm; | AlgoGevm gevm; | ||||
@@ -28,6 +30,8 @@ public: | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | ||||
m_all_algos.emplace_back(&int8x8x32_gevm_dot); | |||||
m_all_algos.emplace_back(&int8x8x32_gevm_n32k4_dot); | |||||
#endif | #endif | ||||
m_all_algos.emplace_back(&int8x8x32_gemv); | m_all_algos.emplace_back(&int8x8x32_gemv); | ||||
m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | ||||
@@ -31,7 +31,9 @@ protected: | |||||
class AlgoF16Gemv; | class AlgoF16Gemv; | ||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT | |||||
class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT | |||||
class AlgoInt8x8x32GevmDot; // Arm_common Int8x8x32 Gevm NCHW DOT | |||||
class AlgoInt8x8x32GevmN32K4Dot; // Arm_common Int8x8x32 Gevm NK4 | |||||
#endif | #endif | ||||
class AlgoInt8x8x16; // Arm_common Int 8x8x16 | class AlgoInt8x8x16; // Arm_common Int 8x8x16 | ||||
class AlgoPack; | class AlgoPack; | ||||
@@ -469,7 +469,7 @@ __ai float64x2_t vbitq_f64(float64x2_t dst, float64x2_t v1, uint64x2_t mask) { | |||||
#endif | #endif | ||||
#if MEGDNN_ARMV7 | #if MEGDNN_ARMV7 | ||||
__ai int8x16_t vqtbl1q_s8(int8x16_t& a, uint8x16_t& idx) { | |||||
__ai int8x16_t vqtbl1q_s8(int8x16_t a, uint8x16_t idx) { | |||||
int8x8_t src_low = vget_low_s8(a); | int8x8_t src_low = vget_low_s8(a); | ||||
int8x8_t src_high = vget_high_s8(a); | int8x8_t src_high = vget_high_s8(a); | ||||
return vcombine_s8( | return vcombine_s8( | ||||
@@ -726,6 +726,13 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { | |||||
asm volatile("fmls %0.4s, %1.4s, %2.4s\n" : "+w"(a) : "w"(b), "w"(v) :); | asm volatile("fmls %0.4s, %1.4s, %2.4s\n" : "+w"(a) : "w"(b), "w"(v) :); | ||||
return a; | return a; | ||||
} | } | ||||
#if __ARM_ARCH < 8 | |||||
__ai int32x4_t vcvtaq_s32_f32(float32x4_t val) { | |||||
float32x4_t vinc0 = vbslq_f32( | |||||
vcgeq_f32(val, vdupq_n_f32(0.f)), vdupq_n_f32(0.5f), vdupq_n_f32(-0.5f)); | |||||
return vcvtq_s32_f32(vaddq_f32(val, vinc0)); | |||||
} | |||||
#endif | |||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
#undef __ARM_FEATURE_DOTPROD | #undef __ARM_FEATURE_DOTPROD | ||||
#endif | #endif | ||||
@@ -61,6 +61,15 @@ void MatrixMulForward::deduce_layout( | |||||
"(transposed) B is (%zu,%zu)", | "(transposed) B is (%zu,%zu)", | ||||
A0, A1, B0, B1); | A0, A1, B0, B1); | ||||
C = TensorLayout(TensorShape({A0, B1}), C.dtype); | C = TensorLayout(TensorShape({A0, B1}), C.dtype); | ||||
} else if (param().format == param::MatrixMul::Format::N32K4_DOT) { | |||||
A0 = A.shape[0]; | |||||
A1 = A.shape[1]; | |||||
B0 = B.shape[0]; | |||||
B1 = B.shape[1]; | |||||
megdnn_assert(!m_param.transposeA && !m_param.transposeB); | |||||
megdnn_assert(A0 == 1 && A1 % 4 == 0); | |||||
megdnn_assert(B.ndim == 4); | |||||
C = TensorLayout(TensorShape({A0, B0 * 32}), C.dtype); | |||||
} else { | } else { | ||||
auto do_deduce = [&](size_t pack_size) { | auto do_deduce = [&](size_t pack_size) { | ||||
megdnn_assert( | megdnn_assert( | ||||
@@ -132,6 +141,18 @@ void MatrixMulForward::check_exec( | |||||
megdnn_assert(A0 == C0, "%s", errmsg().c_str()); | megdnn_assert(A0 == C0, "%s", errmsg().c_str()); | ||||
megdnn_assert(B1 == C1, "%s", errmsg().c_str()); | megdnn_assert(B1 == C1, "%s", errmsg().c_str()); | ||||
megdnn_assert(A1 == B0, "%s", errmsg().c_str()); | megdnn_assert(A1 == B0, "%s", errmsg().c_str()); | ||||
} else if (param().format == param::MatrixMul::Format::N32K4_DOT) { | |||||
size_t A0 = A.shape[0]; | |||||
size_t A1 = A.shape[1]; | |||||
size_t B2 = B.shape[2]; | |||||
size_t B3 = B.shape[3]; | |||||
megdnn_assert(!m_param.transposeA && !m_param.transposeB); | |||||
megdnn_assert(A0 == 1 && A1 % 4 == 0); | |||||
megdnn_assert(B.ndim == 4); | |||||
megdnn_assert(B2 == 32 && B3 == 4); | |||||
megdnn_assert_contiguous(A); | |||||
megdnn_assert_contiguous(B); | |||||
megdnn_assert_contiguous(C); | |||||
} else { | } else { | ||||
megdnn_assert_eq_size_t(A.ndim, 4_z); | megdnn_assert_eq_size_t(A.ndim, 4_z); | ||||
megdnn_assert_eq_size_t(B.ndim, 3_z); | megdnn_assert_eq_size_t(B.ndim, 3_z); | ||||
@@ -442,9 +442,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( | |||||
megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format)); | megdnn_assert(0, "invalid conv format %d", static_cast<int>(param().format)); | ||||
} | } | ||||
BiasMode bias_mode; | BiasMode bias_mode; | ||||
//! dst only channel BIAS is viewed as BROADCAST_CHANNEL_BIAS | |||||
bool dst_only_c = dst[0] == 1 && dst[spatial_pos] == 1 && dst[spatial_pos + 1] == 1; | |||||
if (bias.ndim == 0) { | if (bias.ndim == 0) { | ||||
bias_mode = BiasMode::NO_BIAS; | bias_mode = BiasMode::NO_BIAS; | ||||
} else if (bias.eq_shape(dst)) { | |||||
} else if (bias.eq_shape(dst) && !dst_only_c) { | |||||
bias_mode = BiasMode::BIAS; | bias_mode = BiasMode::BIAS; | ||||
} else { | } else { | ||||
//! just check the ndim, the detail shape check is in check_exec | //! just check the ndim, the detail shape check is in check_exec | ||||
@@ -258,6 +258,9 @@ public: | |||||
ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, | ARM_COMMON_CHANWISE_STRD1_NCHW44_S8, | ||||
ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, | ARM_COMMON_CHANWISE_STRD2_NCHW44_S8, | ||||
ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, | ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8, | ||||
//! LARGE for large filter | |||||
ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8, | |||||
ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8, | |||||
ARM_COMMON_DIRECT_STRD1_DOT_S8, | ARM_COMMON_DIRECT_STRD1_DOT_S8, | ||||
ARM_COMMON_DIRECT_STRD2_DOT_S8, | ARM_COMMON_DIRECT_STRD2_DOT_S8, | ||||
ARM_COMMON_DIRECT_NCHW44_DOT_S8, | ARM_COMMON_DIRECT_NCHW44_DOT_S8, | ||||
@@ -195,11 +195,11 @@ MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( | |||||
kern_size_param.trB = param().transposeB; | kern_size_param.trB = param().transposeB; | ||||
kern_size_param.compute_mode = param().compute_mode; | kern_size_param.compute_mode = param().compute_mode; | ||||
kern_size_param.format = param().format; | kern_size_param.format = param().format; | ||||
size_t pack_size = MatrixMulForward::pack_size(param().format); | |||||
kern_size_param.K *= pack_size; | |||||
kern_size_param.M *= pack_size; | |||||
if (param().format != Param::Format::N32K4_DOT) { | |||||
size_t pack_size = MatrixMulForward::pack_size(param().format); | |||||
kern_size_param.K *= pack_size; | |||||
kern_size_param.M *= pack_size; | |||||
} | |||||
return kern_size_param; | return kern_size_param; | ||||
} | } | ||||
@@ -122,6 +122,8 @@ public: | |||||
ARM_COMMON_INT8X8X32_GEMV, | ARM_COMMON_INT8X8X32_GEMV, | ||||
ARM_COMMON_INT8X8X32_GEMV_MK4, | ARM_COMMON_INT8X8X32_GEMV_MK4, | ||||
ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | ||||
ARM_COMMON_INT8X8X32_GEVM_DOT, | |||||
ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT, | |||||
ARM_COMMON_F16_GEMV, | ARM_COMMON_F16_GEMV, | ||||
ARM_COMMON_GEVM, | ARM_COMMON_GEVM, | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -175,6 +177,7 @@ public: | |||||
enum class AlgoSet : uint32_t { | enum class AlgoSet : uint32_t { | ||||
ALGO_TYPE_GEMM = 0, | ALGO_TYPE_GEMM = 0, | ||||
ALGO_TYPE_GEMV = 1, | ALGO_TYPE_GEMV = 1, | ||||
ALGO_TYPE_GEVM = 2, | |||||
}; | }; | ||||
enum class PackMode : uint32_t { | enum class PackMode : uint32_t { | ||||
@@ -105,6 +105,34 @@ void run_matrix_mul_mk4_dot_tpl( | |||||
template < | template < | ||||
typename itype, typename otype, bool transA, bool transB, | typename itype, typename otype, bool transA, bool transB, | ||||
typename comp_type = otype> | typename comp_type = otype> | ||||
void run_matrix_mul_n32k4_dot_tpl( | |||||
const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K, | |||||
size_t LDA, size_t, size_t, const DType& A_type, const DType& B_type) { | |||||
Getter<itype, comp_type> getterA(A_type), getterB(B_type); | |||||
megdnn_assert(!transA && !transB); | |||||
for (size_t m = 0; m < M; ++m) { | |||||
for (size_t n = 0; n < N; n += 32) { | |||||
comp_type res[32] = {comp_type(0)}; | |||||
for (size_t k = 0; k < K; k += 4) { | |||||
for (size_t i = 0; i < 32; i++) { | |||||
comp_type av, bv; | |||||
for (size_t j = 0; j < 4; j++) { | |||||
av = getterA(A[m * LDA + k + j]); | |||||
bv = getterA(B[n * K + k * 32 + i * 4 + j]); | |||||
res[i] += av * bv; | |||||
} | |||||
} | |||||
} | |||||
for (size_t i = 0; i < 32; i++) { | |||||
C[n + i] = res[i]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template < | |||||
typename itype, typename otype, bool transA, bool transB, | |||||
typename comp_type = otype> | |||||
void run_matrix_mul_mk8_tpl( | void run_matrix_mul_mk8_tpl( | ||||
const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K, | const itype* A, const itype* B, otype* C, size_t M, size_t N, size_t K, | ||||
size_t LDA, size_t LDB, size_t LDC, const DType& A_type, const DType& B_type) { | size_t LDA, size_t LDB, size_t LDC, const DType& A_type, const DType& B_type) { | ||||
@@ -251,6 +279,10 @@ void dispatch_ta_tb( | |||||
return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | return run_matrix_mul_mk4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | ||||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | ||||
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ | static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ | ||||
} else if (format == param::MatrixMul::Format::N32K4_DOT) { \ | |||||
return run_matrix_mul_n32k4_dot_tpl<_itype, _otype, TA, TB, _comp_type>( \ | |||||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | |||||
static_cast<_otype*>(C), M, N, K, LDA, LDB, LDC, A_type, B_type); \ | |||||
} else if (format == param::MatrixMul::Format::MK8) { \ | } else if (format == param::MatrixMul::Format::MK8) { \ | ||||
return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | return run_matrix_mul_mk8_tpl<_itype, _otype, TA, TB, _comp_type>( \ | ||||
static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | static_cast<const _itype*>(A), static_cast<const _itype*>(B), \ | ||||
@@ -160,7 +160,7 @@ static void benchmark_convbias( | |||||
.set_display(false); | .set_display(false); | ||||
} | } | ||||
auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; | auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; | ||||
#if MGB_ENBALE_DOT | |||||
#if MGB_ENABLE_DOT | |||||
if (!is_fp32) { | if (!is_fp32) { | ||||
nchw44_algo_regx = ".*DOT.*"; | nchw44_algo_regx = ".*DOT.*"; | ||||
} | } | ||||
@@ -1626,7 +1626,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { | |||||
#endif | #endif | ||||
#if MGB_ENBALE_DOT | |||||
#if MGB_ENABLE_DOT | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { | ||||
// have to remove preferred restrict in usable func before run the benchmark | // have to remove preferred restrict in usable func before run the benchmark | ||||
@@ -2011,6 +2011,80 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { | |||||
} | } | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { | |||||
using namespace conv_bias; | |||||
std::vector<TestArg> args; | |||||
auto run = [&](size_t group, size_t w, size_t h, size_t kernel, size_t stride, | |||||
NonlineMode nonline_mode) { | |||||
size_t p = kernel / 2; | |||||
if (w + 2 * p < kernel || h + 2 * p < kernel) | |||||
return; | |||||
param::ConvBias param; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
param.pad_h = p; | |||||
param.pad_w = p; | |||||
param.nonlineMode = nonline_mode; | |||||
param.format = param::ConvBias::Format::NCHW; | |||||
param.sparse = ConvBiasForward::Param::Sparse::GROUP; | |||||
//! channel bias | |||||
args.emplace_back( | |||||
param, TensorShape{1, group, h, w}, | |||||
TensorShape{group, 1, 1, kernel, kernel}, TensorShape{1, group, 1, 1}); | |||||
}; | |||||
run(64, 64, 64, 9, 1, NonlineMode::RELU); | |||||
run(64, 40, 40, 9, 2, NonlineMode::RELU); | |||||
run(64, 20, 20, 9, 1, NonlineMode::RELU); | |||||
constexpr size_t RUN = 120; | |||||
Benchmarker<ConvBias> benchmark0(handle()); | |||||
benchmark0.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.set_dtype(4, dtype::QuantizedS8(60.25f)); | |||||
benchmark0.set_display(false); | |||||
benchmark0.set_times(RUN); | |||||
benchmark0.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"ARMDOTS8_DIRECT_CHANWISE_LARGE")); | |||||
Benchmarker<ConvBias> benchmark1(handle()); | |||||
benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.set_dtype(4, dtype::QuantizedS8(60.25f)); | |||||
benchmark1.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"ARMDOTS8_IM2COL_CHANWISE_LARGE")); | |||||
benchmark1.set_display(false); | |||||
benchmark1.set_times(RUN); | |||||
for (auto&& arg : args) { | |||||
TensorLayout dst_layout; | |||||
auto opr = handle()->create_operator<ConvBias>(); | |||||
opr->param() = arg.param; | |||||
opr->deduce_layout( | |||||
{arg.src, dtype::Int8()}, {arg.filter, dtype::Int8()}, | |||||
{arg.bias, dtype::Int32()}, {}, dst_layout); | |||||
//! dst.nr_elems * FH * FW * 2 | |||||
float computations = | |||||
dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6; | |||||
auto used0 = benchmark0.set_param(arg.param).exec( | |||||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||||
RUN; | |||||
auto used1 = benchmark1.set_param(arg.param).exec( | |||||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||||
RUN; | |||||
printf("%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " | |||||
"speedup: %f\n", | |||||
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), used0, | |||||
computations / used0, used1, computations / used1, used1 / used0); | |||||
} | |||||
} | |||||
#endif | #endif | ||||
#endif | #endif | ||||
@@ -2194,7 +2268,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) { | |||||
dtype::QuantizedS8 stype(2.5f); | dtype::QuantizedS8 stype(2.5f); | ||||
dtype::QuantizedS32 dtype(6.25f); | dtype::QuantizedS32 dtype(6.25f); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
#if MGB_ENBALE_DOT | |||||
#if MGB_ENABLE_DOT | |||||
benchmark_conv1x1( | benchmark_conv1x1( | ||||
"AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | "AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | ||||
#else | #else | ||||
@@ -2212,7 +2286,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) { | |||||
dtype::QuantizedS32 dtype(1.2 * 1.2); | dtype::QuantizedS32 dtype(1.2 * 1.2); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
#if MGB_ENBALE_DOT | |||||
#if MGB_ENABLE_DOT | |||||
benchmark_conv1x1( | benchmark_conv1x1( | ||||
"AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | "AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, dtype, dtype); | ||||
#else | #else | ||||
@@ -136,6 +136,84 @@ std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args( | |||||
return args; | return args; | ||||
} | } | ||||
std::vector<conv_bias::TestArg> get_channel_wise_args( | |||||
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | |||||
bool no_full_bias, bool support_relu) { | |||||
using namespace conv_bias; | |||||
using Param = param::ConvBias; | |||||
using NLMode = param::ConvBias::NonlineMode; | |||||
std::vector<TestArg> args; | |||||
auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel, | |||||
size_t stride, NLMode nlmode, bool pad) { | |||||
Param param; | |||||
param.stride_h = stride; | |||||
param.stride_w = stride; | |||||
if (pad) { | |||||
param.pad_h = kernel / 2; | |||||
param.pad_w = kernel / 2; | |||||
} else { | |||||
param.pad_h = 0; | |||||
param.pad_w = 0; | |||||
} | |||||
param.nonlineMode = nlmode; | |||||
param.format = param::ConvBias::Format::NCHW; | |||||
param.sparse = param::ConvBias::Sparse::GROUP; | |||||
args.emplace_back( | |||||
param, TensorShape{n, group, h, w}, | |||||
TensorShape{group, 1, 1, kernel, kernel}, TensorShape{}); | |||||
if (!no_bias) { | |||||
args.emplace_back( | |||||
param, TensorShape{n, group, h, w}, | |||||
TensorShape{group, 1, 1, kernel, kernel}, | |||||
TensorShape{1, group, 1, 1}); | |||||
} | |||||
if (!no_full_bias) { | |||||
args.emplace_back( | |||||
param, TensorShape{n, group, h, w}, | |||||
TensorShape{group, 1, 1, kernel, kernel}, | |||||
TensorShape{ | |||||
n, group, (h + 2 * param.pad_w - kernel) / stride + 1, | |||||
(w + 2 * param.pad_w - kernel) / stride + 1}); | |||||
} | |||||
}; | |||||
std::vector<NLMode> nonlinemode = {NLMode::IDENTITY}; | |||||
if (!no_nonlinemode) { | |||||
nonlinemode.emplace_back(NLMode::RELU); | |||||
nonlinemode.emplace_back(NLMode::H_SWISH); | |||||
} else if (support_relu) { | |||||
nonlinemode.emplace_back(NLMode::RELU); | |||||
} | |||||
for (size_t n : {1, 2}) { | |||||
for (auto nlmode : nonlinemode) { | |||||
for (bool pad : {true}) { | |||||
for (size_t group : {1, 3, 7}) { | |||||
for (size_t size : {4, 6, 7, 9, 16, 20, 32, 55}) { | |||||
for (size_t kern : kernel) { | |||||
pack(n, group, size, size, kern, stride, nlmode, pad); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
for (bool pad : {false}) { | |||||
for (size_t group : {7}) { | |||||
for (size_t size : {37}) { | |||||
for (size_t kern : kernel) { | |||||
if (size < kern) | |||||
continue; | |||||
pack(n, group, size, size, kern, stride, nlmode, pad); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return args; | |||||
} | |||||
std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args( | std::vector<conv_bias::TestArg> get_nchw88_channel_wise_args( | ||||
std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | std::vector<size_t> kernel, size_t stride, bool no_bias, bool no_nonlinemode, | ||||
bool no_full_bias) { | bool no_full_bias) { | ||||
@@ -226,7 +304,7 @@ void checker_conv_bias_qint8x8x8( | |||||
.set_rng(1, &rng) | .set_rng(1, &rng) | ||||
.set_rng(2, &rng); | .set_rng(2, &rng); | ||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}}); | |||||
checker.set_param(arg.param).execs({arg.src, arg.filter, arg.bias, {}, {}}); | |||||
} | } | ||||
} | } | ||||
void checker_conv_bias_qint8x8x32( | void checker_conv_bias_qint8x8x32( | ||||
@@ -532,6 +610,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { | |||||
/****************************dot qint8 direct*************************/ | /****************************dot qint8 direct*************************/ | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) { | |||||
checker_conv_bias_qint8x8x8( | |||||
get_channel_wise_args({9}, 1, false, true, true, true), handle(), | |||||
"ARMDOTS8_DIRECT_CHANWISE_LARGE"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) { | |||||
checker_conv_bias_qint8x8x8( | |||||
get_channel_wise_args({9}, 2, false, true, true, true), handle(), | |||||
"ARMDOTS8_DIRECT_CHANWISE_LARGE"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S1) { | |||||
checker_conv_bias_qint8x8x8( | |||||
get_channel_wise_args({9}, 1, false, true, true, true), handle(), | |||||
"ARMDOTS8_IM2COL_CHANWISE_LARGE"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_IM2COL_LARGE_S2) { | |||||
checker_conv_bias_qint8x8x8( | |||||
get_channel_wise_args({9}, 2, false, true, true, true), handle(), | |||||
"ARMDOTS8_IM2COL_CHANWISE_LARGE"); | |||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { | ||||
auto args = get_nchw44_conv_bias_args( | auto args = get_nchw44_conv_bias_args( | ||||
{2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true); | {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true); | ||||
@@ -219,6 +219,113 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) { | |||||
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | ||||
run(M, K, 1); | run(M, K, 1); | ||||
} | } | ||||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM_DOT) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_DOT"); | |||||
checker.set_before_exec_callback(algo_ck); | |||||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||||
Param param; | |||||
param.format = Param::Format::DEFAULT; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
auto run = [&](size_t M, size_t N, size_t K) { | |||||
TensorShape A, B; | |||||
A = TensorShape{M, K}; | |||||
B = TensorShape{K, N}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype::Int8()) | |||||
.set_dtype(1, dtype::Int8()) | |||||
.set_dtype(2, dtype::Int32()) | |||||
.execs({A, B, {}}); | |||||
}; | |||||
run(1, 32, 4); | |||||
for (int n = 7; n < 43; n += 3) { | |||||
for (int k = 1; k < 33; k += 3) { | |||||
run(1, n, k); | |||||
} | |||||
} | |||||
} | |||||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM_N32K4_DOT) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"); | |||||
checker.set_before_exec_callback(algo_ck); | |||||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||||
Param param; | |||||
param.format = Param::Format::N32K4_DOT; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
auto run = [&](size_t M, size_t N, size_t K) { | |||||
TensorShape A, B; | |||||
A = TensorShape{M, K}; | |||||
B = TensorShape{N / 32, K / 4, 32, 4}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype::Int8()) | |||||
.set_dtype(1, dtype::Int8()) | |||||
.set_dtype(2, dtype::Int32()) | |||||
.execs({A, B, {}}); | |||||
}; | |||||
run(1, 32, 4); | |||||
for (int n = 32; n < 65; n += 32) { | |||||
for (int k = 4; k < 39; k += 4) { | |||||
run(1, n, k); | |||||
} | |||||
} | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | |||||
TEST_F(ARM_COMMON, BENCHMARK_QINT8x8x32_GEVM_N32K4_DOT) { | |||||
using Param = MatrixMul::Param; | |||||
auto algo_ck = AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT"); | |||||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-30, 30); | |||||
Param param; | |||||
param.format = Param::Format::N32K4_DOT; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
constexpr size_t RUNS = 2000; | |||||
Benchmarker<MatrixMul> benchmarker_int(handle()); | |||||
benchmarker_int.set_times(RUNS) | |||||
.set_dtype(0, dtype::Int8{}) | |||||
.set_dtype(1, dtype::Int8{}) | |||||
.set_dtype(2, dtype::Int32{}) | |||||
.set_param(param) | |||||
.set_before_exec_callback(algo_ck) | |||||
.set_display(false); | |||||
Benchmarker<MatrixMul> benchmarker_float(handle()); | |||||
benchmarker_float.set_display(false).set_times(RUNS); | |||||
auto bench = [&](size_t M, size_t N, size_t K) { | |||||
auto int_used = | |||||
benchmarker_int.exec({{M, K}, {N / 32, K / 4, 32, 4}, {}}) / RUNS; | |||||
auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; | |||||
float computations = 2.f * M * K * N * 1e-6; | |||||
float through_put = (M * K + N * K + M * N) * 1e-6; | |||||
printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f ms " | |||||
"%f Gflops speedup: %f, through put %f G\n", | |||||
M, K, N, float_used, computations / float_used, int_used, | |||||
computations / int_used, float_used / int_used, through_put / int_used); | |||||
}; | |||||
bench(1, 256, 512); | |||||
bench(1, 256, 1024); | |||||
bench(1, 512, 512); | |||||
bench(1, 512, 1024); | |||||
} | |||||
#endif | |||||
#endif | #endif | ||||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | ||||