@@ -407,6 +407,11 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_kern( | |||||
return kern_mk8_16x12x1; | return kern_mk8_16x12x1; | ||||
} | } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( | |||||
AlgoF16MK8_16x12x1, megdnn_aarch64_matmul_kern, "AlogF16MK8_16x12x1Impl"_hash, | |||||
aarch64::matmul::hgemm_mk8_16x12, dt_float16, dt_float16, AlgoDataType::FLOAT16, | |||||
MK8); | |||||
#endif | #endif | ||||
#if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
@@ -93,7 +93,7 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(16, 12, 1, 2, AlgoDataType::FLOAT16, MK8); | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_16X12X1); | MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_16X12X1); | ||||
}; | }; | ||||
@@ -9,8 +9,8 @@ | |||||
template <> | template <> | ||||
void matmul_mk8_16x12::kern<M_BLOCK, N_BLOCK>( | void matmul_mk8_16x12::kern<M_BLOCK, N_BLOCK>( | ||||
const dt_float16* packedA, const dt_float16* packedB, int K, | |||||
dt_float16* out, int LDC, bool is_first_k) { | |||||
const dt_float16* packedA, const dt_float16* packedB, int K, dt_float16* out, | |||||
int LDC, bool is_first_k) { | |||||
#define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n" | #define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n" | ||||
#define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n" | #define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n" | ||||
// clang-format off | // clang-format off | ||||
@@ -26,6 +26,8 @@ static fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||||
format = param::MatrixMul::Format::MK4; | format = param::MatrixMul::Format::MK4; | ||||
} else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { | } else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { | ||||
format = param::MatrixMul::Format::MK4_DOT; | format = param::MatrixMul::Format::MK4_DOT; | ||||
} else if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||||
format = param::MatrixMul::Format::MK8; | |||||
} | } | ||||
size_t M = oc_tile_size; | size_t M = oc_tile_size; | ||||
size_t N = ohw_tile_size; | size_t N = ohw_tile_size; | ||||
@@ -329,9 +331,15 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
if (format != param::ConvBias::Format::NCHW && | if (format != param::ConvBias::Format::NCHW && | ||||
format != param::ConvBias::Format::NCHW44 && | format != param::ConvBias::Format::NCHW44 && | ||||
format != param::ConvBias::Format::NCHW44_DOT) { | |||||
format != param::ConvBias::Format::NCHW44_DOT && | |||||
format != param::ConvBias::Format::NCHW88) { | |||||
return false; | return false; | ||||
} | } | ||||
if (format == param::ConvBias::Format::NCHW88) { | |||||
if (matmul_desc.packmode != Pack_Mode::DEFAULT) { | |||||
return false; | |||||
} | |||||
} | |||||
if (format == param::ConvBias::Format::NCHW44 || | if (format == param::ConvBias::Format::NCHW44 || | ||||
format == param::ConvBias::Format::NCHW44_DOT) { | format == param::ConvBias::Format::NCHW44_DOT) { | ||||
//! current NCHW44 im2col only support DEFAULT mode matmul | //! current NCHW44 im2col only support DEFAULT mode matmul | ||||
@@ -248,8 +248,18 @@ public: | |||||
break; | break; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
case StrategyType::FLOAT_FP16: | case StrategyType::FLOAT_FP16: | ||||
cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
"DefaultStrategyType::FLOAT_FP16"_hash); | |||||
if (format == param::ConvBias::Format::NCHW) { | |||||
cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
"DefaultStrategyType::FLOAT_FP16"_hash); | |||||
} else if (format == param::ConvBias::Format::NCHW88) { | |||||
cb1(NCHW88, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, | |||||
"DefaultStrategyTypeNCHW88::FLOAT_FP16"_hash); | |||||
} else { | |||||
megdnn_throw(ssprintf( | |||||
"Current only support layout NCHW/NCHW88 for im2col algo " | |||||
"of float 16, but got %d\n", | |||||
uint32_t(format))); | |||||
} | |||||
break; | break; | ||||
#endif | #endif | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
@@ -348,6 +348,32 @@ template < | |||||
typename op_dtype, megdnn::PostprocessMode postprocess_mode> | typename op_dtype, megdnn::PostprocessMode postprocess_mode> | ||||
class Strategy< | class Strategy< | ||||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | ||||
PackMode::DEFAULT, FormatMode::NCHW88> | |||||
: public Strategy< | |||||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
postprocess_mode, PackMode::DEFAULT> { | |||||
public: | |||||
constexpr static size_t BUNDLE_PADDING_INDEX = 0; | |||||
constexpr static size_t BUNDLE_PACKA_INDEX = 1; | |||||
constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; | |||||
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; | |||||
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; | |||||
Strategy() = default; | |||||
void exec_im2col( | |||||
const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, | |||||
const StrategyParam& sparam, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernParam matmul_param, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
}; | |||||
template < | |||||
typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, | |||||
typename op_dtype, megdnn::PostprocessMode postprocess_mode> | |||||
class Strategy< | |||||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | |||||
PackMode::NO_PACK> | PackMode::NO_PACK> | ||||
: public StrategyBridge< | : public StrategyBridge< | ||||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | ||||
@@ -0,0 +1,98 @@ | |||||
#include "src/fallback/conv_bias/im2col/strategy_base.h" | |||||
#include "src/fallback/convolution/img2col_helper.h" | |||||
#if MEGDNN_X86 | |||||
#include "src/x86/conv_bias/postprocess_helper.h" | |||||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||||
#else | |||||
#include "src/common/postprocess_helper.h" | |||||
#endif | |||||
using namespace megdnn; | |||||
#if MEGDNN_X86 | |||||
using namespace x86; | |||||
#endif | |||||
namespace megdnn { | |||||
template < | |||||
typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, | |||||
typename op_dtype, megdnn::PostprocessMode postprocess_mode> | |||||
void Strategy< | |||||
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, | |||||
PackMode::DEFAULT, FormatMode::NCHW88>:: | |||||
exec_im2col( | |||||
const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, | |||||
const StrategyParam& sparam, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernParam matmul_param, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
size_t sh = param.filter_meta.stride[0]; | |||||
size_t sw = param.filter_meta.stride[1]; | |||||
size_t ow = param.osz[1]; | |||||
size_t ic = param.filter_meta.icpg; | |||||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||||
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; | |||||
size_t fh = param.filter_meta.spatial[0]; | |||||
size_t fw = param.filter_meta.spatial[1]; | |||||
bool is_xcoor = !param.filter_meta.should_flip; | |||||
constexpr static size_t pack_size = 8; | |||||
size_t input_offset = | |||||
ic * ih * iw * | |||||
(sparam.group_id + param.filter_meta.group * sparam.batch_id) * | |||||
sizeof(src_ctype); | |||||
src_ctype* src = reinterpret_cast<src_ctype*>( | |||||
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) + | |||||
input_offset); | |||||
bool is_phpwzero = | |||||
(param.filter_meta.padding[0] == 0 && param.filter_meta.padding[1] == 0); | |||||
if (is_phpwzero) { | |||||
src = const_cast<src_ctype*>( | |||||
param.src<src_ctype>(sparam.batch_id, sparam.group_id)); | |||||
} | |||||
src_ctype* im2col_dst = | |||||
reinterpret_cast<src_ctype*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); | |||||
if (sh == 1 && sw == 1) { | |||||
if (is_xcoor) { | |||||
img2col_nchw8<true>( | |||||
src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
} else { | |||||
img2col_nchw8<false>( | |||||
src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index, | |||||
sparam.output_block_size); | |||||
} | |||||
} else { | |||||
if (is_xcoor) { | |||||
img2col_stride_nchw8<true>( | |||||
src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} else { | |||||
img2col_stride_nchw8<false>( | |||||
src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw, | |||||
sparam.ohw_cur_index, sparam.output_block_size); | |||||
} | |||||
} | |||||
matmul_param.M = sparam.output_block_oc_size; | |||||
matmul_param.N = sparam.output_block_size; | |||||
matmul_param.LDB = pack_size * sparam.output_block_size; | |||||
matmul_param.LDC = pack_size * sparam.output_block_size; | |||||
matmul_param.B_ptr = im2col_dst; | |||||
src_ctype* b_panel = | |||||
reinterpret_cast<src_ctype*>(bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)); | |||||
matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N); | |||||
} | |||||
#define INSTANTIAL_CLASS( \ | |||||
_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, _postprocess_mode) \ | |||||
template class Strategy< \ | |||||
_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, \ | |||||
_postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW88>; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
INSTANTIAL_CLASS( | |||||
dt_float16, dt_float16, dt_float16, __fp16, __fp16, | |||||
megdnn::PostprocessMode::FLOAT); | |||||
#endif | |||||
#undef INSTANTIAL_CLASS | |||||
} // namespace megdnn |
@@ -348,6 +348,441 @@ void img2col_nchw4( | |||||
} | } | ||||
template <bool is_xcorr, typename dtype> | template <bool is_xcorr, typename dtype> | ||||
void img2col_nchw8( | |||||
const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC, | |||||
const int IH, const int IW, const int FH, const int FW, const int cur_index, | |||||
const int block_size) { | |||||
int start_h = cur_index / OW; | |||||
int cur_n_remain = cur_index % OW; | |||||
int end_h = (cur_index + block_size) / OW; | |||||
int end_n_remain = (cur_index + block_size) % OW; | |||||
bool same_line = (start_h == end_h); | |||||
int IC_div_8 = IC / 8; | |||||
if (sizeof(dtype) == 2) { | |||||
if (same_line) { | |||||
int dst_idx = 0; | |||||
rep(ic, IC_div_8) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2 = fh, fw2 = fw; | |||||
if (!is_xcorr) { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
//! TODO: Substitute GI for arm intrinsic when GI supports FP16 | |||||
//! data type. | |||||
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + | |||||
cur_n_remain + fw2); | |||||
for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||||
vst1q_f16( | |||||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
vld1q_f16( | |||||
reinterpret_cast<const __fp16*>(src) + | |||||
src_idx)); | |||||
dst_idx += 8; | |||||
src_idx += 8; | |||||
} | |||||
#else | |||||
int src_idx = 2 * (ic * IH * IW + (start_h + fh2) * IW + | |||||
cur_n_remain + fw2); | |||||
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||||
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||||
for (int w = cur_n_remain; w < end_n_remain; w++) { | |||||
u64_dst[dst_idx] = u64_src[src_idx]; | |||||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
dst_idx += 2; | |||||
src_idx += 2; | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
int dst_idx = 0; | |||||
rep(ic, IC_div_8) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2 = fh, fw2 = fw; | |||||
if (!is_xcorr) { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
int src_idx = 8 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 + | |||||
cur_n_remain); | |||||
for (int w = cur_n_remain; w < OW; ++w) { | |||||
vst1q_f16( | |||||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
vld1q_f16( | |||||
reinterpret_cast<const __fp16*>(src) + | |||||
src_idx)); | |||||
dst_idx += 8; | |||||
src_idx += 8; | |||||
} | |||||
src_idx = 8 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2); | |||||
for (int h = start_h + 1; h < end_h; ++h) { | |||||
int _src_idx = src_idx; | |||||
rep(w, OW) { | |||||
vst1q_f16( | |||||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
vld1q_f16( | |||||
reinterpret_cast<const __fp16*>(src) + | |||||
_src_idx)); | |||||
dst_idx += 8; | |||||
_src_idx += 8; | |||||
} | |||||
src_idx += IW * 8; | |||||
} | |||||
src_idx = 8 * (ic * IH * IW + (fh2 + end_h) * IW + fw2); | |||||
rep(w, end_n_remain) { | |||||
vst1q_f16( | |||||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
vld1q_f16( | |||||
reinterpret_cast<const __fp16*>(src) + | |||||
src_idx)); | |||||
dst_idx += 8; | |||||
src_idx += 8; | |||||
} | |||||
#else | |||||
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||||
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||||
int src_idx = 2 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 + | |||||
cur_n_remain); | |||||
for (int w = cur_n_remain; w < OW; ++w) { | |||||
u64_dst[dst_idx] = u64_src[src_idx]; | |||||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
dst_idx += 2; | |||||
src_idx += 2; | |||||
} | |||||
src_idx = 2 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2); | |||||
for (int h = start_h + 1; h < end_h; ++h) { | |||||
int _src_idx = src_idx; | |||||
rep(w, OW) { | |||||
u64_dst[dst_idx] = u64_src[_src_idx]; | |||||
u64_dst[dst_idx + 1] = u64_src[_src_idx + 1]; | |||||
dst_idx += 2; | |||||
_src_idx += 2; | |||||
} | |||||
src_idx += IW * 2; | |||||
} | |||||
src_idx = 2 * (ic * IH * IW + (fh2 + end_h) * IW + fw2); | |||||
rep(w, end_n_remain) { | |||||
u64_dst[dst_idx] = u64_src[src_idx]; | |||||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
dst_idx += 2; | |||||
src_idx += 2; | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
if (same_line) { | |||||
int dst_idx = 0; | |||||
rep(ic, IC_div_8) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2 = fh, fw2 = fw; | |||||
if (!is_xcorr) { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 + | |||||
cur_n_remain); | |||||
for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
int dst_idx = 0; | |||||
rep(ic, IC_div_8) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2 = fh, fw2 = fw; | |||||
if (!is_xcorr) { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 + | |||||
cur_n_remain); | |||||
for (int w = cur_n_remain; w < OW; ++w) { | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
} | |||||
src_idx = 8 * (ic * IH * IW + (start_h + 1 + fh2) * IW + fw2); | |||||
for (int h = start_h + 1; h < end_h; ++h) { | |||||
rep(w, OW) { | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
} | |||||
} | |||||
src_idx = 8 * (ic * IH * IW + (end_h + fh2) * IW + fw2); | |||||
rep(w, end_n_remain) { | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
dst[dst_idx++] = src[src_idx++]; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <bool is_xcorr, typename dtype> | |||||
void img2col_stride_nchw8( | |||||
const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC, | |||||
const int IH, const int IW, const int FH, const int FW, const int SH, | |||||
const int SW, const int cur_index, const int block_size) { | |||||
int start_h = cur_index / OW; | |||||
int cur_n_remain = cur_index % OW; | |||||
int end_h = (cur_index + block_size) / OW; | |||||
int end_n_remain = (cur_index + block_size) % OW; | |||||
bool same_line = (start_h == end_h); | |||||
int IC_div_8 = IC / 8; | |||||
if (sizeof(dtype) == 2) { | |||||
if (same_line) { | |||||
int dst_idx = 0; | |||||
rep(ic, IC_div_8) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2 = fh, fw2 = fw; | |||||
if (!is_xcorr) { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
cur_n_remain * SW + fw2); | |||||
for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||||
vst1q_f16( | |||||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
vld1q_f16( | |||||
reinterpret_cast<const __fp16*>(src) + | |||||
src_idx)); | |||||
dst_idx += 8; | |||||
src_idx += 8 * SW; | |||||
} | |||||
#else | |||||
int src_idx = 2 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
cur_n_remain * SW + fw2); | |||||
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||||
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||||
for (int w = cur_n_remain; w < end_n_remain; w++) { | |||||
u64_dst[dst_idx] = u64_src[src_idx]; | |||||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
dst_idx += 2; | |||||
src_idx += 2 * SW; | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
int dst_idx = 0; | |||||
rep(ic, IC_div_8) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2 = fh, fw2 = fw; | |||||
if (!is_xcorr) { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
int src_idx = 8 * (ic * IH * IW + (fh2 + start_h * SH) * IW + | |||||
fw2 + cur_n_remain * SW); | |||||
for (int w = cur_n_remain; w < OW; ++w) { | |||||
vst1q_f16( | |||||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
vld1q_f16( | |||||
reinterpret_cast<const __fp16*>(src) + | |||||
src_idx)); | |||||
dst_idx += 8; | |||||
src_idx += 8 * SW; | |||||
} | |||||
src_idx = 8 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW + | |||||
fw2); | |||||
for (int h = start_h + 1; h < end_h; ++h) { | |||||
int _src_idx = src_idx; | |||||
rep(w, OW) { | |||||
vst1q_f16( | |||||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
vld1q_f16( | |||||
reinterpret_cast<const __fp16*>(src) + | |||||
_src_idx)); | |||||
dst_idx += 8; | |||||
_src_idx += 8 * SW; | |||||
} | |||||
src_idx += IW * 8 * SH; | |||||
} | |||||
src_idx = 8 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2); | |||||
rep(w, end_n_remain) { | |||||
vst1q_f16( | |||||
reinterpret_cast<__fp16*>(dst) + dst_idx, | |||||
vld1q_f16( | |||||
reinterpret_cast<const __fp16*>(src) + | |||||
src_idx)); | |||||
dst_idx += 8; | |||||
src_idx += 8 * SW; | |||||
} | |||||
#else | |||||
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src); | |||||
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst); | |||||
int src_idx = 2 * (ic * IH * IW + (fh2 + start_h * SH) * IW + | |||||
fw2 + cur_n_remain * SW); | |||||
for (int w = cur_n_remain; w < OW; ++w) { | |||||
u64_dst[dst_idx] = u64_src[src_idx]; | |||||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
dst_idx += 2; | |||||
src_idx += 2 * SW; | |||||
} | |||||
src_idx = 2 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW + | |||||
fw2); | |||||
for (int h = start_h + 1; h < end_h; ++h) { | |||||
int _src_idx = src_idx; | |||||
rep(w, OW) { | |||||
u64_dst[dst_idx] = u64_src[_src_idx]; | |||||
u64_dst[dst_idx + 1] = u64_src[_src_idx + 1]; | |||||
dst_idx += 2; | |||||
_src_idx += 2 * SW; | |||||
} | |||||
src_idx += IW * 2 * SH; | |||||
} | |||||
src_idx = 2 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2); | |||||
rep(w, end_n_remain) { | |||||
u64_dst[dst_idx] = u64_src[src_idx]; | |||||
u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; | |||||
dst_idx += 2; | |||||
src_idx += 2 * SW; | |||||
} | |||||
#endif | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
if (same_line) { | |||||
int dst_idx = 0; | |||||
rep(ic, IC_div_8) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2 = fh, fw2 = fw; | |||||
if (!is_xcorr) { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
fw2 + cur_n_remain * SW); | |||||
for (int w = cur_n_remain; w < end_n_remain; ++w) { | |||||
dst[dst_idx++] = src[src_idx]; | |||||
dst[dst_idx++] = src[src_idx + 1]; | |||||
dst[dst_idx++] = src[src_idx + 2]; | |||||
dst[dst_idx++] = src[src_idx + 3]; | |||||
dst[dst_idx++] = src[src_idx + 4]; | |||||
dst[dst_idx++] = src[src_idx + 5]; | |||||
dst[dst_idx++] = src[src_idx + 6]; | |||||
dst[dst_idx++] = src[src_idx + 7]; | |||||
src_idx += 8 * SW; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
int dst_idx = 0; | |||||
rep(ic, IC_div_8) { | |||||
rep(fh, FH) { | |||||
rep(fw, FW) { | |||||
int fh2 = fh, fw2 = fw; | |||||
if (!is_xcorr) { | |||||
fh2 = FH - fh - 1; | |||||
fw2 = FW - fw - 1; | |||||
} | |||||
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + | |||||
fw2 + cur_n_remain * SW); | |||||
for (int w = cur_n_remain; w < OW; ++w) { | |||||
dst[dst_idx++] = src[src_idx]; | |||||
dst[dst_idx++] = src[src_idx + 1]; | |||||
dst[dst_idx++] = src[src_idx + 2]; | |||||
dst[dst_idx++] = src[src_idx + 3]; | |||||
dst[dst_idx++] = src[src_idx + 4]; | |||||
dst[dst_idx++] = src[src_idx + 5]; | |||||
dst[dst_idx++] = src[src_idx + 6]; | |||||
dst[dst_idx++] = src[src_idx + 7]; | |||||
src_idx += 8 * SW; | |||||
} | |||||
src_idx = 8 * (ic * IH * IW + ((start_h + 1) * SH + fh2) * IW + | |||||
fw2); | |||||
for (int h = start_h + 1; h < end_h; ++h) { | |||||
rep(w, OW) { | |||||
dst[dst_idx++] = src[src_idx]; | |||||
dst[dst_idx++] = src[src_idx + 1]; | |||||
dst[dst_idx++] = src[src_idx + 2]; | |||||
dst[dst_idx++] = src[src_idx + 3]; | |||||
dst[dst_idx++] = src[src_idx + 4]; | |||||
dst[dst_idx++] = src[src_idx + 5]; | |||||
dst[dst_idx++] = src[src_idx + 6]; | |||||
dst[dst_idx++] = src[src_idx + 7]; | |||||
src_idx += 8 * SW; | |||||
} | |||||
} | |||||
src_idx = 8 * (ic * IH * IW + (end_h * SH + fh2) * IW + fw2); | |||||
rep(w, end_n_remain) { | |||||
dst[dst_idx++] = src[src_idx]; | |||||
dst[dst_idx++] = src[src_idx + 1]; | |||||
dst[dst_idx++] = src[src_idx + 2]; | |||||
dst[dst_idx++] = src[src_idx + 3]; | |||||
dst[dst_idx++] = src[src_idx + 4]; | |||||
dst[dst_idx++] = src[src_idx + 5]; | |||||
dst[dst_idx++] = src[src_idx + 6]; | |||||
dst[dst_idx++] = src[src_idx + 7]; | |||||
src_idx += 8 * SW; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
template <bool is_xcorr, typename dtype> | |||||
void img2col_stride( | void img2col_stride( | ||||
const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH, | const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH, | ||||
const int OW, const int IC, const int IH, const int IW, const int FH, | const int OW, const int IC, const int IH, const int IW, const int FH, | ||||
@@ -68,6 +68,87 @@ void benchmark_impl( | |||||
multi_thread_config.nr_thread); | multi_thread_config.nr_thread); | ||||
} | } | ||||
} | } | ||||
void benchmark_with_contrast( | |||||
const std::vector<conv_bias::TestArg>& args, const std::string algo_name, | |||||
std::vector<DType>& data_type, | |||||
const std::vector<conv_bias::TestArg>& args_contrast, | |||||
const std::string algo_name_contrast, std::vector<DType>& data_type_contrast, | |||||
size_t RUNS, TaskExecutorConfig&& single_thread_config) { | |||||
auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config); | |||||
auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get()); | |||||
auto benchmarker_contrast = Benchmarker<ConvBias>(single_thread_handle.get()); | |||||
benchmarker.set_times(RUNS) | |||||
.set_display(false) | |||||
.set_dtype(0, data_type[0]) | |||||
.set_dtype(1, data_type[1]) | |||||
.set_dtype(2, data_type[2]) | |||||
.set_dtype(4, data_type[3]) | |||||
.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str())); | |||||
benchmarker_contrast.set_times(RUNS) | |||||
.set_display(false) | |||||
.set_dtype(0, data_type_contrast[0]) | |||||
.set_dtype(1, data_type_contrast[1]) | |||||
.set_dtype(2, data_type_contrast[2]) | |||||
.set_dtype(4, data_type_contrast[3]) | |||||
.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||||
algo_name_contrast.c_str())); | |||||
size_t arg_size = args.size(), arg_contrast_size = args_contrast.size(); | |||||
megdnn_assert(arg_size == arg_contrast_size); | |||||
rep(i, arg_size) { | |||||
TensorLayout dst_layout, dst_layout_contrast; | |||||
auto opr = single_thread_handle.get()->create_operator<ConvBias>(); | |||||
auto&& arg = args[i]; | |||||
opr->param() = arg.param; | |||||
opr->deduce_layout( | |||||
{arg.src, data_type[0]}, {arg.filter, data_type[1]}, | |||||
{arg.bias, data_type[2]}, {}, dst_layout); | |||||
float computation = (dst_layout.total_nr_elems() * arg.filter[1] * | |||||
arg.filter[2] * arg.filter[3] * arg.filter[4] * 2.0) / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
benchmarker.set_param(arg.param); | |||||
auto used = benchmarker.exec({arg.src, arg.filter, arg.bias, {}, {}}) / RUNS; | |||||
auto&& arg_contrast = args_contrast[i]; | |||||
opr->param() = arg_contrast.param; | |||||
opr->deduce_layout( | |||||
{arg_contrast.src, data_type_contrast[0]}, | |||||
{arg_contrast.filter, data_type_contrast[1]}, | |||||
{arg_contrast.bias, data_type_contrast[2]}, {}, dst_layout_contrast); | |||||
float computation_contrast = | |||||
(dst_layout_contrast.total_nr_elems() * arg_contrast.filter[1] * | |||||
arg_contrast.filter[2] * arg_contrast.filter[3] * | |||||
arg_contrast.filter[4] * 2.0) / | |||||
(1024 * 1024 * 1024) * 1e3; | |||||
benchmarker_contrast.set_param(arg_contrast.param); | |||||
auto used_contrast = benchmarker_contrast.exec( | |||||
{arg_contrast.src, | |||||
arg_contrast.filter, | |||||
arg_contrast.bias, | |||||
{}, | |||||
{}}) / | |||||
RUNS; | |||||
printf("Bench case: \n"); | |||||
printf("padding: %u, stride: %u, nonline mode: %u\n", arg.param.pad_h, | |||||
arg.param.stride_h, arg.param.nonlineMode); | |||||
printf("%s %s %s\n", arg.src.to_string().c_str(), | |||||
arg.filter.to_string().c_str(), arg.bias.to_string().c_str()); | |||||
printf("%s %s %s\n", arg_contrast.src.to_string().c_str(), | |||||
arg_contrast.filter.to_string().c_str(), | |||||
arg_contrast.bias.to_string().c_str()); | |||||
printf("%s: %f gflops;\n%s: %f gflops\n" | |||||
"spead up = %f\n", | |||||
algo_name.c_str(), computation / used, algo_name_contrast.c_str(), | |||||
computation_contrast / used_contrast, used_contrast / used); | |||||
} | |||||
} | |||||
} // namespace | } // namespace | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -1591,6 +1672,91 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_FP32) { | |||||
data_type); | data_type); | ||||
shapes_and_computation.clear(); | shapes_and_computation.clear(); | ||||
} | } | ||||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_NCHW44_VS_NCHW88) { | |||||
constexpr size_t RUNS = 50; | |||||
using NLMode = param::ConvBias::NonlineMode; | |||||
std::vector<conv_bias::TestArg> args_nchw88, args_nchw44; | |||||
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, | |||||
size_t group) { | |||||
param::ConvBias param_nchw88, param_nchw44; | |||||
param_nchw88.format = param::ConvBias::Format::NCHW88; | |||||
param_nchw44.format = param::ConvBias::Format::NCHW44; | |||||
for (size_t pad : {1, 2, 4}) { | |||||
for (size_t stride : {1, 2, 3}) { | |||||
for (auto nlmode : | |||||
{NLMode::RELU, NLMode::IDENTITY, NLMode::SIGMOID, | |||||
NLMode::H_SWISH}) { | |||||
param_nchw88.nonlineMode = nlmode; | |||||
param_nchw88.pad_h = pad; | |||||
param_nchw88.pad_w = pad; | |||||
param_nchw88.stride_h = stride; | |||||
param_nchw88.stride_w = stride; | |||||
param_nchw44.nonlineMode = nlmode; | |||||
param_nchw44.pad_h = pad; | |||||
param_nchw44.pad_w = pad; | |||||
param_nchw44.stride_h = stride; | |||||
param_nchw44.stride_w = stride; | |||||
args_nchw88.emplace_back( | |||||
param_nchw88, TensorShape{N, IC / 8, H, W, 8}, | |||||
TensorShape{OC / 8, IC / group / 8, FS, FS, 8, 8}, | |||||
TensorShape{1, OC / 8, 1, 1, 8}); | |||||
args_nchw44.emplace_back( | |||||
param_nchw44, TensorShape{N, IC / 4, H, W, 4}, | |||||
TensorShape{OC / 4, IC / group / 4, FS, FS, 4, 4}, | |||||
TensorShape{1, OC / 4, 1, 1, 4}); | |||||
} | |||||
} | |||||
} | |||||
}; | |||||
std::vector<DType> data_type_fp16 = { | |||||
dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()}; | |||||
std::vector<DType> data_type_fp32 = { | |||||
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; | |||||
bench_case(1, 32, 32, 300, 300, 3, 1); | |||||
bench_case(1, 32, 32, 400, 400, 3, 1); | |||||
bench_case(1, 32, 32, 100, 100, 3, 1); | |||||
bench_case(1, 32, 32, 80, 80, 3, 1); | |||||
bench_case(1, 32, 64, 200, 200, 3, 1); | |||||
bench_case(1, 32, 64, 128, 128, 3, 1); | |||||
bench_case(1, 32, 64, 100, 100, 3, 1); | |||||
bench_case(1, 32, 64, 80, 80, 3, 1); | |||||
bench_case(1, 32, 128, 200, 200, 3, 1); | |||||
bench_case(1, 32, 128, 128, 128, 3, 1); | |||||
bench_case(1, 32, 128, 100, 100, 3, 1); | |||||
bench_case(1, 32, 128, 80, 80, 3, 1); | |||||
bench_case(1, 64, 32, 7, 7, 3, 1); | |||||
bench_case(1, 64, 64, 7, 7, 3, 1); | |||||
bench_case(1, 64, 128, 7, 7, 3, 1); | |||||
bench_case(1, 64, 256, 7, 7, 3, 1); | |||||
bench_case(1, 64, 512, 7, 7, 3, 1); | |||||
bench_case(1, 64, 1024, 7, 7, 3, 1); | |||||
bench_case(1, 64, 32, 14, 14, 3, 1); | |||||
bench_case(1, 64, 64, 14, 14, 3, 1); | |||||
bench_case(1, 64, 128, 14, 14, 3, 1); | |||||
bench_case(1, 64, 256, 14, 14, 3, 1); | |||||
bench_case(1, 64, 512, 14, 14, 3, 1); | |||||
bench_case(1, 64, 1024, 14, 14, 3, 1); | |||||
bench_case(1, 128, 128, 14, 14, 3, 1); | |||||
bench_case(1, 128, 256, 14, 14, 3, 1); | |||||
bench_case(1, 512, 512, 14, 14, 3, 1); | |||||
bench_case(1, 256, 512, 14, 14, 3, 1); | |||||
bench_case(1, 512, 1024, 14, 14, 3, 1); | |||||
bench_case(1, 1024, 1024, 14, 14, 3, 1); | |||||
std::string algo_name_nchw88 = "IM2COLMATMUL:AARCH64_F16_MK8_16X12X1:96"; | |||||
std::string algo_name_nchw44 = "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1:96"; | |||||
benchmark_with_contrast( | |||||
args_nchw88, algo_name_nchw88, data_type_fp16, args_nchw44, | |||||
algo_name_nchw44, data_type_fp32, RUNS, {1, {4}}); | |||||
} | |||||
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, | ||||
BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) { | BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) { | ||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
@@ -362,6 +362,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) { | |||||
#endif | #endif | ||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_MK8_FP16) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = get_nchw88_conv_bias_args( | |||||
{2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); | |||||
auto args1 = get_nchw88_conv_bias_args( | |||||
{2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 2, 3); | |||||
args.insert(args.begin(), args1.begin(), args1.begin()); | |||||
args1 = get_nchw88_conv_bias_args( | |||||
{2, 3, 4, 5, 6, 7, 9}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 3, 4); | |||||
args.insert(args.begin(), args1.begin(), args1.begin()); | |||||
NormalRNG rng(1); | |||||
#define cb(name) \ | |||||
checker_conv_bias_common( \ | |||||
args, handle(), &rng, 0.03, dtype::Float16{}, dtype::Float16{}, \ | |||||
dtype::Float16{}, dtype::Float16{}, name); | |||||
#if MEGDNN_AARCH64 | |||||
cb("IM2COLMATMUL:AARCH64_F16_MK8_16X12X1"); | |||||
#endif | |||||
#undef cb | |||||
} | |||||
#endif | #endif | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
@@ -161,6 +161,24 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) { | |||||
run(); | run(); | ||||
} | } | ||||
TEST_F(ARM_COMMON, ELEMWISE_SIGMOID) { | |||||
using Mode = ElemwiseForward::Param::Mode; | |||||
Checker<ElemwiseForward> checker(handle()); | |||||
checker.set_epsilon(1e-3); | |||||
checker.set_dtype(0, dtype::Float16()); | |||||
checker.set_param(Mode::SIGMOID); | |||||
for (size_t n : {1, 2, 3}) { | |||||
for (size_t ic : {8, 16, 24, 32}) { | |||||
for (size_t ih : {5, 10, 15, 20, 21, 37}) { | |||||
for (size_t iw : {7, 9, 11, 13, 14, 20, 35}) { | |||||
checker.exec({{n, ic, ih, iw}, {}}); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { | TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { | ||||
using Mode = ElemwiseForward::Param::Mode; | using Mode = ElemwiseForward::Param::Mode; | ||||
Checker<ElemwiseForward> checker(handle()); | Checker<ElemwiseForward> checker(handle()); | ||||
@@ -98,6 +98,9 @@ TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_UNARY) { | |||||
BENCHMARK_CASES_INT(shape, dtype::Int16()); | BENCHMARK_CASES_INT(shape, dtype::Int16()); | ||||
BENCHMARK_CASES_INT(shape, dtype::Int8()); | BENCHMARK_CASES_INT(shape, dtype::Int8()); | ||||
BENCHMARK_CASES_FLOAT(shape, dtype::Float32()); | BENCHMARK_CASES_FLOAT(shape, dtype::Float32()); | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
BENCHMARK_CASES_FLOAT(shape, dtype::Float16()); | |||||
#endif | |||||
#undef BENCHMARK_CASES_INT | #undef BENCHMARK_CASES_INT | ||||
#undef BENCHMARK_CASES_FLOAT | #undef BENCHMARK_CASES_FLOAT | ||||
#undef RUN | #undef RUN | ||||
@@ -1580,17 +1580,19 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | ||||
std::vector<size_t> kernel_vec, | std::vector<size_t> kernel_vec, | ||||
std::vector<param::ConvBias::NonlineMode> nlmode_vec, | std::vector<param::ConvBias::NonlineMode> nlmode_vec, | ||||
std::vector<megdnn::BiasMode> biasmode_vec, size_t stride) { | |||||
std::vector<megdnn::BiasMode> biasmode_vec, size_t stride, int pad) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
using NLMode = param::ConvBias::NonlineMode; | using NLMode = param::ConvBias::NonlineMode; | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t kernel, | auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t kernel, | ||||
size_t stride, size_t group, NLMode nlmode, | |||||
size_t stride, int pad, size_t group, NLMode nlmode, | |||||
megdnn::BiasMode bias_mode) { | megdnn::BiasMode bias_mode) { | ||||
constexpr int pack_c = 8; | constexpr int pack_c = 8; | ||||
const size_t pad = kernel / 2; | |||||
if (pad == -1) { | |||||
pad = kernel / 2; | |||||
} | |||||
auto oc_per_group = oc / group; | auto oc_per_group = oc / group; | ||||
auto ic_per_group = ic / group; | auto ic_per_group = ic / group; | ||||
@@ -1651,8 +1653,8 @@ std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args( | |||||
if (kernel < h || kernel < w) { | if (kernel < h || kernel < w) { | ||||
continue; | continue; | ||||
} | } | ||||
pack(n, oc, ic, h, w, kernel, stride, group, | |||||
nlmode, bias); | |||||
pack(n, oc, ic, h, w, kernel, stride, pad, | |||||
group, nlmode, bias); | |||||
} | } | ||||
} | } | ||||
return args; | return args; | ||||