|
@@ -6,7 +6,8 @@ |
|
|
* |
|
|
* |
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
* software distributed under the License is distributed on an |
|
|
* software distributed under the License is distributed on an |
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or |
|
|
|
|
|
* implied. |
|
|
*/ |
|
|
*/ |
|
|
|
|
|
|
|
|
#include "megbrain/gopt/inference.h" |
|
|
#include "megbrain/gopt/inference.h" |
|
@@ -63,7 +64,10 @@ public: |
|
|
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout |
|
|
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout |
|
|
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout |
|
|
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout |
|
|
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout |
|
|
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout |
|
|
|
|
|
NCHW_TO_NCHW44, //!< from nchw layout to nchw44 layout |
|
|
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout |
|
|
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout |
|
|
|
|
|
NCHW44_TO_NCHW, //!< from nchw44 layout to nchw layout |
|
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 |
|
|
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 |
|
|
//!< layout |
|
|
//!< layout |
|
|
WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to |
|
|
WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to |
|
@@ -73,6 +77,16 @@ public: |
|
|
//!< the weight layout of input is nchw output is nchw88, special for |
|
|
//!< the weight layout of input is nchw output is nchw88, special for |
|
|
//!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} |
|
|
//!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} |
|
|
WEIGHT_HYBIRD_NCHW_NCHW88, |
|
|
WEIGHT_HYBIRD_NCHW_NCHW88, |
|
|
|
|
|
|
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44 |
|
|
|
|
|
//!< layout |
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to |
|
|
|
|
|
//!< nchw44 layout |
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout |
|
|
|
|
|
//!< to nchw44 layout |
|
|
|
|
|
//!< the weight layout of input is nchw output is nchw44, special for |
|
|
|
|
|
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4} |
|
|
|
|
|
WEIGHT_HYBIRD_NCHW_NCHW44, |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); |
|
|
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); |
|
@@ -203,10 +217,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
dst[3] = inp_shape[3]; |
|
|
dst[3] = inp_shape[3]; |
|
|
dst[4] = inp_shape[4]; |
|
|
dst[4] = inp_shape[4]; |
|
|
dst[5] = 8; |
|
|
dst[5] = 8; |
|
|
} else { |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88); |
|
|
|
|
|
|
|
|
} else if (layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88) { |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0); |
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0); |
|
|
dst.ndim = 5; |
|
|
dst.ndim = 5; |
|
|
dst[0] = inp_shape[0] / 8; |
|
|
dst[0] = inp_shape[0] / 8; |
|
@@ -214,6 +226,68 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
dst[2] = inp_shape[3]; |
|
|
dst[2] = inp_shape[3]; |
|
|
dst[3] = inp_shape[1]; |
|
|
dst[3] = inp_shape[1]; |
|
|
dst[4] = 8; |
|
|
dst[4] = 8; |
|
|
|
|
|
} else if (layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW44) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); |
|
|
|
|
|
dst.ndim = 5; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1] / 4; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
dst[4] = 4; |
|
|
|
|
|
} else if (layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); |
|
|
|
|
|
dst.ndim = 4; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1] * 4; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW44_DENSE) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 && |
|
|
|
|
|
inp_shape[1] % 4 == 0); |
|
|
|
|
|
dst.ndim = 6; |
|
|
|
|
|
dst[0] = inp_shape[0] / 4; |
|
|
|
|
|
dst[1] = inp_shape[1] / 4; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
dst[4] = 4; |
|
|
|
|
|
dst[5] = 4; |
|
|
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW44_GROUP) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 && |
|
|
|
|
|
inp_shape[2] % 4 == 0); |
|
|
|
|
|
dst.ndim = 7; |
|
|
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
|
|
dst[1] = inp_shape[1] / 4; |
|
|
|
|
|
dst[2] = inp_shape[2] / 4; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
dst[4] = inp_shape[4]; |
|
|
|
|
|
dst[5] = 4; |
|
|
|
|
|
dst[6] = 4; |
|
|
|
|
|
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: |
|
|
|
|
|
WEIGHT_NCHW_TO_NCHW44_CHAN) { |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 && |
|
|
|
|
|
inp_shape[2] == 1 && inp_shape[0] % 4 == 0); |
|
|
|
|
|
dst.ndim = 6; |
|
|
|
|
|
dst[0] = inp_shape[0] / 4; |
|
|
|
|
|
dst[1] = inp_shape[1]; |
|
|
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
|
|
dst[4] = inp_shape[4]; |
|
|
|
|
|
dst[5] = 4; |
|
|
|
|
|
} else { |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
layout_type() == |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44); |
|
|
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0); |
|
|
|
|
|
dst.ndim = 5; |
|
|
|
|
|
dst[0] = inp_shape[0] / 4; |
|
|
|
|
|
dst[1] = inp_shape[2]; |
|
|
|
|
|
dst[2] = inp_shape[3]; |
|
|
|
|
|
dst[3] = inp_shape[1]; |
|
|
|
|
|
dst[4] = 4; |
|
|
} |
|
|
} |
|
|
return true; |
|
|
return true; |
|
|
}; |
|
|
}; |
|
@@ -418,6 +492,104 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
return y2.node(); |
|
|
return y2.node(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
reformat[LayoutType::NCHW_TO_NCHW44] = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::NCHW44_TO_NCHW] = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); |
|
|
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); |
|
|
|
|
|
auto y1 = opr::Reshape::make(y0, tshp0); |
|
|
|
|
|
return y1.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] = |
|
|
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_GROUP] = |
|
|
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, |
|
|
|
|
|
cv(4), sub(3), sub(4)}, |
|
|
|
|
|
0), |
|
|
|
|
|
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), |
|
|
|
|
|
sub(4), cv(4), cv(4)}, |
|
|
|
|
|
0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_CHAN] = |
|
|
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0) / 4, cv(4), sub(1), sub(2), sub(3), sub(4)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0) / 4, sub(1), sub(2), sub(3), sub(4), cv(4)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44] = |
|
|
|
|
|
[](VarNode* inp) -> VarNode* { |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0) / 4, cv(4), sub(1), sub(2), sub(3)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0) / 4, sub(2), sub(3), sub(1), cv(4)}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
auto rewriter = opt.graph().make_rewriter(); |
|
|
auto rewriter = opt.graph().make_rewriter(); |
|
|
auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { |
|
|
auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { |
|
@@ -1071,16 +1243,24 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { |
|
|
VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, |
|
|
VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, |
|
|
VarNode* orig_var) const { |
|
|
VarNode* orig_var) const { |
|
|
if (!orig_var->shape().eq_shape(new_var->shape())) { |
|
|
if (!orig_var->shape().eq_shape(new_var->shape())) { |
|
|
return RelayoutPlaceholder::make( |
|
|
|
|
|
new_var, RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) |
|
|
|
|
|
.node(); |
|
|
|
|
|
|
|
|
if (m_pack_c_size == 8) { |
|
|
|
|
|
return RelayoutPlaceholder::make( |
|
|
|
|
|
new_var, |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) |
|
|
|
|
|
.node(); |
|
|
|
|
|
} else if (m_pack_c_size == 4) { |
|
|
|
|
|
return RelayoutPlaceholder::make( |
|
|
|
|
|
new_var, |
|
|
|
|
|
RelayoutPlaceholder::LayoutType::NCHW44_TO_NCHW) |
|
|
|
|
|
.node(); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
return new_var; |
|
|
return new_var; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
size_t pack_c_size) { |
|
|
size_t pack_c_size) { |
|
|
auto ret = std::make_unique<EnableNchwxxPass>(); |
|
|
|
|
|
|
|
|
auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size); |
|
|
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); |
|
|
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); |
|
|
//! First is whether the conv can trans to nchwxx, second is the filter |
|
|
//! First is whether the conv can trans to nchwxx, second is the filter |
|
|
//! trans mode |
|
|
//! trans mode |
|
@@ -1102,8 +1282,18 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
megdnn::param::Pooling::Format pooling_format = |
|
|
megdnn::param::Pooling::Format pooling_format = |
|
|
megdnn::param::Pooling::Format::NCHW88; |
|
|
megdnn::param::Pooling::Format::NCHW88; |
|
|
std::string convter_pass_name = "conv_format_nchw88"; |
|
|
std::string convter_pass_name = "conv_format_nchw88"; |
|
|
mgb_assert(pack_c_size == static_cast<size_t>(8), |
|
|
|
|
|
"The ConvertFormatPass to nchwxx only support NCHW88 now !"); |
|
|
|
|
|
|
|
|
if (pack_c_size == 4) { |
|
|
|
|
|
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; |
|
|
|
|
|
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; |
|
|
|
|
|
weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN; |
|
|
|
|
|
hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; |
|
|
|
|
|
src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW44; |
|
|
|
|
|
src_to_nchw_mode = RelayoutMode::NCHW44_TO_NCHW; |
|
|
|
|
|
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; |
|
|
|
|
|
conv_format = megdnn::param::ConvolutionV0::Format::NCHW44; |
|
|
|
|
|
pooling_format = megdnn::param::Pooling::Format::NCHW44; |
|
|
|
|
|
convter_pass_name = "conv_format_nchw44"; |
|
|
|
|
|
} |
|
|
auto test_trans_nchwxx = |
|
|
auto test_trans_nchwxx = |
|
|
[pack_c_size, weight_to_nchwxx_mode_dense, |
|
|
[pack_c_size, weight_to_nchwxx_mode_dense, |
|
|
weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, |
|
|
weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, |
|
@@ -1297,7 +1487,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
auto new_param = conv_bias_opr.param(); |
|
|
auto new_param = conv_bias_opr.param(); |
|
|
new_param.format = conv_bias_format; |
|
|
new_param.format = conv_bias_format; |
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
auto new_conv_bias_opr = opr::ConvBias::make( |
|
|
conv_bias_src, conv_bias_filter, new_param, |
|
|
|
|
|
|
|
|
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, |
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
conv_bias_opr.execution_policy(), conv_bias_opr.config()); |
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); |
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
|
mgb_assert(new_conv_bias_opr.shape().ndim == 5, |
|
@@ -1330,6 +1520,51 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
} |
|
|
} |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto replace_concat_opr = [=](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
bool has_inp_changed = false; |
|
|
|
|
|
bool can_exec_ncwxx = true; |
|
|
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
|
|
if (new_inp[i]->shape().ndim == 5) { |
|
|
|
|
|
has_inp_changed = true; |
|
|
|
|
|
break; |
|
|
|
|
|
} else if (new_inp[i]->shape().ndim == 4) { |
|
|
|
|
|
if (new_inp[i]->shape()[1] % pack_c_size != 0) { |
|
|
|
|
|
can_exec_ncwxx = false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (has_inp_changed) { |
|
|
|
|
|
auto temp_inp = new_inp; |
|
|
|
|
|
if (can_exec_ncwxx) { |
|
|
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
|
|
if (new_inp[i]->shape().ndim == 4) { |
|
|
|
|
|
auto new_var = RelayoutPlaceholder::make( |
|
|
|
|
|
new_inp[i], src_to_nchwxx_mode); |
|
|
|
|
|
temp_inp[i] = new_var.node(); |
|
|
|
|
|
} else { |
|
|
|
|
|
mgb_assert((new_inp[i]->shape().ndim == 5) || |
|
|
|
|
|
new_inp[i]->shape().is_scalar()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
|
|
if (new_inp[i]->shape().ndim == 5) { |
|
|
|
|
|
auto new_var = RelayoutPlaceholder::make( |
|
|
|
|
|
new_inp[i], src_to_nchw_mode); |
|
|
|
|
|
temp_inp[i] = new_var.node(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, temp_inp, |
|
|
|
|
|
opr->config()); |
|
|
|
|
|
} else { |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, |
|
|
|
|
|
opr->config()); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, |
|
|
auto replace_elemwise_opr = [=](OperatorNodeBase* opr, |
|
|
const VarNodeArray& new_inp) { |
|
|
const VarNodeArray& new_inp) { |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
@@ -1382,6 +1617,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
|
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; |
|
|
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; |
|
|
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; |
|
|
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; |
|
|
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; |
|
|
|
|
|
replace_func[opr::Concat::typeinfo()] = replace_concat_opr; |
|
|
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; |
|
|
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; |
|
|
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; |
|
|
replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; |
|
|
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; |
|
|
replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; |
|
@@ -1390,13 +1626,10 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( |
|
|
replace_func[opr::ConvolutionBackwardData::typeinfo()] = |
|
|
replace_func[opr::ConvolutionBackwardData::typeinfo()] = |
|
|
relayout_inp_to_nchw; |
|
|
relayout_inp_to_nchw; |
|
|
replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
|
|
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
replace_func[opr::WarpPerspectiveForward::typeinfo()] = |
|
|
replace_func[opr::WarpPerspectiveForward::typeinfo()] = |
|
|