@@ -40,7 +40,8 @@ size_t ConvBiasImpl::AlgoConv1x1::get_oc_tile_size_heuristic( | |||
size_t OC = param.filter_meta.ocpg; | |||
if (OH * OW >= 56 * 56 || OC >= 64) | |||
return m_oc_block_size; | |||
return div_ceil(OC, param.nr_threads); | |||
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); | |||
return round_up<size_t>(oc_block_size_one_thread, 24); | |||
} | |||
size_t ConvBiasImpl::AlgoConv1x1::get_workspace( | |||
@@ -180,8 +181,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | |||
//! only support nchw format | |||
if (opr->param().format != param::ConvBias::Format::NCHW) | |||
if (opr->param().format != param::ConvBias::Format::NCHW && | |||
opr->param().format != param::ConvBias::Format::NCHW44) | |||
return false; | |||
size_t FH = param.filter_meta.spatial[0], | |||
@@ -218,8 +219,12 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
MatrixMulImpl::KernSizeParam matmul_param = | |||
get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); | |||
bool matmulusable = m_matmul_algo->usable(matmul_param); | |||
return matmulusable && | |||
if(opr->param().format == param::ConvBias::Format::NCHW44) | |||
matmul_param.format = param::MatrixMul::Format::MK4; | |||
bool matmul_usable = m_matmul_algo->usable(matmul_param); | |||
return matmul_usable && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
@@ -71,33 +71,32 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
const ConvBiasImpl::NCBKernSizeParam& param, | |||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
param::ConvBias::Format format) { | |||
MEGDNN_MARK_USED_VAR(format); | |||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
return std::make_unique< \ | |||
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
_postprocess_mode, _packmode>>(); \ | |||
} \ | |||
} \ | |||
size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4; | |||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
return std::make_unique< \ | |||
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
_postprocess_mode, _packmode>>(pack_size); \ | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
return std::make_unique< \ | |||
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
DTypeTrait<_i_bias_type>::ctype, \ | |||
DTypeTrait<_i_dst_type>::ctype, \ | |||
_postprocess_mode, _packmode>>(); \ | |||
} \ | |||
} \ | |||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||
midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
return std::make_unique< \ | |||
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
DTypeTrait<_i_bias_type>::ctype, \ | |||
DTypeTrait<_i_dst_type>::ctype, \ | |||
_postprocess_mode, _packmode>>(pack_size); \ | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
switch (pack_mode) { | |||
@@ -88,6 +88,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
megdnn::PostprocessMode postprocess_mode, MatrixMulImpl::AlgoBase::PackMode pack_mode> | |||
class Conv1x1Strategy : public Conv1x1StrategyBase { | |||
public: | |||
explicit Conv1x1Strategy(size_t pack_size = 1) : m_pack_size(pack_size) {} | |||
void packA(WorkspaceBundle& whole_bundle, | |||
WorkspaceBundle& matmul_bundle, | |||
size_t oc_tile_size, | |||
@@ -133,6 +135,9 @@ public: | |||
src_ctype* a_panel = reinterpret_cast<src_ctype*>( | |||
reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | |||
bytes_offset_of_a_panel); | |||
matmul_kern_param.LDA *= m_pack_size; | |||
matmul_kern_param.A_ptr = const_cast<src_ctype*>( | |||
ncb_param.filter<src_ctype>(group_id) + | |||
numbers_offset_of_filter); | |||
@@ -165,6 +170,8 @@ public: | |||
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | |||
get_matmul_kern_param(param, OH * OW, OC); | |||
matmul_kern_param.LDB *= m_pack_size; | |||
rep(batch, BATCH) { | |||
rep(g, GROUP) { | |||
if (SH == 2 && SW == 2) | |||
@@ -273,6 +280,8 @@ public: | |||
matmul_kern_param.C_ptr = matmul_dst; | |||
matmul_kern_param.LDC *= m_pack_size; | |||
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | |||
auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); | |||
matmul_kern(matmul_kern_param); | |||
@@ -291,11 +300,14 @@ public: | |||
else | |||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, | |||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
oc_end - oc_start, OH, OW); | |||
(oc_end - oc_start) / m_pack_size, OH, OW, m_pack_size); | |||
} | |||
private: | |||
size_t m_pack_size = 1; | |||
}; | |||
class Conv1x1Factory { | |||