|
@@ -63,10 +63,15 @@ public: |
|
|
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout |
|
|
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout |
|
|
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout |
|
|
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout |
|
|
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout |
|
|
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout |
|
|
|
|
|
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout |
|
|
|
|
|
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout |
|
|
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout |
|
|
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout |
|
|
NCHW_TO_NCHW44, //!< from nchw layout to nchw44 layout |
|
|
|
|
|
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout |
|
|
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout |
|
|
NCHW44_TO_NCHW, //!< from nchw44 layout to nchw layout |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 |
|
|
|
|
|
//!< layout |
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to |
|
|
|
|
|
//!< nchw4 layout |
|
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 |
|
|
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 |
|
|
//!< layout |
|
|
//!< layout |
|
@@ -167,6 +172,42 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
dst[3] = inp_shape[2]; |
|
|
dst[3] = inp_shape[2]; |
|
|
dst[4] = inp_shape[4]; |
|
|
dst[4] = inp_shape[4]; |
|
|
} else if (layout_type() == |
|
|
} else if (layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4){ |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
|
|
dst.ndim = 5; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1] / 4; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
dst[4] = 4; |
|
|
|
|
|
} else if (layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW){ |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); |
|
|
|
|
|
dst.ndim = 4; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1] * 4; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW4_DENSE) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
|
|
dst.ndim = 5; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1] / 4; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
dst[4] = 4; |
|
|
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW4_GROUP) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[2] % 4 == 0); |
|
|
|
|
|
dst.ndim = 6; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1]; |
|
|
|
|
|
dst[2] = inp_shape[2] / 4; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
dst[4] = inp_shape[4]; |
|
|
|
|
|
dst[5] = 4; |
|
|
|
|
|
}else if (layout_type() == |
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) { |
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) { |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0); |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0); |
|
|
dst.ndim = 5; |
|
|
dst.ndim = 5; |
|
@@ -226,23 +267,6 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
dst[2] = inp_shape[3]; |
|
|
dst[2] = inp_shape[3]; |
|
|
dst[3] = inp_shape[1]; |
|
|
dst[3] = inp_shape[1]; |
|
|
dst[4] = 8; |
|
|
dst[4] = 8; |
|
|
} else if (layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW44) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
|
|
dst.ndim = 5; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1] / 4; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
dst[4] = 4; |
|
|
|
|
|
} else if (layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); |
|
|
|
|
|
dst.ndim = 4; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1] * 4; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
WEIGHT_NCHW_TO_NCHW44_DENSE) { |
|
|
WEIGHT_NCHW_TO_NCHW44_DENSE) { |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 && |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 && |
|
@@ -394,6 +418,66 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
return y2.node(); |
|
|
return y2.node(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); |
|
|
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); |
|
|
|
|
|
auto y1 = opr::Reshape::make(y0, tshp0); |
|
|
|
|
|
return y1.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1), sub(2) / 4, cv(4), sub(3), sub(4)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1), sub(2) / 4, sub(3), sub(4), cv(4)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 2, 4, 5, 3}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* { |
|
|
reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* { |
|
|
auto x = SymbolVar(inp); |
|
|
auto x = SymbolVar(inp); |
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
auto xshp = opr::GetVarShape::make(x); |
|
@@ -492,34 +576,6 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
return y2.node(); |
|
|
return y2.node(); |
|
|
}; |
|
|
}; |
|
|
reformat[LayoutType::NCHW_TO_NCHW44] = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::NCHW44_TO_NCHW] = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); |
|
|
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); |
|
|
|
|
|
auto y1 = opr::Reshape::make(y0, tshp0); |
|
|
|
|
|
return y1.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] = |
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] = |
|
|
[](VarNode* inp) -> VarNode* { |
|
|
[](VarNode* inp) -> VarNode* { |
|
|
auto x = SymbolVar(inp); |
|
|
auto x = SymbolVar(inp); |
|
@@ -1239,6 +1295,293 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/* ================ EnableNCHW4Pass ================ */ |
|
|
|
|
|
VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, |
|
|
|
|
|
VarNode* orig_var) const { |
|
|
|
|
|
if (!orig_var->shape().eq_shape(new_var->shape())) { |
|
|
|
|
|
return RelayoutPlaceholder::make( |
|
|
|
|
|
new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) |
|
|
|
|
|
.node(); |
|
|
|
|
|
} |
|
|
|
|
|
return new_var; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ |
|
|
|
|
|
auto ret = std::make_unique<EnableNCHW4Pass>(); |
|
|
|
|
|
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); |
|
|
|
|
|
using RelayoutMode = RelayoutPlaceholder::LayoutType; |
|
|
|
|
|
megdnn::param::Convolution::Format conv_format = |
|
|
|
|
|
megdnn::param::Convolution::Format::NCHW4; |
|
|
|
|
|
megdnn::param::ConvBias::Format conv_bias_format = |
|
|
|
|
|
megdnn::param::ConvBias::Format::NCHW4; |
|
|
|
|
|
megdnn::param::BatchConvBias::Format batch_conv_bias_format = |
|
|
|
|
|
megdnn::param::BatchConvBias::Format::NCHW4; |
|
|
|
|
|
RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4; |
|
|
|
|
|
RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; |
|
|
|
|
|
RelayoutMode weight_to_nchw4_mode_dense = |
|
|
|
|
|
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE; |
|
|
|
|
|
RelayoutMode weight_to_nchw4_mode_group = |
|
|
|
|
|
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; |
|
|
|
|
|
auto trans_nchw4 = [weight_to_nchw4_mode_dense, |
|
|
|
|
|
weight_to_nchw4_mode_group]( |
|
|
|
|
|
const megdnn::param::Convolution::Sparse conv_mode, |
|
|
|
|
|
const VarNode* filter) -> RelayoutMode { |
|
|
|
|
|
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { |
|
|
|
|
|
mgb_assert(filter->shape().ndim == 4, |
|
|
|
|
|
"The origin filter is not NCHW mode"); |
|
|
|
|
|
size_t IC = filter->shape()[1]; |
|
|
|
|
|
mgb_assert(IC % 4 == 0, |
|
|
|
|
|
"The input channel should be divisible by 4"); |
|
|
|
|
|
return weight_to_nchw4_mode_dense; |
|
|
|
|
|
} else { |
|
|
|
|
|
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); |
|
|
|
|
|
mgb_assert(filter->shape().ndim == 5, |
|
|
|
|
|
"The origin filter if not NCHW mode"); |
|
|
|
|
|
size_t IC = filter->shape()[2]; |
|
|
|
|
|
mgb_assert(IC % 4 == 0, |
|
|
|
|
|
"The input channel should be divisible by 4"); |
|
|
|
|
|
return weight_to_nchw4_mode_group; |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
auto replace_conv_opr = [trans_nchw4, conv_format, src_to_nchw4_mode]( |
|
|
|
|
|
OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); |
|
|
|
|
|
mgb_assert(conv_opr.param().format == |
|
|
|
|
|
megdnn::param::Convolution::Format::NCHW, |
|
|
|
|
|
"ConvertFormat Pass only support converting NCHW to NCHW4"); |
|
|
|
|
|
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; |
|
|
|
|
|
// src: NCHW --> NCWH4 |
|
|
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
|
|
|
auto new_src = RelayoutPlaceholder::make(new_inp[0], |
|
|
|
|
|
src_to_nchw4_mode); |
|
|
|
|
|
conv_src = new_src.node(); |
|
|
|
|
|
} |
|
|
|
|
|
// weight: NCHW --> NCHW4 |
|
|
|
|
|
auto weight_mode = |
|
|
|
|
|
trans_nchw4(conv_opr.param().sparse, new_inp[1]); |
|
|
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); |
|
|
|
|
|
conv_filter = new_filter.node(); |
|
|
|
|
|
// format: NCHW --> NCHW4 |
|
|
|
|
|
auto new_param = conv_opr.param(); |
|
|
|
|
|
new_param.format = conv_format; |
|
|
|
|
|
// dst |
|
|
|
|
|
auto new_conv_opr = opr::Convolution::make( |
|
|
|
|
|
conv_src, conv_filter, new_param, |
|
|
|
|
|
conv_opr.execution_policy(), conv_opr.config()); |
|
|
|
|
|
OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); |
|
|
|
|
|
mgb_assert(new_conv_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv dst dim is not trans to nchw4"); |
|
|
|
|
|
return new_opr; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto replace_batch_conv_bias_opr = [batch_conv_bias_format, |
|
|
|
|
|
src_to_nchw4_mode]( |
|
|
|
|
|
OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
auto& batch_conv_bias_opr = |
|
|
|
|
|
opr->cast_final_safe<opr::BatchConvBiasForward>(); |
|
|
|
|
|
mgb_assert(batch_conv_bias_opr.param().format == |
|
|
|
|
|
megdnn::param::BatchConvBias::Format::NCHW, |
|
|
|
|
|
"ConvertFormat Pass only support converting NCHW to NCHW4"); |
|
|
|
|
|
// what should be converted: src, weight |
|
|
|
|
|
VarNode *src = new_inp[0], *filter = new_inp[1]; |
|
|
|
|
|
// src: NCHW --> NCHW4 |
|
|
|
|
|
if (new_inp[0]->shape().ndim !=5) { |
|
|
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
|
|
|
auto new_src = RelayoutPlaceholder::make(new_inp[0], |
|
|
|
|
|
src_to_nchw4_mode); |
|
|
|
|
|
src = new_src.node(); |
|
|
|
|
|
} |
|
|
|
|
|
// weight: BNCHW --> BNCHW4 |
|
|
|
|
|
// only support dense mode, which is similar with conv->group. |
|
|
|
|
|
auto weight_mode = |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP; |
|
|
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); |
|
|
|
|
|
filter = new_filter.node(); |
|
|
|
|
|
// format: NCHW --> NCHW4 |
|
|
|
|
|
auto new_param = batch_conv_bias_opr.param(); |
|
|
|
|
|
new_param.format = batch_conv_bias_format; |
|
|
|
|
|
if (new_inp.size() == 2) { |
|
|
|
|
|
auto dst = opr::BatchConvBias::make(src, filter, new_param, |
|
|
|
|
|
batch_conv_bias_opr.execution_policy(), |
|
|
|
|
|
batch_conv_bias_opr.config()); |
|
|
|
|
|
OperatorNodeBase* new_opr = dst.node()->owner_opr(); |
|
|
|
|
|
mgb_assert(dst.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
return new_opr; |
|
|
|
|
|
} |
|
|
|
|
|
// bias: NCHW --> NCHW4 |
|
|
|
|
|
VarNode* bias = new_inp[2]; |
|
|
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
|
|
auto new_bias = RelayoutPlaceholder::make(new_inp[2], |
|
|
|
|
|
src_to_nchw4_mode); |
|
|
|
|
|
bias = new_bias.node(); |
|
|
|
|
|
} |
|
|
|
|
|
if (new_inp.size() == 3) { |
|
|
|
|
|
auto dst = opr::BatchConvBias::make(src, filter, bias, new_param, |
|
|
|
|
|
batch_conv_bias_opr.execution_policy(), |
|
|
|
|
|
batch_conv_bias_opr.config()); |
|
|
|
|
|
OperatorNodeBase* new_opr = dst.node()->owner_opr(); |
|
|
|
|
|
mgb_assert(dst.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
return new_opr; |
|
|
|
|
|
} |
|
|
|
|
|
// z_inp: NCHW --> NCHW4 |
|
|
|
|
|
VarNode* z_inp = new_inp[3]; |
|
|
|
|
|
if (new_inp[3]->shape().ndim == 4) { |
|
|
|
|
|
auto new_z = RelayoutPlaceholder::make(new_inp[3], |
|
|
|
|
|
src_to_nchw4_mode); |
|
|
|
|
|
z_inp = new_z.node(); |
|
|
|
|
|
} |
|
|
|
|
|
auto dst = opr::BatchConvBias::make(src, filter, bias, z_inp, |
|
|
|
|
|
new_param,batch_conv_bias_opr.execution_policy(), |
|
|
|
|
|
batch_conv_bias_opr.config()); |
|
|
|
|
|
OperatorNodeBase* new_opr = dst.node()->owner_opr(); |
|
|
|
|
|
mgb_assert(dst.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
return new_opr; |
|
|
|
|
|
}; |
|
|
|
|
|
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, |
|
|
|
|
|
src_to_nchw4_mode]( |
|
|
|
|
|
OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); |
|
|
|
|
|
mgb_assert(conv_bias_opr.param().format == |
|
|
|
|
|
megdnn::param::ConvBias::Format::NCHW, |
|
|
|
|
|
"ConvertFormat Pass only support converting NCHW to NCHW4"); |
|
|
|
|
|
// what should be converted: src, weight |
|
|
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; |
|
|
|
|
|
// src: NCHW --> NCHW4 |
|
|
|
|
|
if (new_inp[0]->shape().ndim !=5) { |
|
|
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
|
|
|
auto new_src = RelayoutPlaceholder::make(new_inp[0], |
|
|
|
|
|
src_to_nchw4_mode); |
|
|
|
|
|
conv_bias_src = new_src.node(); |
|
|
|
|
|
} |
|
|
|
|
|
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 |
|
|
|
|
|
auto weight_mode = |
|
|
|
|
|
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); |
|
|
|
|
|
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode); |
|
|
|
|
|
conv_bias_filter = new_filter.node(); |
|
|
|
|
|
// format: NCHW --> NCHW4 |
|
|
|
|
|
auto new_param = conv_bias_opr.param(); |
|
|
|
|
|
new_param.format = conv_bias_format; |
|
|
|
|
|
if (new_inp.size() == 2) { |
|
|
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
|
|
conv_bias_src, conv_bias_filter, new_param, |
|
|
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
return new_opr; |
|
|
|
|
|
} |
|
|
|
|
|
// bias: NCHW --> NCHW4 |
|
|
|
|
|
VarNode* conv_bias_bias = new_inp[2]; |
|
|
|
|
|
if (new_inp[2]->shape().ndim == 4) { |
|
|
|
|
|
auto new_bias = RelayoutPlaceholder::make(new_inp[2], |
|
|
|
|
|
src_to_nchw4_mode); |
|
|
|
|
|
conv_bias_bias = new_bias.node(); |
|
|
|
|
|
} |
|
|
|
|
|
if (new_inp.size() == 3) { |
|
|
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
return new_opr; |
|
|
|
|
|
} |
|
|
|
|
|
// z_inp: NCHW --> NCHW4 |
|
|
|
|
|
VarNode* z_inp = new_inp[3]; |
|
|
|
|
|
if (new_inp[3]->shape().ndim == 4) { |
|
|
|
|
|
auto new_z = RelayoutPlaceholder::make(new_inp[3], |
|
|
|
|
|
src_to_nchw4_mode); |
|
|
|
|
|
z_inp = new_z.node(); |
|
|
|
|
|
} |
|
|
|
|
|
auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, |
|
|
|
|
|
conv_bias_filter, conv_bias_bias, z_inp, new_param, |
|
|
|
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
|
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
|
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
|
|
|
"The conv_bias dst dim is not trans to nchw4"); |
|
|
|
|
|
return new_opr; |
|
|
|
|
|
}; |
|
|
|
|
|
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
bool has_inp_changed = false; |
|
|
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
|
|
if (new_inp[i]->shape().ndim == 5) { |
|
|
|
|
|
has_inp_changed = true; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (has_inp_changed) { |
|
|
|
|
|
auto temp_inp = new_inp; |
|
|
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
|
|
if (new_inp[i]->shape().ndim == 4) { |
|
|
|
|
|
auto new_var = RelayoutPlaceholder::make( |
|
|
|
|
|
new_inp[i], src_to_nchw4_mode); |
|
|
|
|
|
temp_inp[i] = new_var.node(); |
|
|
|
|
|
} else { |
|
|
|
|
|
mgb_assert((new_inp[i]->shape().ndim == 5) || |
|
|
|
|
|
new_inp[i]->shape().is_scalar()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, temp_inp, |
|
|
|
|
|
opr->config()); |
|
|
|
|
|
} else { |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
|
|
opr->config()); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
VarNodeArray temp_inp = new_inp; |
|
|
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
|
|
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { |
|
|
|
|
|
mgb_assert(opr->input(i)->shape().ndim == 4); |
|
|
|
|
|
mgb_assert(new_inp[i]->shape().ndim == 5); |
|
|
|
|
|
auto new_var = |
|
|
|
|
|
RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode); |
|
|
|
|
|
temp_inp[i] = new_var.node(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); |
|
|
|
|
|
}; |
|
|
|
|
|
auto&& replace_func = ret->m_opr_replace_func; |
|
|
|
|
|
//! supportted nchw4 |
|
|
|
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
|
|
|
|
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; |
|
|
|
|
|
replace_func[opr::BatchConvBias::typeinfo()] = |
|
|
|
|
|
replace_batch_conv_bias_opr; |
|
|
|
|
|
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; |
|
|
|
|
|
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; |
|
|
|
|
|
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; |
|
|
|
|
|
replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; |
|
|
|
|
|
//! not supported nchw4 |
|
|
|
|
|
replace_func[opr::PoolingForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::ConvolutionBackwardData::typeinfo()] = |
|
|
|
|
|
relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::WarpPerspectiveForward::typeinfo()] = |
|
|
|
|
|
relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
/* ================ EnableNchwxxPass =============== */ |
|
|
/* ================ EnableNchwxxPass =============== */ |
|
|
VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, |
|
|
VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, |
|
|
VarNode* orig_var) const { |
|
|
VarNode* orig_var) const { |
|
@@ -1251,7 +1594,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, |
|
|
} else if (m_pack_c_size == 4) { |
|
|
} else if (m_pack_c_size == 4) { |
|
|
return RelayoutPlaceholder::make( |
|
|
return RelayoutPlaceholder::make( |
|
|
new_var, |
|
|
new_var, |
|
|
RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) |
|
|
|
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) |
|
|
.node(); |
|
|
.node(); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
@@ -1287,8 +1630,8 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; |
|
|
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; |
|
|
weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; |
|
|
weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; |
|
|
hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; |
|
|
hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; |
|
|
src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW44; |
|
|
|
|
|
src_to_nchw_mode = RelayoutMode::NCHW44_TO_NCHW; |
|
|
|
|
|
|
|
|
src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW4; |
|
|
|
|
|
src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW; |
|
|
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; |
|
|
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; |
|
|
conv_format = megdnn::param::ConvolutionV0::Format::NCHW44; |
|
|
conv_format = megdnn::param::ConvolutionV0::Format::NCHW44; |
|
|
pooling_format = megdnn::param::Pooling::Format::NCHW44; |
|
|
pooling_format = megdnn::param::Pooling::Format::NCHW44; |
|
|