@@ -40,7 +40,8 @@ size_t ConvBiasImpl::AlgoConv1x1::get_oc_tile_size_heuristic( | |||||
size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
if (OH * OW >= 56 * 56 || OC >= 64) | if (OH * OW >= 56 * 56 || OC >= 64) | ||||
return m_oc_block_size; | return m_oc_block_size; | ||||
return div_ceil(OC, param.nr_threads); | |||||
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); | |||||
return round_up<size_t>(oc_block_size_one_thread, 24); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoConv1x1::get_workspace( | size_t ConvBiasImpl::AlgoConv1x1::get_workspace( | ||||
@@ -180,8 +181,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { | ||||
//! only support nchw format | |||||
if (opr->param().format != param::ConvBias::Format::NCHW) | |||||
if (opr->param().format != param::ConvBias::Format::NCHW && | |||||
opr->param().format != param::ConvBias::Format::NCHW44) | |||||
return false; | return false; | ||||
size_t FH = param.filter_meta.spatial[0], | size_t FH = param.filter_meta.spatial[0], | ||||
@@ -218,8 +219,12 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||||
MatrixMulImpl::KernSizeParam matmul_param = | MatrixMulImpl::KernSizeParam matmul_param = | ||||
get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); | get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param)); | ||||
bool matmulusable = m_matmul_algo->usable(matmul_param); | |||||
return matmulusable && | |||||
if(opr->param().format == param::ConvBias::Format::NCHW44) | |||||
matmul_param.format = param::MatrixMul::Format::MK4; | |||||
bool matmul_usable = m_matmul_algo->usable(matmul_param); | |||||
return matmul_usable && | |||||
(param.filter_meta.dilation[0] == | (param.filter_meta.dilation[0] == | ||||
param.filter_meta.dilation[1] && | param.filter_meta.dilation[1] && | ||||
param.filter_meta.dilation[0] == 1) && | param.filter_meta.dilation[0] == 1) && | ||||
@@ -71,33 +71,32 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||||
const ConvBiasImpl::NCBKernSizeParam& param, | const ConvBiasImpl::NCBKernSizeParam& param, | ||||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | MatrixMulImpl::AlgoBase::PackMode pack_mode, | ||||
param::ConvBias::Format format) { | param::ConvBias::Format format) { | ||||
MEGDNN_MARK_USED_VAR(format); | |||||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||||
_postprocess_mode, _packmode>>(); \ | |||||
} \ | |||||
} \ | |||||
size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4; | |||||
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||||
_postprocess_mode, _packmode>>(pack_size); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END() | MIDOUT_END() | ||||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
DTypeTrait<_i_bias_type>::ctype, \ | |||||
DTypeTrait<_i_dst_type>::ctype, \ | |||||
_postprocess_mode, _packmode>>(); \ | |||||
} \ | |||||
} \ | |||||
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ | |||||
midout_iv(_midout_tag)) { \ | |||||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||||
return std::make_unique< \ | |||||
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ | |||||
DTypeTrait<_i_bias_type>::ctype, \ | |||||
DTypeTrait<_i_dst_type>::ctype, \ | |||||
_postprocess_mode, _packmode>>(pack_size); \ | |||||
} \ | |||||
} \ | |||||
MIDOUT_END() | MIDOUT_END() | ||||
switch (pack_mode) { | switch (pack_mode) { | ||||
@@ -88,6 +88,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||||
megdnn::PostprocessMode postprocess_mode, MatrixMulImpl::AlgoBase::PackMode pack_mode> | megdnn::PostprocessMode postprocess_mode, MatrixMulImpl::AlgoBase::PackMode pack_mode> | ||||
class Conv1x1Strategy : public Conv1x1StrategyBase { | class Conv1x1Strategy : public Conv1x1StrategyBase { | ||||
public: | public: | ||||
explicit Conv1x1Strategy(size_t pack_size = 1) : m_pack_size(pack_size) {} | |||||
void packA(WorkspaceBundle& whole_bundle, | void packA(WorkspaceBundle& whole_bundle, | ||||
WorkspaceBundle& matmul_bundle, | WorkspaceBundle& matmul_bundle, | ||||
size_t oc_tile_size, | size_t oc_tile_size, | ||||
@@ -133,6 +135,9 @@ public: | |||||
src_ctype* a_panel = reinterpret_cast<src_ctype*>( | src_ctype* a_panel = reinterpret_cast<src_ctype*>( | ||||
reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | ||||
bytes_offset_of_a_panel); | bytes_offset_of_a_panel); | ||||
matmul_kern_param.LDA *= m_pack_size; | |||||
matmul_kern_param.A_ptr = const_cast<src_ctype*>( | matmul_kern_param.A_ptr = const_cast<src_ctype*>( | ||||
ncb_param.filter<src_ctype>(group_id) + | ncb_param.filter<src_ctype>(group_id) + | ||||
numbers_offset_of_filter); | numbers_offset_of_filter); | ||||
@@ -165,6 +170,8 @@ public: | |||||
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | ||||
get_matmul_kern_param(param, OH * OW, OC); | get_matmul_kern_param(param, OH * OW, OC); | ||||
matmul_kern_param.LDB *= m_pack_size; | |||||
rep(batch, BATCH) { | rep(batch, BATCH) { | ||||
rep(g, GROUP) { | rep(g, GROUP) { | ||||
if (SH == 2 && SW == 2) | if (SH == 2 && SW == 2) | ||||
@@ -273,6 +280,8 @@ public: | |||||
matmul_kern_param.C_ptr = matmul_dst; | matmul_kern_param.C_ptr = matmul_dst; | ||||
matmul_kern_param.LDC *= m_pack_size; | |||||
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | ||||
auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); | auto matmul_kern = matmul_algo->get_kern(matmul_kern_param); | ||||
matmul_kern(matmul_kern_param); | matmul_kern(matmul_kern_param); | ||||
@@ -291,11 +300,14 @@ public: | |||||
else | else | ||||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | ||||
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | ||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | ||||
matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, | matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode, | ||||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | param.nonlineMode, param.bias_type, param.dst_type, 1_z, | ||||
oc_end - oc_start, OH, OW); | |||||
(oc_end - oc_start) / m_pack_size, OH, OW, m_pack_size); | |||||
} | } | ||||
private: | |||||
size_t m_pack_size = 1; | |||||
}; | }; | ||||
class Conv1x1Factory { | class Conv1x1Factory { | ||||