|
|
@@ -2050,23 +2050,27 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
//! First is whether the conv can trans to nchwxx, second is the filter |
|
|
|
//! trans mode |
|
|
|
using RelayoutMode = RelayoutPlaceholder::LayoutType; |
|
|
|
using TestTransResult = std::pair<TransType, RelayoutMode>; |
|
|
|
megdnn::param::ConvolutionV0::Format conv_dot_format = |
|
|
|
megdnn::param::ConvBias::Format::NCHW44_DOT; |
|
|
|
struct TestTransResult { |
|
|
|
TransType trans_type; |
|
|
|
RelayoutMode relayout_mod; |
|
|
|
megdnn::param::ConvolutionV0::Format conv_format; |
|
|
|
}; |
|
|
|
constexpr size_t pack_c_size = 4_z; |
|
|
|
auto test_trans_nchw44_dot = |
|
|
|
[](const megdnn::param::Convolution::Sparse conv_mode, |
|
|
|
const VarNode* filter) -> TestTransResult { |
|
|
|
TestTransResult ret{TransType::TRANS_NONE, {}}; |
|
|
|
TestTransResult ret{TransType::TRANS_NONE, {}, {}}; |
|
|
|
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { |
|
|
|
size_t IC = filter->shape()[1]; |
|
|
|
size_t OC = filter->shape()[0]; |
|
|
|
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { |
|
|
|
ret.first = TransType::TRANS_PURE_NCHWXX; |
|
|
|
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; |
|
|
|
ret.trans_type = TransType::TRANS_PURE_NCHWXX; |
|
|
|
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; |
|
|
|
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; |
|
|
|
} else if (IC < pack_c_size && OC % pack_c_size == 0) { |
|
|
|
ret.first = TransType::TRANS_HYBIRD_NCHWXX; |
|
|
|
ret.second = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; |
|
|
|
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; |
|
|
|
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; |
|
|
|
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; |
|
|
|
} |
|
|
|
} else { |
|
|
|
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); |
|
|
@@ -2074,15 +2078,18 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
size_t ocpg = filter->shape()[1]; |
|
|
|
size_t icpg = filter->shape()[2]; |
|
|
|
if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) { |
|
|
|
ret.first = TransType::TRANS_NONE; |
|
|
|
ret.trans_type = TransType::TRANS_PURE_NCHWXX; |
|
|
|
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; |
|
|
|
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; |
|
|
|
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { |
|
|
|
ret.first = TransType::TRANS_PURE_NCHWXX; |
|
|
|
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; |
|
|
|
ret.trans_type = TransType::TRANS_PURE_NCHWXX; |
|
|
|
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; |
|
|
|
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; |
|
|
|
} |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
auto replace_conv_opr = [test_trans_nchw44_dot, conv_dot_format]( |
|
|
|
auto replace_conv_opr = [test_trans_nchw44_dot]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
@@ -2094,7 +2101,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
auto is_trans = |
|
|
|
test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]); |
|
|
|
//! can not trans to nchwxx |
|
|
|
if (is_trans.first == TransType::TRANS_NONE) { |
|
|
|
if (is_trans.trans_type == TransType::TRANS_NONE) { |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
@@ -2108,14 +2115,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, |
|
|
|
opr->config()); |
|
|
|
return new_opr; |
|
|
|
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { |
|
|
|
} else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) { |
|
|
|
//! filter trans to nchwxx mode |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], |
|
|
|
is_trans.relayout_mod); |
|
|
|
conv_filter = new_filter.node(); |
|
|
|
//! src trans to nchwxx mode |
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
@@ -2125,7 +2132,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
conv_src = new_src.node(); |
|
|
|
} |
|
|
|
auto new_param = conv_opr.param(); |
|
|
|
new_param.format = conv_dot_format; |
|
|
|
new_param.format = is_trans.conv_format; |
|
|
|
mgb_assert(conv_src->shape().ndim == 5 && |
|
|
|
conv_filter->shape().ndim >= 6, |
|
|
|
"The conv src dim is not trans to nchwxx"); |
|
|
@@ -2137,16 +2144,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
"The conv dst dim is not trans to nchwxx"); |
|
|
|
return new_opr; |
|
|
|
} else { |
|
|
|
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); |
|
|
|
mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX); |
|
|
|
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], |
|
|
|
is_trans.relayout_mod); |
|
|
|
conv_filter = new_filter.node(); |
|
|
|
mgb_assert(conv_src->shape().ndim == 4 && |
|
|
|
conv_filter->shape().ndim == 5, |
|
|
|
"The src and filter is OK"); |
|
|
|
auto new_param = conv_opr.param(); |
|
|
|
new_param.format = conv_dot_format; |
|
|
|
new_param.format = is_trans.conv_format; |
|
|
|
auto new_conv_opr = opr::Convolution::make( |
|
|
|
conv_src, conv_filter, new_param, |
|
|
|
conv_opr.execution_policy(), conv_opr.config()); |
|
|
@@ -2157,7 +2164,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_conv_bias_opr = [test_trans_nchw44_dot, conv_dot_format]( |
|
|
|
auto replace_conv_bias_opr = [test_trans_nchw44_dot]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
@@ -2168,7 +2175,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
auto is_trans = |
|
|
|
test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]); |
|
|
|
//! can not trans to nchwxx |
|
|
|
if (is_trans.first == TransType::TRANS_NONE) { |
|
|
|
if (is_trans.trans_type == TransType::TRANS_NONE) { |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
@@ -2188,15 +2195,15 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, |
|
|
|
opr->config()); |
|
|
|
return new_opr; |
|
|
|
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { |
|
|
|
} else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) { |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], |
|
|
|
*conv_bias_bias = new_inp[2]; |
|
|
|
//! filter trans to nchwxx mode |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], |
|
|
|
is_trans.relayout_mod); |
|
|
|
conv_bias_filter = new_filter.node(); |
|
|
|
//! src trans to nchwxx mode |
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
@@ -2213,7 +2220,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
} |
|
|
|
|
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = conv_dot_format; |
|
|
|
new_param.format = is_trans.conv_format; |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 5 && |
|
|
|
conv_bias_filter->shape().ndim >= 6, |
|
|
|
"The conv_bias src dim is not trans to nchwxx"); |
|
|
@@ -2225,11 +2232,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
"The conv_bias dst dim is not trans to nchwxx"); |
|
|
|
return new_opr; |
|
|
|
} else { |
|
|
|
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); |
|
|
|
mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX); |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], |
|
|
|
*conv_bias_bias = new_inp[2]; |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], |
|
|
|
is_trans.relayout_mod); |
|
|
|
conv_bias_filter = new_filter.node(); |
|
|
|
//! bias trans to nchwxx mode, bias may be scale |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
@@ -2242,7 +2249,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
mgb_assert((conv_bias_bias->shape().ndim == 5) || |
|
|
|
conv_bias_bias->shape().is_scalar()); |
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = conv_dot_format; |
|
|
|
new_param.format = is_trans.conv_format; |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|