@@ -407,6 +407,11 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_kern( | |||
return kern_mk8_16x12x1; | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( | |||
AlgoF16MK8_16x12x1, megdnn_aarch64_matmul_kern, "AlogF16MK8_16x12x1Impl"_hash, | |||
aarch64::matmul::hgemm_mk8_16x12, dt_float16, dt_float16, AlgoDataType::FLOAT16, | |||
MK8); | |||
#endif | |||
#if MGB_ENABLE_DOT | |||
@@ -93,7 +93,7 @@ public: | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
MEGDNN_OVERRIDE_MATMUL_DESC(16, 12, 1, 2, AlgoDataType::FLOAT16, MK8); | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_16X12X1); | |||
}; | |||
@@ -9,8 +9,8 @@ | |||
template <> | |||
void matmul_mk8_16x12::kern<M_BLOCK, N_BLOCK>( | |||
const dt_float16* packedA, const dt_float16* packedB, int K, | |||
dt_float16* out, int LDC, bool is_first_k) { | |||
const dt_float16* packedA, const dt_float16* packedB, int K, dt_float16* out, | |||
int LDC, bool is_first_k) { | |||
#define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n" | |||
#define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n" | |||
// clang-format off | |||
@@ -26,6 +26,8 @@ static fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
format = param::MatrixMul::Format::MK4; | |||
} else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { | |||
format = param::MatrixMul::Format::MK4_DOT; | |||
} else if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||
format = param::MatrixMul::Format::MK8; | |||
} | |||
size_t M = oc_tile_size; | |||
size_t N = ohw_tile_size; | |||
@@ -329,9 +331,15 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
if (format != param::ConvBias::Format::NCHW && | |||
format != param::ConvBias::Format::NCHW44 && | |||
format != param::ConvBias::Format::NCHW44_DOT) { | |||
format != param::ConvBias::Format::NCHW44_DOT && | |||
format != param::ConvBias::Format::NCHW88) { | |||
return false; | |||
} | |||
if (format == param::ConvBias::Format::NCHW88) { | |||
if (matmul_desc.packmode != Pack_Mode::DEFAULT) { | |||
return false; | |||
} | |||
} | |||
if (format == param::ConvBias::Format::NCHW44 || | |||
format == param::ConvBias::Format::NCHW44_DOT) { | |||
//! current NCHW44 im2col only support DEFAULT mode matmul | |||
@@ -248,8 +248,18 @@ public: | |||
break; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
case StrategyType::FLOAT_FP16: | |||
cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
"DefaultStrategyType::FLOAT_FP16"_hash); | |||
if (format == param::ConvBias::Format::NCHW) { | |||
cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
"DefaultStrategyType::FLOAT_FP16"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW88) { | |||
cb1(NCHW88, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||
"DefaultStrategyTypeNCHW88::FLOAT_FP16"_hash); | |||
} else { | |||
megdnn_throw(ssprintf( | |||
"Current only support layout NCHW/NCHW88 for im2col algo " | |||
"of float 16, but got %d\n", | |||
uint32_t(format))); | |||
} | |||
break; | |||
#endif | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
@@ -348,6 +348,32 @@ template < | |||
typename op_dtype, megdnn::PostprocessMode postprocess_mode> | |||
class Strategy< | |||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | |||
PackMode::DEFAULT, FormatMode::NCHW88> | |||
: public Strategy< | |||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
postprocess_mode, PackMode::DEFAULT> { | |||
public: | |||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||
constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; | |||
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||
Strategy() = default; | |||
void exec_im2col( | |||
const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
}; | |||
template < | |||
typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, | |||
typename op_dtype, megdnn::PostprocessMode postprocess_mode> | |||
class Strategy< | |||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | |||
PackMode::NO_PACK> | |||
: public StrategyBridge< | |||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
@@ -0,0 +1,98 @@ | |||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||
#include "src/fallback/convolution/img2col_helper.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#else | |||
#include "src/common/postprocess_helper.h" | |||
#endif | |||
using namespace megdnn; | |||
#if MEGDNN_X86 | |||
using namespace x86; | |||
#endif | |||
namespace megdnn { | |||
template < | |||
typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, | |||
typename op_dtype, megdnn::PostprocessMode postprocess_mode> | |||
void Strategy< | |||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | |||
PackMode::DEFAULT, FormatMode::NCHW88>:: | |||
exec_im2col( | |||
const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
size_t sh = param.filter_meta.stride[0]; | |||
size_t sw = param.filter_meta.stride[1]; | |||
size_t ow = param.osz[1]; | |||
size_t ic = param.filter_meta.icpg; | |||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||
size_t fh = param.filter_meta.spatial[0]; | |||
size_t fw = param.filter_meta.spatial[1]; | |||
bool is_xcoor = !param.filter_meta.should_flip; | |||
constexpr static size_t pack_size = 8; | |||
size_t input_offset = | |||
ic * ih * iw * | |||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||
sizeof(src_ctype); | |||
src_ctype* src = reinterpret_cast<src_ctype*>( | |||
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||
input_offset); | |||
bool is_phpwzero = | |||
(param.filter_meta.padding[0] == 0 && param.filter_meta.padding[1] == 0); | |||
if (is_phpwzero) { | |||
src = const_cast<src_ctype*>( | |||
param.src<src_ctype>(sparam.batch_id, sparam.group_id)); | |||
} | |||
src_ctype* im2col_dst = | |||
reinterpret_cast<src_ctype*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||
if (sh == 1 && sw == 1) { | |||
if (is_xcoor) { | |||
img2col_nchw8<true>( | |||
src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} else { | |||
img2col_nchw8<false>( | |||
src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index, | |||
sparam.output_block_size); | |||
} | |||
} else { | |||
if (is_xcoor) { | |||
img2col_stride_nchw8<true>( | |||
src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} else { | |||
img2col_stride_nchw8<false>( | |||
src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw, | |||
sparam.ohw_cur_index, sparam.output_block_size); | |||
} | |||
} | |||
matmul_param.M = sparam.output_block_oc_size; | |||
matmul_param.N = sparam.output_block_size; | |||
matmul_param.LDB = pack_size * sparam.output_block_size; | |||
matmul_param.LDC = pack_size * sparam.output_block_size; | |||
matmul_param.B_ptr = im2col_dst; | |||
src_ctype* b_panel = | |||
reinterpret_cast<src_ctype*>(bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)); | |||
matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N); | |||
} | |||
#define INSTANTIAL_CLASS( \ | |||
_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, _postprocess_mode) \ | |||
template class Strategy< \ | |||
_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, \ | |||
_postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW88>; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
INSTANTIAL_CLASS( | |||
dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||
megdnn::PostprocessMode::FLOAT); | |||
#endif | |||
#undef INSTANTIAL_CLASS | |||
} // namespace megdnn |
@@ -348,6 +348,441 @@ void img2col_nchw4( | |||
} | |||
template <bool is_xcorr, typename dtype> | |||
void img2col_nchw8( | |||
const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC, | |||
const int IH, const int IW, const int FH, const int FW, const int cur_index, | |||
const int block_size) { | |||
int start_h = cur_index / OW; | |||
int cur_n_remain = cur_index % OW; | |||
int end_h = (cur_index + block_size) / OW; | |||
int end_n_remain = (cur_index + block_size) % OW; | |||
bool same_line = (start_h == end_h); | |||
int IC_div_8 = IC / 8; | |||
if (sizeof(dtype) == 2) { | |||
if (same_line) { | |||
int dst_idx = 0; | |||
rep(ic, IC_div_8) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2 = fh, fw2 = fw; | |||
if (!is_xcorr) { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
//! TODO: Substitute GI for arm intrinsic when GI supports FP16 | |||
//! data type. | |||
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + | |||
cur_n_remain + fw2); | |||
for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||
vst1q_f16( | |||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||
vld1q_f16( | |||
reinterpret_cast<const __fp16*>(src) + | |||
src_idx)); | |||
dst_idx += 8; | |||
src_idx += 8; | |||
} | |||
#else | |||
int src_idx = 2 * (ic * IH * IW + (start_h + fh2) * IW + | |||
cur_n_remain + fw2); | |||
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||
for (int w = cur_n_remain; w < end_n_remain; w++) { | |||
u64_dst[dst_idx] = u64_src[src_idx]; | |||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||
dst_idx += 2; | |||
src_idx += 2; | |||
} | |||
#endif | |||
} | |||
} | |||
} | |||
} else { | |||
int dst_idx = 0; | |||
rep(ic, IC_div_8) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2 = fh, fw2 = fw; | |||
if (!is_xcorr) { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
int src_idx = 8 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 + | |||
cur_n_remain); | |||
for (int w = cur_n_remain; w < OW; ++w) { | |||
vst1q_f16( | |||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||
vld1q_f16( | |||
reinterpret_cast<const __fp16*>(src) + | |||
src_idx)); | |||
dst_idx += 8; | |||
src_idx += 8; | |||
} | |||
src_idx = 8 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2); | |||
for (int h = start_h + 1; h < end_h; ++h) { | |||
int _src_idx = src_idx; | |||
rep(w, OW) { | |||
vst1q_f16( | |||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||
vld1q_f16( | |||
reinterpret_cast<const __fp16*>(src) + | |||
_src_idx)); | |||
dst_idx += 8; | |||
_src_idx += 8; | |||
} | |||
src_idx += IW * 8; | |||
} | |||
src_idx = 8 * (ic * IH * IW + (fh2 + end_h) * IW + fw2); | |||
rep(w, end_n_remain) { | |||
vst1q_f16( | |||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||
vld1q_f16( | |||
reinterpret_cast<const __fp16*>(src) + | |||
src_idx)); | |||
dst_idx += 8; | |||
src_idx += 8; | |||
} | |||
#else | |||
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||
int src_idx = 2 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 + | |||
cur_n_remain); | |||
for (int w = cur_n_remain; w < OW; ++w) { | |||
u64_dst[dst_idx] = u64_src[src_idx]; | |||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||
dst_idx += 2; | |||
src_idx += 2; | |||
} | |||
src_idx = 2 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2); | |||
for (int h = start_h + 1; h < end_h; ++h) { | |||
int _src_idx = src_idx; | |||
rep(w, OW) { | |||
u64_dst[dst_idx] = u64_src[_src_idx]; | |||
u64_dst[dst_idx + 1] = u64_src[_src_idx + 1]; | |||
dst_idx += 2; | |||
_src_idx += 2; | |||
} | |||
src_idx += IW * 2; | |||
} | |||
src_idx = 2 * (ic * IH * IW + (fh2 + end_h) * IW + fw2); | |||
rep(w, end_n_remain) { | |||
u64_dst[dst_idx] = u64_src[src_idx]; | |||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||
dst_idx += 2; | |||
src_idx += 2; | |||
} | |||
#endif | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
if (same_line) { | |||
int dst_idx = 0; | |||
rep(ic, IC_div_8) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2 = fh, fw2 = fw; | |||
if (!is_xcorr) { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 + | |||
cur_n_remain); | |||
for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
int dst_idx = 0; | |||
rep(ic, IC_div_8) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2 = fh, fw2 = fw; | |||
if (!is_xcorr) { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 + | |||
cur_n_remain); | |||
for (int w = cur_n_remain; w < OW; ++w) { | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
} | |||
src_idx = 8 * (ic * IH * IW + (start_h + 1 + fh2) * IW + fw2); | |||
for (int h = start_h + 1; h < end_h; ++h) { | |||
rep(w, OW) { | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
} | |||
} | |||
src_idx = 8 * (ic * IH * IW + (end_h + fh2) * IW + fw2); | |||
rep(w, end_n_remain) { | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
dst[dst_idx++] = src[src_idx++]; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
template <bool is_xcorr, typename dtype> | |||
void img2col_stride_nchw8( | |||
const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC, | |||
const int IH, const int IW, const int FH, const int FW, const int SH, | |||
const int SW, const int cur_index, const int block_size) { | |||
int start_h = cur_index / OW; | |||
int cur_n_remain = cur_index % OW; | |||
int end_h = (cur_index + block_size) / OW; | |||
int end_n_remain = (cur_index + block_size) % OW; | |||
bool same_line = (start_h == end_h); | |||
int IC_div_8 = IC / 8; | |||
if (sizeof(dtype) == 2) { | |||
if (same_line) { | |||
int dst_idx = 0; | |||
rep(ic, IC_div_8) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2 = fh, fw2 = fw; | |||
if (!is_xcorr) { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||
cur_n_remain * SW + fw2); | |||
for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||
vst1q_f16( | |||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||
vld1q_f16( | |||
reinterpret_cast<const __fp16*>(src) + | |||
src_idx)); | |||
dst_idx += 8; | |||
src_idx += 8 * SW; | |||
} | |||
#else | |||
int src_idx = 2 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||
cur_n_remain * SW + fw2); | |||
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||
for (int w = cur_n_remain; w < end_n_remain; w++) { | |||
u64_dst[dst_idx] = u64_src[src_idx]; | |||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||
dst_idx += 2; | |||
src_idx += 2 * SW; | |||
} | |||
#endif | |||
} | |||
} | |||
} | |||
} else { | |||
int dst_idx = 0; | |||
rep(ic, IC_div_8) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2 = fh, fw2 = fw; | |||
if (!is_xcorr) { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
int src_idx = 8 * (ic * IH * IW + (fh2 + start_h * SH) * IW + | |||
fw2 + cur_n_remain * SW); | |||
for (int w = cur_n_remain; w < OW; ++w) { | |||
vst1q_f16( | |||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||
vld1q_f16( | |||
reinterpret_cast<const __fp16*>(src) + | |||
src_idx)); | |||
dst_idx += 8; | |||
src_idx += 8 * SW; | |||
} | |||
src_idx = 8 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW + | |||
fw2); | |||
for (int h = start_h + 1; h < end_h; ++h) { | |||
int _src_idx = src_idx; | |||
rep(w, OW) { | |||
vst1q_f16( | |||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||
vld1q_f16( | |||
reinterpret_cast<const __fp16*>(src) + | |||
_src_idx)); | |||
dst_idx += 8; | |||
_src_idx += 8 * SW; | |||
} | |||
src_idx += IW * 8 * SH; | |||
} | |||
src_idx = 8 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2); | |||
rep(w, end_n_remain) { | |||
vst1q_f16( | |||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||
vld1q_f16( | |||
reinterpret_cast<const __fp16*>(src) + | |||
src_idx)); | |||
dst_idx += 8; | |||
src_idx += 8 * SW; | |||
} | |||
#else | |||
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||
int src_idx = 2 * (ic * IH * IW + (fh2 + start_h * SH) * IW + | |||
fw2 + cur_n_remain * SW); | |||
for (int w = cur_n_remain; w < OW; ++w) { | |||
u64_dst[dst_idx] = u64_src[src_idx]; | |||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||
dst_idx += 2; | |||
src_idx += 2 * SW; | |||
} | |||
src_idx = 2 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW + | |||
fw2); | |||
for (int h = start_h + 1; h < end_h; ++h) { | |||
int _src_idx = src_idx; | |||
rep(w, OW) { | |||
u64_dst[dst_idx] = u64_src[_src_idx]; | |||
u64_dst[dst_idx + 1] = u64_src[_src_idx + 1]; | |||
dst_idx += 2; | |||
_src_idx += 2 * SW; | |||
} | |||
src_idx += IW * 2 * SH; | |||
} | |||
src_idx = 2 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2); | |||
rep(w, end_n_remain) { | |||
u64_dst[dst_idx] = u64_src[src_idx]; | |||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||
dst_idx += 2; | |||
src_idx += 2 * SW; | |||
} | |||
#endif | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
if (same_line) { | |||
int dst_idx = 0; | |||
rep(ic, IC_div_8) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2 = fh, fw2 = fw; | |||
if (!is_xcorr) { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||
fw2 + cur_n_remain * SW); | |||
for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||
dst[dst_idx++] = src[src_idx]; | |||
dst[dst_idx++] = src[src_idx + 1]; | |||
dst[dst_idx++] = src[src_idx + 2]; | |||
dst[dst_idx++] = src[src_idx + 3]; | |||
dst[dst_idx++] = src[src_idx + 4]; | |||
dst[dst_idx++] = src[src_idx + 5]; | |||
dst[dst_idx++] = src[src_idx + 6]; | |||
dst[dst_idx++] = src[src_idx + 7]; | |||
src_idx += 8 * SW; | |||
} | |||
} | |||
} | |||
} | |||
} else { | |||
int dst_idx = 0; | |||
rep(ic, IC_div_8) { | |||
rep(fh, FH) { | |||
rep(fw, FW) { | |||
int fh2 = fh, fw2 = fw; | |||
if (!is_xcorr) { | |||
fh2 = FH - fh - 1; | |||
fw2 = FW - fw - 1; | |||
} | |||
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||
fw2 + cur_n_remain * SW); | |||
for (int w = cur_n_remain; w < OW; ++w) { | |||
dst[dst_idx++] = src[src_idx]; | |||
dst[dst_idx++] = src[src_idx + 1]; | |||
dst[dst_idx++] = src[src_idx + 2]; | |||
dst[dst_idx++] = src[src_idx + 3]; | |||
dst[dst_idx++] = src[src_idx + 4]; | |||
dst[dst_idx++] = src[src_idx + 5]; | |||
dst[dst_idx++] = src[src_idx + 6]; | |||
dst[dst_idx++] = src[src_idx + 7]; | |||
src_idx += 8 * SW; | |||
} | |||
src_idx = 8 * (ic * IH * IW + ((start_h + 1) * SH + fh2) * IW + | |||
fw2); | |||
for (int h = start_h + 1; h < end_h; ++h) { | |||
rep(w, OW) { | |||
dst[dst_idx++] = src[src_idx]; | |||
dst[dst_idx++] = src[src_idx + 1]; | |||
dst[dst_idx++] = src[src_idx + 2]; | |||
dst[dst_idx++] = src[src_idx + 3]; | |||
dst[dst_idx++] = src[src_idx + 4]; | |||
dst[dst_idx++] = src[src_idx + 5]; | |||
dst[dst_idx++] = src[src_idx + 6]; | |||
dst[dst_idx++] = src[src_idx + 7]; | |||
src_idx += 8 * SW; | |||
} | |||
} | |||
src_idx = 8 * (ic * IH * IW + (end_h * SH + fh2) * IW + fw2); | |||
rep(w, end_n_remain) { | |||
dst[dst_idx++] = src[src_idx]; | |||
dst[dst_idx++] = src[src_idx + 1]; | |||
dst[dst_idx++] = src[src_idx + 2]; | |||
dst[dst_idx++] = src[src_idx + 3]; | |||
dst[dst_idx++] = src[src_idx + 4]; | |||
dst[dst_idx++] = src[src_idx + 5]; | |||
dst[dst_idx++] = src[src_idx + 6]; | |||
dst[dst_idx++] = src[src_idx + 7]; | |||
src_idx += 8 * SW; | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
template <bool is_xcorr, typename dtype> | |||
void img2col_stride( | |||
const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH, | |||
const int OW, const int IC, const int IH, const int IW, const int FH, | |||
@@ -68,6 +68,87 @@ void benchmark_impl( | |||
multi_thread_config.nr_thread); | |||
} | |||
} | |||
void benchmark_with_contrast( | |||
const std::vector<conv_bias::TestArg>& args, const std::string algo_name, | |||
std::vector<DType>& data_type, | |||
const std::vector<conv_bias::TestArg>& args_contrast, | |||
const std::string algo_name_contrast, std::vector<DType>& data_type_contrast, | |||
size_t RUNS, TaskExecutorConfig&& single_thread_config) { | |||
auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config); | |||
auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get()); | |||
auto benchmarker_contrast = Benchmarker<ConvBias>(single_thread_handle.get()); | |||
benchmarker.set_times(RUNS) | |||
.set_display(false) | |||
.set_dtype(0, data_type[0]) | |||
.set_dtype(1, data_type[1]) | |||
.set_dtype(2, data_type[2]) | |||
.set_dtype(4, data_type[3]) | |||
.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str())); | |||
benchmarker_contrast.set_times(RUNS) | |||
.set_display(false) | |||
.set_dtype(0, data_type_contrast[0]) | |||
.set_dtype(1, data_type_contrast[1]) | |||
.set_dtype(2, data_type_contrast[2]) | |||
.set_dtype(4, data_type_contrast[3]) | |||
.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||
algo_name_contrast.c_str())); | |||
size_t arg_size = args.size(), arg_contrast_size = args_contrast.size(); | |||
megdnn_assert(arg_size == arg_contrast_size); | |||
rep(i, arg_size) { | |||
TensorLayout dst_layout, dst_layout_contrast; | |||
auto opr = single_thread_handle.get()->create_operator<ConvBias>(); | |||
auto&& arg = args[i]; | |||
opr->param() = arg.param; | |||
opr->deduce_layout( | |||
{arg.src, data_type[0]}, {arg.filter, data_type[1]}, | |||
{arg.bias, data_type[2]}, {}, dst_layout); | |||
float computation = (dst_layout.total_nr_elems() * arg.filter[1] * | |||
arg.filter[2] * arg.filter[3] * arg.filter[4] * 2.0) / | |||
(1024 * 1024 * 1024) * 1e3; | |||
benchmarker.set_param(arg.param); | |||
auto used = benchmarker.exec({arg.src, arg.filter, arg.bias, {}, {}}) / RUNS; | |||
auto&& arg_contrast = args_contrast[i]; | |||
opr->param() = arg_contrast.param; | |||
opr->deduce_layout( | |||
{arg_contrast.src, data_type_contrast[0]}, | |||
{arg_contrast.filter, data_type_contrast[1]}, | |||
{arg_contrast.bias, data_type_contrast[2]}, {}, dst_layout_contrast); | |||
float computation_contrast = | |||
(dst_layout_contrast.total_nr_elems() * arg_contrast.filter[1] * | |||
arg_contrast.filter[2] * arg_contrast.filter[3] * | |||
arg_contrast.filter[4] * 2.0) / | |||
(1024 * 1024 * 1024) * 1e3; | |||
benchmarker_contrast.set_param(arg_contrast.param); | |||
auto used_contrast = benchmarker_contrast.exec( | |||
{arg_contrast.src, | |||
arg_contrast.filter, | |||
arg_contrast.bias, | |||
{}, | |||
{}}) / | |||
RUNS; | |||
printf("Bench case: \n"); | |||
printf("padding: %u, stride: %u, nonline mode: %u\n", arg.param.pad_h, | |||
arg.param.stride_h, arg.param.nonlineMode); | |||
printf("%s %s %s\n", arg.src.to_string().c_str(), | |||
arg.filter.to_string().c_str(), arg.bias.to_string().c_str()); | |||
printf("%s %s %s\n", arg_contrast.src.to_string().c_str(), | |||
arg_contrast.filter.to_string().c_str(), | |||
arg_contrast.bias.to_string().c_str()); | |||
printf("%s: %f gflops;\n%s: %f gflops\n" | |||
"spead up = %f\n", | |||
algo_name.c_str(), computation / used, algo_name_contrast.c_str(), | |||
computation_contrast / used_contrast, used_contrast / used); | |||
} | |||
} | |||
} // namespace | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -1591,6 +1672,91 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_FP32) { | |||
data_type); | |||
shapes_and_computation.clear(); | |||
} | |||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_NCHW44_VS_NCHW88) { | |||
constexpr size_t RUNS = 50; | |||
using NLMode = param::ConvBias::NonlineMode; | |||
std::vector<conv_bias::TestArg> args_nchw88, args_nchw44; | |||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||
size_t group) { | |||
param::ConvBias param_nchw88, param_nchw44; | |||
param_nchw88.format = param::ConvBias::Format::NCHW88; | |||
param_nchw44.format = param::ConvBias::Format::NCHW44; | |||
for (size_t pad : {1, 2, 4}) { | |||
for (size_t stride : {1, 2, 3}) { | |||
for (auto nlmode : | |||
{NLMode::RELU, NLMode::IDENTITY, NLMode::SIGMOID, | |||
NLMode::H_SWISH}) { | |||
param_nchw88.nonlineMode = nlmode; | |||
param_nchw88.pad_h = pad; | |||
param_nchw88.pad_w = pad; | |||
param_nchw88.stride_h = stride; | |||
param_nchw88.stride_w = stride; | |||
param_nchw44.nonlineMode = nlmode; | |||
param_nchw44.pad_h = pad; | |||
param_nchw44.pad_w = pad; | |||
param_nchw44.stride_h = stride; | |||
param_nchw44.stride_w = stride; | |||
args_nchw88.emplace_back( | |||
param_nchw88, TensorShape{N, IC / 8, H, W, 8}, | |||
TensorShape{OC / 8, IC / group / 8, FS, FS, 8, 8}, | |||
TensorShape{1, OC / 8, 1, 1, 8}); | |||
args_nchw44.emplace_back( | |||
param_nchw44, TensorShape{N, IC / 4, H, W, 4}, | |||
TensorShape{OC / 4, IC / group / 4, FS, FS, 4, 4}, | |||
TensorShape{1, OC / 4, 1, 1, 4}); | |||
} | |||
} | |||
} | |||
}; | |||
std::vector<DType> data_type_fp16 = { | |||
dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()}; | |||
std::vector<DType> data_type_fp32 = { | |||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||
bench_case(1, 32, 32, 300, 300, 3, 1); | |||
bench_case(1, 32, 32, 400, 400, 3, 1); | |||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||
bench_case(1, 32, 64, 200, 200, 3, 1); | |||
bench_case(1, 32, 64, 128, 128, 3, 1); | |||
bench_case(1, 32, 64, 100, 100, 3, 1); | |||
bench_case(1, 32, 64, 80, 80, 3, 1); | |||
bench_case(1, 32, 128, 200, 200, 3, 1); | |||
bench_case(1, 32, 128, 128, 128, 3, 1); | |||
bench_case(1, 32, 128, 100, 100, 3, 1); | |||
bench_case(1, 32, 128, 80, 80, 3, 1); | |||
bench_case(1, 64, 32, 7, 7, 3, 1); | |||
bench_case(1, 64, 64, 7, 7, 3, 1); | |||
bench_case(1, 64, 128, 7, 7, 3, 1); | |||
bench_case(1, 64, 256, 7, 7, 3, 1); | |||
bench_case(1, 64, 512, 7, 7, 3, 1); | |||
bench_case(1, 64, 1024, 7, 7, 3, 1); | |||
bench_case(1, 64, 32, 14, 14, 3, 1); | |||
bench_case(1, 64, 64, 14, 14, 3, 1); | |||
bench_case(1, 64, 128, 14, 14, 3, 1); | |||
bench_case(1, 64, 256, 14, 14, 3, 1); | |||
bench_case(1, 64, 512, 14, 14, 3, 1); | |||
bench_case(1, 64, 1024, 14, 14, 3, 1); | |||
bench_case(1, 128, 128, 14, 14, 3, 1); | |||
bench_case(1, 128, 256, 14, 14, 3, 1); | |||
bench_case(1, 512, 512, 14, 14, 3, 1); | |||
bench_case(1, 256, 512, 14, 14, 3, 1); | |||
bench_case(1, 512, 1024, 14, 14, 3, 1); | |||
bench_case(1, 1024, 1024, 14, 14, 3, 1); | |||
std::string algo_name_nchw88 = "IM2COLMATMUL:AARCH64_F16_MK8_16X12X1:96"; | |||
std::string algo_name_nchw44 = "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1:96"; | |||
benchmark_with_contrast( | |||
args_nchw88, algo_name_nchw88, data_type_fp16, args_nchw44, | |||
algo_name_nchw44, data_type_fp32, RUNS, {1, {4}}); | |||
} | |||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | |||
BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) { | |||
constexpr size_t RUNS = 50; | |||
@@ -362,6 +362,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) { | |||
#endif | |||
#undef cb | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_MK8_FP16) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = get_nchw88_conv_bias_args( | |||
{2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||
auto args1 = get_nchw88_conv_bias_args( | |||
{2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 2, 3); | |||
args.insert(args.begin(), args1.begin(), args1.begin()); | |||
args1 = get_nchw88_conv_bias_args( | |||
{2, 3, 4, 5, 6, 7, 9}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 3, 4); | |||
args.insert(args.begin(), args1.begin(), args1.begin()); | |||
NormalRNG rng(1); | |||
#define cb(name) \ | |||
checker_conv_bias_common( \ | |||
args, handle(), &rng, 0.03, dtype::Float16{}, dtype::Float16{}, \ | |||
dtype::Float16{}, dtype::Float16{}, name); | |||
#if MEGDNN_AARCH64 | |||
cb("IM2COLMATMUL:AARCH64_F16_MK8_16X12X1"); | |||
#endif | |||
#undef cb | |||
} | |||
#endif | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
@@ -161,6 +161,24 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) { | |||
run(); | |||
} | |||
TEST_F(ARM_COMMON, ELEMWISE_SIGMOID) { | |||
using Mode = ElemwiseForward::Param::Mode; | |||
Checker<ElemwiseForward> checker(handle()); | |||
checker.set_epsilon(1e-3); | |||
checker.set_dtype(0, dtype::Float16()); | |||
checker.set_param(Mode::SIGMOID); | |||
for (size_t n : {1, 2, 3}) { | |||
for (size_t ic : {8, 16, 24, 32}) { | |||
for (size_t ih : {5, 10, 15, 20, 21, 37}) { | |||
for (size_t iw : {7, 9, 11, 13, 14, 20, 35}) { | |||
checker.exec({{n, ic, ih, iw}, {}}); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { | |||
using Mode = ElemwiseForward::Param::Mode; | |||
Checker<ElemwiseForward> checker(handle()); | |||
@@ -98,6 +98,9 @@ TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_UNARY) { | |||
BENCHMARK_CASES_INT(shape, dtype::Int16()); | |||
BENCHMARK_CASES_INT(shape, dtype::Int8()); | |||
BENCHMARK_CASES_FLOAT(shape, dtype::Float32()); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
BENCHMARK_CASES_FLOAT(shape, dtype::Float16()); | |||
#endif | |||
#undef BENCHMARK_CASES_INT | |||
#undef BENCHMARK_CASES_FLOAT | |||
#undef RUN | |||
@@ -1580,17 +1580,19 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||
std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | |||
std::vector<size_t> kernel_vec, | |||
std::vector<param::ConvBias::NonlineMode> nlmode_vec, | |||
std::vector<megdnn::BiasMode> biasmode_vec, size_t stride) { | |||
std::vector<megdnn::BiasMode> biasmode_vec, size_t stride, int pad) { | |||
using namespace conv_bias; | |||
using NLMode = param::ConvBias::NonlineMode; | |||
std::vector<TestArg> args; | |||
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t kernel, | |||
size_t stride, size_t group, NLMode nlmode, | |||
size_t stride, int pad, size_t group, NLMode nlmode, | |||
megdnn::BiasMode bias_mode) { | |||
constexpr int pack_c = 8; | |||
const size_t pad = kernel / 2; | |||
if (pad == -1) { | |||
pad = kernel / 2; | |||
} | |||
auto oc_per_group = oc / group; | |||
auto ic_per_group = ic / group; | |||
@@ -1651,8 +1653,8 @@ std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | |||
if (kernel < h || kernel < w) { | |||
continue; | |||
} | |||
pack(n, oc, ic, h, w, kernel, stride, group, | |||
nlmode, bias); | |||
pack(n, oc, ic, h, w, kernel, stride, pad, | |||
group, nlmode, bias); | |||
} | |||
} | |||
return args; | |||