|
|
@@ -60,19 +60,24 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, |
|
|
|
public: |
|
|
|
//! relayout type of this opr |
|
|
|
enum class LayoutType { |
|
|
|
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout |
|
|
|
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout |
|
|
|
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout |
|
|
|
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout |
|
|
|
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout |
|
|
|
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout |
|
|
|
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout |
|
|
|
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout |
|
|
|
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout |
|
|
|
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout |
|
|
|
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout |
|
|
|
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout |
|
|
|
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout |
|
|
|
NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose |
|
|
|
///< channel size less than 4 |
|
|
|
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout |
|
|
|
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout |
|
|
|
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout |
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 |
|
|
|
//!< layout |
|
|
|
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to |
|
|
|
//!< nchw4 layout |
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout |
|
|
|
//!< to nchw4 layout whose |
|
|
|
//! channel size less than 4 |
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 |
|
|
|
//!< layout |
|
|
@@ -177,11 +182,21 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
|
dst[3] = inp_shape[2]; |
|
|
|
dst[4] = inp_shape[4]; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4){ |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4 || |
|
|
|
layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW_TO_NCHW4_IC_SMALL_CONV) { |
|
|
|
if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) { |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
} else { |
|
|
|
mgb_assert(layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType:: |
|
|
|
NCHW_TO_NCHW4_IC_SMALL_CONV); |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4); |
|
|
|
} |
|
|
|
dst.ndim = 5; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[1] / 4; |
|
|
|
dst[1] = (inp_shape[1] + 4 - 1) / 4; |
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
dst[4] = 4; |
|
|
@@ -194,11 +209,23 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE) { |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE || |
|
|
|
layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType:: |
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV) { |
|
|
|
if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE) { |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
} else { |
|
|
|
mgb_assert(layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType:: |
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV); |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4); |
|
|
|
} |
|
|
|
|
|
|
|
dst.ndim = 5; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[1] / 4; |
|
|
|
dst[1] = (inp_shape[1] + 4 - 1) / 4; |
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
dst[4] = 4; |
|
|
@@ -427,6 +454,23 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
return y2.node(); |
|
|
|
}; |
|
|
|
|
|
|
|
reformat[LayoutType::NCHW_TO_NCHW4_IC_SMALL_CONV] = |
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto y = opr::RelayoutFormat::make( |
|
|
|
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); |
|
|
|
return y.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV] = |
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto y = opr::RelayoutFormat::make( |
|
|
|
x, megdnn::param::RelayoutFormat::Mode:: |
|
|
|
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT); |
|
|
|
return y.node(); |
|
|
|
}; |
|
|
|
|
|
|
|
reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
@@ -1367,29 +1411,40 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ |
|
|
|
RelayoutMode weight_to_nchw4_mode_group = |
|
|
|
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; |
|
|
|
|
|
|
|
auto trans_nchw4 = [weight_to_nchw4_mode_dense, |
|
|
|
weight_to_nchw4_mode_group]( |
|
|
|
struct ConvMode { |
|
|
|
RelayoutMode weight; |
|
|
|
RelayoutMode src; |
|
|
|
}; |
|
|
|
|
|
|
|
auto trans_nchw4 = |
|
|
|
[weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group, |
|
|
|
src_to_nchw4_mode]( |
|
|
|
const megdnn::param::Convolution::Sparse conv_mode, |
|
|
|
const VarNode* filter) -> RelayoutMode { |
|
|
|
const VarNode* filter) -> ConvMode { |
|
|
|
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { |
|
|
|
mgb_assert(filter->shape().ndim == 4, |
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
size_t IC = filter->shape()[1]; |
|
|
|
mgb_assert(IC % 4 == 0, |
|
|
|
"The input channel should be divisible by 4"); |
|
|
|
return weight_to_nchw4_mode_dense; |
|
|
|
if (IC < 4) { |
|
|
|
return {RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, |
|
|
|
RelayoutMode::NCHW_TO_NCHW4_IC_SMALL_CONV}; |
|
|
|
} else { |
|
|
|
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; |
|
|
|
} |
|
|
|
} else { |
|
|
|
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); |
|
|
|
mgb_assert(filter->shape().ndim == 5, |
|
|
|
"The origin filter if not NCHW mode"); |
|
|
|
size_t IC = filter->shape()[2]; |
|
|
|
mgb_assert(IC % 4 == 0, |
|
|
|
"The input channel should be divisible by 4"); |
|
|
|
return weight_to_nchw4_mode_group; |
|
|
|
"The input channel should be divisible by 4 for group " |
|
|
|
"conv"); |
|
|
|
return {weight_to_nchw4_mode_group, src_to_nchw4_mode}; |
|
|
|
} |
|
|
|
}; |
|
|
|
auto replace_conv_opr = [trans_nchw4, conv_format, src_to_nchw4_mode]( |
|
|
|
OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
|
auto replace_conv_opr = [trans_nchw4, conv_format]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); |
|
|
|
if (conv_opr.param().format != |
|
|
@@ -1397,18 +1452,19 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
opr->config()); |
|
|
|
} |
|
|
|
auto conv_mode = |
|
|
|
trans_nchw4(conv_opr.param().sparse, new_inp[1]); |
|
|
|
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; |
|
|
|
// src: NCHW --> NCWH4 |
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
|
auto new_src = RelayoutPlaceholder::make(new_inp[0], |
|
|
|
src_to_nchw4_mode); |
|
|
|
auto new_src = |
|
|
|
RelayoutPlaceholder::make(new_inp[0], conv_mode.src); |
|
|
|
conv_src = new_src.node(); |
|
|
|
} |
|
|
|
// weight: NCHW --> NCHW4 |
|
|
|
auto weight_mode = |
|
|
|
trans_nchw4(conv_opr.param().sparse, new_inp[1]); |
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], conv_mode.weight); |
|
|
|
conv_filter = new_filter.node(); |
|
|
|
// format: NCHW --> NCHW4 |
|
|
|
auto new_param = conv_opr.param(); |
|
|
@@ -1499,8 +1555,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ |
|
|
|
}; |
|
|
|
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, |
|
|
|
src_to_nchw4_mode]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
if (conv_bias_opr.param().format != |
|
|
@@ -1511,17 +1567,18 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ |
|
|
|
|
|
|
|
// what should be converted: src, weight |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; |
|
|
|
auto conv_mode = |
|
|
|
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); |
|
|
|
// src: NCHW --> NCHW4 |
|
|
|
if (new_inp[0]->shape().ndim !=5) { |
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
|
auto new_src = RelayoutPlaceholder::make(new_inp[0], |
|
|
|
src_to_nchw4_mode); |
|
|
|
auto new_src = |
|
|
|
RelayoutPlaceholder::make(new_inp[0], conv_mode.src); |
|
|
|
conv_bias_src = new_src.node(); |
|
|
|
} |
|
|
|
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 |
|
|
|
auto weight_mode = |
|
|
|
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); |
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); |
|
|
|
auto new_filter = |
|
|
|
RelayoutPlaceholder::make(new_inp[1], conv_mode.weight); |
|
|
|
conv_bias_filter = new_filter.node(); |
|
|
|
// format: NCHW --> NCHW4 |
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
@@ -1538,8 +1595,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ |
|
|
|
// bias: NCHW --> NCHW4 |
|
|
|
VarNode* conv_bias_bias = new_inp[2]; |
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
auto new_bias = RelayoutPlaceholder::make(new_inp[2], |
|
|
|
src_to_nchw4_mode); |
|
|
|
auto new_bias = |
|
|
|
RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode); |
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
} |
|
|
|
if (new_inp.size() == 3) { |
|
|
@@ -1554,8 +1611,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ |
|
|
|
// z_inp: NCHW --> NCHW4 |
|
|
|
VarNode* z_inp = new_inp[3]; |
|
|
|
if (new_inp[3]->shape().ndim == 4) { |
|
|
|
auto new_z = RelayoutPlaceholder::make(new_inp[3], |
|
|
|
src_to_nchw4_mode); |
|
|
|
auto new_z = |
|
|
|
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode); |
|
|
|
z_inp = new_z.node(); |
|
|
|
} |
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, |
|
|
|