|
|
@@ -1862,7 +1862,8 @@ static inline bool nchw_nchwxx_valid( |
|
|
|
auto& src_node = new_inp[0]; |
|
|
|
auto& filter_node = new_inp[1]; |
|
|
|
auto dst_node = opr.output(0); |
|
|
|
if (filter_node->shape().ndim != 4) { |
|
|
|
//! already transformed or have fuse Z |
|
|
|
if (filter_node->shape().ndim != 4 || new_inp.size() == 4) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
megdnn::ConvolutionBase<megdnn::param::Convolution>::CanonizedFilterMeta fm; |
|
|
@@ -1884,7 +1885,8 @@ static inline bool nchw_nchwxx_valid( |
|
|
|
|
|
|
|
megdnn::ConvBiasForward::BiasMode bias_mode = |
|
|
|
megdnn::ConvBiasForward::BiasMode::NO_BIAS; |
|
|
|
if (std::is_same<OprType, opr::ConvBiasForward>::value) { |
|
|
|
if (std::is_same<OprType, opr::ConvBiasForward>::value && |
|
|
|
new_inp.size() > 2) { |
|
|
|
TensorShape bias_shape = new_inp[2]->shape(); |
|
|
|
if (bias_shape.ndim == 5) { |
|
|
|
bias_shape = nchwxx_shape_2_nchw_shape(bias_shape); |
|
|
@@ -2067,6 +2069,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { |
|
|
|
pack_c_size](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
mgb_assert(opr->input().size() <= 3, |
|
|
|
"nchwxx does not support conv_bias fuse Z right now"); |
|
|
|
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
mgb_assert(conv_bias_opr.param().format == |
|
|
|
megdnn::param::ConvBias::Format::NCHW, |
|
|
@@ -2092,7 +2096,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { |
|
|
|
temp_inp[0] = new_src.node(); |
|
|
|
} |
|
|
|
//! the bias is nchwxx |
|
|
|
if (temp_inp[2]->shape().ndim == 5) { |
|
|
|
if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) { |
|
|
|
auto new_bias = |
|
|
|
RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode); |
|
|
|
temp_inp[2] = new_bias.node(); |
|
|
@@ -2102,7 +2106,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { |
|
|
|
return new_opr; |
|
|
|
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], |
|
|
|
*conv_bias_bias = new_inp[2]; |
|
|
|
*conv_bias_bias = nullptr; |
|
|
|
//! filter trans to nchwxx mode |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
@@ -2117,21 +2121,34 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { |
|
|
|
src_to_nchwxx_mode); |
|
|
|
conv_bias_src = new_src.node(); |
|
|
|
} |
|
|
|
//! bias trans to nchwxx mode, bias may be scale |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make(new_inp[2], |
|
|
|
src_to_nchwxx_mode); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
//! bias trans to nchwxx mode |
|
|
|
if (new_inp.size() > 2) { |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], src_to_nchwxx_mode); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(new_inp[2]->shape().ndim == 5); |
|
|
|
conv_bias_bias = new_inp[2]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = conv_bias_format; |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 5 && |
|
|
|
conv_bias_filter->shape().ndim >= 6, |
|
|
|
"The conv_bias src dim is not trans to nchwxx"); |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
SymbolVar new_conv_bias_opr; |
|
|
|
if (conv_bias_bias) { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, |
|
|
|
new_param, conv_bias_opr.execution_policy(), |
|
|
|
conv_bias_opr.config()); |
|
|
|
} else { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, new_param, |
|
|
|
conv_bias_opr.execution_policy(), |
|
|
|
conv_bias_opr.config()); |
|
|
|
} |
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
"The conv_bias dst dim is not trans to nchwxx"); |
|
|
@@ -2139,25 +2156,37 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { |
|
|
|
} else { |
|
|
|
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], |
|
|
|
*conv_bias_bias = new_inp[2]; |
|
|
|
*conv_bias_bias = nullptr; |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], is_trans.second); |
|
|
|
conv_bias_filter = new_filter.node(); |
|
|
|
//! bias trans to nchwxx mode, bias may be scale |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make(new_inp[2], |
|
|
|
src_to_nchwxx_mode); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
if (new_inp.size() > 2) { |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], src_to_nchwxx_mode); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(new_inp[2]->shape().ndim == 5); |
|
|
|
conv_bias_bias = new_inp[2]; |
|
|
|
} |
|
|
|
} |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 4 && |
|
|
|
conv_bias_filter->shape().ndim == 5); |
|
|
|
mgb_assert((conv_bias_bias->shape().ndim == 5) || |
|
|
|
conv_bias_bias->shape().is_scalar()); |
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = conv_bias_format; |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
SymbolVar new_conv_bias_opr; |
|
|
|
if (conv_bias_bias) { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, |
|
|
|
new_param, conv_bias_opr.execution_policy(), |
|
|
|
conv_bias_opr.config()); |
|
|
|
} else { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, new_param, |
|
|
|
conv_bias_opr.execution_policy(), |
|
|
|
conv_bias_opr.config()); |
|
|
|
} |
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
"The conv dst dim is not trans to nchwxx"); |
|
|
@@ -2275,6 +2304,10 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { |
|
|
|
relayout_inp_to_nchw; |
|
|
|
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
replace_func[opr::Argmax::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
replace_func[opr::ImmutableTensor::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
@@ -2459,6 +2492,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
mgb_assert(opr->input().size() <= 3, |
|
|
|
"nchwxx-dot does not support conv_bias fuse Z right now"); |
|
|
|
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
mgb_assert(conv_bias_opr.param().format == |
|
|
|
megdnn::param::ConvBias::Format::NCHW, |
|
|
@@ -2489,7 +2524,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
} |
|
|
|
|
|
|
|
//! the bias is nchwxx |
|
|
|
if (temp_inp[2]->shape().ndim == 5) { |
|
|
|
if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], RelayoutMode::NCHW4_TO_NCHW); |
|
|
|
temp_inp[2] = new_bias.node(); |
|
|
@@ -2499,7 +2534,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
return new_opr; |
|
|
|
} else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) { |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], |
|
|
|
*conv_bias_bias = new_inp[2]; |
|
|
|
*conv_bias_bias = nullptr; |
|
|
|
//! filter trans to nchwxx mode |
|
|
|
mgb_assert(new_inp[1]->shape().ndim == 4 || |
|
|
|
new_inp[1]->shape().ndim == 5, |
|
|
@@ -2514,21 +2549,34 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
new_inp[0], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_bias_src = new_src.node(); |
|
|
|
} |
|
|
|
//! bias trans to nchwxx mode, bias may be scale |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
//! bias trans to nchwxx mode |
|
|
|
if (new_inp.size() > 2) { |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(new_inp[2]->shape().ndim == 5); |
|
|
|
conv_bias_bias = new_inp[2]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = is_trans.conv_format; |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 5 && |
|
|
|
conv_bias_filter->shape().ndim >= 6, |
|
|
|
"The conv_bias src dim is not trans to nchwxx"); |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
SymbolVar new_conv_bias_opr; |
|
|
|
if (conv_bias_bias) { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, |
|
|
|
new_param, conv_bias_opr.execution_policy(), |
|
|
|
conv_bias_opr.config()); |
|
|
|
} else { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, new_param, |
|
|
|
conv_bias_opr.execution_policy(), |
|
|
|
conv_bias_opr.config()); |
|
|
|
} |
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
"The conv_bias dst dim is not trans to nchwxx"); |
|
|
@@ -2536,25 +2584,37 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { |
|
|
|
} else { |
|
|
|
mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX); |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], |
|
|
|
*conv_bias_bias = new_inp[2]; |
|
|
|
*conv_bias_bias = nullptr; |
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], |
|
|
|
is_trans.relayout_mod); |
|
|
|
conv_bias_filter = new_filter.node(); |
|
|
|
//! bias trans to nchwxx mode, bias may be scale |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
if (new_inp.size() > 2) { |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make( |
|
|
|
new_inp[2], RelayoutMode::NCHW_TO_NCHW4); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(new_inp[2]->shape().ndim == 5); |
|
|
|
conv_bias_bias = new_inp[2]; |
|
|
|
} |
|
|
|
} |
|
|
|
mgb_assert(conv_bias_src->shape().ndim == 4 && |
|
|
|
conv_bias_filter->shape().ndim == 5); |
|
|
|
mgb_assert((conv_bias_bias->shape().ndim == 5) || |
|
|
|
conv_bias_bias->shape().is_scalar()); |
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
new_param.format = is_trans.conv_format; |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
SymbolVar new_conv_bias_opr; |
|
|
|
if (conv_bias_bias) { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, |
|
|
|
new_param, conv_bias_opr.execution_policy(), |
|
|
|
conv_bias_opr.config()); |
|
|
|
} else { |
|
|
|
new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
conv_bias_src, conv_bias_filter, new_param, |
|
|
|
conv_bias_opr.execution_policy(), |
|
|
|
conv_bias_opr.config()); |
|
|
|
} |
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
"The conv dst dim is not trans to nchwxx"); |
|
|
|