diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index dd9390e5..bbddd96e 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -1950,1960 +1950,4 @@ void ParamMergePass::apply(OptState& opt_state) const { opr::MultipleDeviceTensorWithFormatHolder>(opt_state); } -/* ================ TensorReformatPass =============== */ -/*! - * \brief relayout placeholder opr - * - * RelayoutPlaceholder oprs act as the placeholders of the ComputingGraph - * during graph opt pass `TensorReformatPass`. These oprs are introduced - * into a ComputingGraph for conveniently discovering further optimize - * opportunities (such as fuse consecutive relayouts, translate into - * optimized implementations). They are canonized to have a shape infer, so - * the ouput's shape can be correctly deduced during the opt pass. - * - * Note that the oprs in the ComputingGraph are only used as intermediate - * representations before being translated to MegBrain oprs, so the - * oprs should not get involved in any actual computing. - */ -MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, - cg::SingleCNOperatorNodeBase) // { -public: - //! relayout type of this opr - enum class LayoutType { - NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout - NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout - NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout - CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout - NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout - NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout - WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 - //!< layout - WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to - //!< nchw88 layout - WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout - //!< to nchw88 layout - //!< the weight layout of input is nchw output is nchw88, special for - //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} - WEIGHT_HYBIRD_NCHW_NCHW88, - }; - - RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); - - /*! - * \param src_var the input var - * \param layout_type tensor layout transform type of this relayout - * placeholder as described in LayoutType - */ - static SymbolVar make(VarNode* src_var, LayoutType layout_type); - - LayoutType layout_type() const { return m_layout_type; } - -private: - void init_output_static_infer_desc() override; - void scn_do_execute() override; - void init_output_comp_node() override; - const LayoutType m_layout_type; -}; - -MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder); - -TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder( - VarNode* src_var, LayoutType layout_type) - : Super(src_var->owner_graph(), {}, "RelayoutPlaceholder", {src_var}), - m_layout_type{layout_type} { - add_input({src_var}); - add_equivalence_component>(m_layout_type); - add_output(None)->dtype(src_var->dtype()); -} - -void TensorReformatPass::RelayoutPlaceholder::scn_do_execute() { - mgb_throw(InternalError, "RelayoutPlaceholder opr can not be executed"); -} - -void TensorReformatPass::RelayoutPlaceholder::init_output_comp_node() { - output(0)->comp_node(input(0)->comp_node()); -} - -void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { - using namespace cg::static_infer; - auto&& mgr = owner_graph()->static_infer_manager(); - DepVal deps; - for (auto i : input()) - deps.push_back({i, DepType::SHAPE}); - auto infer_shape = [this](TensorShape& dst, const InpVal& inp) { - TensorShape inp_shape = inp.val[0].shape(); - dst = inp_shape; - if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4] * 8; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4] / 8; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst[0] = inp_shape[1]; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[0]; - dst[4] = inp_shape[4]; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst[0] = inp_shape[3]; - dst[1] = inp_shape[0]; - dst[2] = inp_shape[1]; - dst[3] = inp_shape[2]; - dst[4] = inp_shape[4]; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 8; - } else if (layout_type() == - RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 8); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW88_DENSE) { - mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0 && - inp_shape[1] % 8 == 0); - dst.ndim = 6; - dst[0] = inp_shape[0] / 8; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 8; - dst[5] = 8; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW88_GROUP) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 8 == 0 && - inp_shape[2] % 8 == 0); - dst.ndim = 7; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2] / 8; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4]; - dst[5] = 8; - dst[6] = 8; - } else if (layout_type() == RelayoutPlaceholder::LayoutType:: - WEIGHT_NCHW_TO_NCHW88_CHAN) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 && - inp_shape[2] == 1 && inp_shape[0] % 8 == 0); - dst.ndim = 6; - dst[0] = inp_shape[0] / 8; - dst[1] = inp_shape[1]; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4]; - dst[5] = 8; - } else { - mgb_assert( - layout_type() == - RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88); - mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0); - dst.ndim = 5; - dst[0] = inp_shape[0] / 8; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[1]; - dst[4] = 8; - } - return true; - }; - mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape}); -} - -SymbolVar TensorReformatPass::RelayoutPlaceholder::make( - VarNode* src_var, LayoutType layout_type) { - return src_var->owner_graph() - ->insert_opr( - std::make_unique(src_var, layout_type)) - ->output(0); -} - -void TensorReformatPass::insert_pass(OptState& opt) const { - opt.set_var_replace_check_flag(m_var_replace_check_flag); - auto rewriter = opt.graph().make_rewriter(); - VarNodeArray new_inp_cache; - auto on_opr = [this, &opt, &rewriter, - &new_inp_cache](OperatorNodeBase* opr) { - auto it = m_opr_replace_func.find(opr->dyn_typeinfo()); - if (it != m_opr_replace_func.end()) { - auto& new_inp = new_inp_cache; - new_inp.clear(); - new_inp.reserve(opr->input().size()); - for (auto&& inp : opr->input()) { - new_inp.push_back(rewriter.get_var(inp)); - } - auto new_opr = (it->second)(opr, new_inp); - auto &&out0 = opr->output(), &&out1 = new_opr->output(); - mgb_assert(out0.size() == out1.size(), - "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu " - "dst.size=%zu", - opr->cname(), opr->dyn_typeinfo()->name, - new_opr->cname(), new_opr->dyn_typeinfo()->name, - out0.size(), out1.size()); - for (size_t i = 0; i < out0.size(); ++i) { - if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - mgb_assert(!out1[i]->contain_flag( - VarNode::Flag::VOLATILE_CONTENT)); - auto src = out0[i]; - auto dst = out1[i]; - if (opt.graph().endpoint_contain(src)) { - // additional process on endpoint var node - dst = on_graph_endpoint_var(dst, src); - } - rewriter.replace_var(src, dst, nullptr); - } - } - } else { - rewriter.auto_replace_outputs(opr); - } - }; - opt.graph().iter(on_opr); - rewriter.apply_inplace(); -} - -void TensorReformatPass::translate_pass(OptState& opt) const { - ThinHashMap> - reformat; - using LayoutType = RelayoutPlaceholder::LayoutType; - reformat[LayoutType::NCHW4_TO_CHWN4] = [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - reformat[LayoutType::CHWN4_TO_NCHW4] = [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - reformat[LayoutType::NCHW4_TO_NCHW32] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW32_TO_NCHW4] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 8, cv(8), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 8, sub(2), sub(3), cv(8)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::NCHW88_TO_NCHW] = [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(1) * 8, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp0); - return y1.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_DENSE] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 8, cv(8), sub(1) / 8, cv(8), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 8, sub(1) / 8, sub(2), sub(3), cv(8), cv(8)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_GROUP] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make({sub(0), sub(1) / 8, cv(8), sub(2) / 8, - cv(8), sub(3), sub(4)}, - 0), - tshp1 = opr::Concat::make({sub(0), sub(1) / 8, sub(2) / 8, sub(3), - sub(4), cv(8), cv(8)}, - 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_CHAN] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 8, cv(8), sub(1), sub(2), sub(3), sub(4)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 8, sub(1), sub(2), sub(3), sub(4), cv(8)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0) / 8, cv(8), sub(1), sub(2), sub(3)}, 0), - tshp1 = opr::Concat::make( - {sub(0) / 8, sub(2), sub(3), sub(1), cv(8)}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - - auto rewriter = opt.graph().make_rewriter(); - auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { - if (opr->same_type()) { - auto ph = try_cast_as_op(opr); - auto new_inp = rewriter.get_var(opr->input(0)); - mgb_assert(reformat.count(ph->layout_type()), - "no replace rule can be found for layout_type(%u)", - static_cast(ph->layout_type())); - auto new_var = reformat[ph->layout_type()](new_inp); - rewriter.replace_var(opr->output(0), new_var, - mgb_cstr_log("replace relayout placeholder")); - return; - } - rewriter.auto_replace_outputs(opr); - }; - opt.graph().iter(on_opr); - rewriter.apply_inplace(); -} - -void TensorReformatPass::apply(OptState& opt) const { - insert_pass(opt); - translate_pass(opt); -} - -/* ================ EnableTensorCorePass =============== */ -VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var, - VarNode* orig_var) const { - if (!orig_var->shape().eq_shape(new_var->shape())) { - return RelayoutPlaceholder::make( - new_var, - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) - .node(); - } - return new_var; -} - -std::unique_ptr -EnableTensorCorePass::make_tensorcore_converter() { - // replace rule for conv bias opr - auto replace_conv_bias_opr = [](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - using Param = megdnn::param::ConvBias; - using Format = Param::Format; - using Sparse = Param::Sparse; - mgb_assert(opr->input().size() == new_inp.size()); - auto& conv_bias = opr->cast_final_safe(); - if (conv_bias.param().format != Format::NCHW4 || - conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) { - size_t nr_inps = opr->input().size(); - bool shape_has_changed = false; - for (size_t i = 0; i < nr_inps; ++i) { - if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - shape_has_changed = true; - } - } - MGB_MARK_USED_VAR(shape_has_changed); - mgb_assert( - !shape_has_changed, - "EnableTensorCorePass assumes that the shape of inputs of" - "ConvBias operators whose output dtype is not QuantizedS8 " - "can not be changed in this opt pass"); - return serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); - } - mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()), - "EnableTensorCorePass assumes that filter tensor of " - "conv_bias operator can not be changed by other operators"); - VarNode* orig_filter = opr->input(1); - auto is_nchw4 = [](TensorShape shape) -> bool { - return shape.ndim == 5 && shape[4] == 4; - }; - auto is_nchw32 = [](TensorShape shape) -> bool { - return shape.ndim == 5 && shape[4] == 32; - }; - bool can_replace_nchw32 = false; - VarNode *src = nullptr, *weight = nullptr, *bias = nullptr, - *z_inp = nullptr; - // process src tensor - if (is_nchw4(new_inp[0]->shape())) { // new input is NCHW4 layout - size_t group = 1, icpg, ocpg; - if (conv_bias.param().sparse == Sparse::DENSE) { - icpg = orig_filter->shape()[1] * 4; - ocpg = orig_filter->shape()[0]; - } else { - mgb_assert(conv_bias.param().sparse == Sparse::GROUP); - group = orig_filter->shape()[0]; - icpg = orig_filter->shape()[2]; - ocpg = orig_filter->shape()[1]; - if (icpg == 1 && ocpg == 1) { // channel wise conv - group *= 4; - } else { - icpg *= 4; - } - } - // nchw32 layout need that input width and height are larger than 3 - size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3]; - if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && - iw >= 3) { - auto symvar = RelayoutPlaceholder::make( - new_inp[0], - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); - src = symvar.node(); - can_replace_nchw32 = true; - } else { - src = new_inp[0]; - } - } else { // new input is NCHW32 layout - mgb_assert(is_nchw32(new_inp[0]->shape())); - size_t group = 1, ocpg; - if (conv_bias.param().sparse == Sparse::DENSE) { - ocpg = orig_filter->shape()[0]; - } else { - mgb_assert(conv_bias.param().sparse == Sparse::GROUP); - size_t icpg = orig_filter->shape()[2]; - ocpg = orig_filter->shape()[1]; - if (icpg == 1 && ocpg == 1) { - group *= 4; - } else { - icpg *= 4; - } - } - size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3]; - if (group == 1 && ocpg % 32 == 0 && ih >= 3 && iw >= 3) { - can_replace_nchw32 = true; - src = new_inp[0]; - } else { - auto symvar = RelayoutPlaceholder::make( - new_inp[0], - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); - src = symvar.node(); - } - } - // process filter tensor - if (can_replace_nchw32) { - auto symvar = RelayoutPlaceholder::make( - new_inp[1], - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); - weight = symvar.node(); - } else { - weight = new_inp[1]; - } - if (new_inp.size() == 2) { - if (can_replace_nchw32) { - auto param = conv_bias.param(); - param.format = Format::NCHW32; - auto new_opr = opr::ConvBiasForward::make( - src, weight, param, conv_bias.execution_policy(), - conv_bias.config()); - return new_opr.node()->owner_opr(); - } else { - VarNodeArray inps{src, weight}; - auto new_opr = serialization::copy_opr_shallow(*opr, inps, - opr->config()); - return new_opr; - } - } - auto process_inp = [&](VarNode* inp) -> VarNode* { - if (can_replace_nchw32) { - if (is_nchw4(inp->shape())) { - auto symvar = RelayoutPlaceholder::make( - inp, - RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); - return symvar.node(); - } else { - mgb_assert(is_nchw32(inp->shape())); - return inp; - } - } else { - if (is_nchw4(inp->shape())) { - return inp; - } else { - mgb_assert(is_nchw32(inp->shape())); - auto symvar = RelayoutPlaceholder::make( - inp, - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); - return symvar.node(); - } - } - }; - // process bias tensor - bias = process_inp(new_inp[2]); - if (new_inp.size() == 3) { - if (can_replace_nchw32) { - auto param = conv_bias.param(); - param.format = Format::NCHW32; - auto new_opr = opr::ConvBiasForward::make( - src, weight, bias, param, conv_bias.execution_policy(), - conv_bias.config()); - return new_opr.node()->owner_opr(); - } else { - VarNodeArray inps{src, weight, bias}; - auto new_opr = serialization::copy_opr_shallow(*opr, inps, - opr->config()); - return new_opr; - } - } - // process z_inp tensor - z_inp = process_inp(new_inp[3]); - if (can_replace_nchw32) { - auto param = conv_bias.param(); - param.format = Format::NCHW32; - auto new_opr = opr::ConvBiasForward::make( - src, weight, bias, z_inp, param, - conv_bias.execution_policy(), conv_bias.config()); - return new_opr.node()->owner_opr(); - } - VarNodeArray inps{src, weight, bias, z_inp}; - auto new_opr = - serialization::copy_opr_shallow(*opr, inps, opr->config()); - return new_opr; - }; - // replace rule for elemwise like opr - // for oprs support NCHW4 and NCHW32 layout - auto replace_elemwise_like_opr = [](OperatorNodeBase* opr, - const VarNodeArray new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - size_t nr_inps = new_inp.size(); - size_t nr_shape_changed = 0; - for (size_t i = 0; i < nr_inps; ++i) { - if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - nr_shape_changed++; - } - } - if (nr_shape_changed) { - auto inps = new_inp; - if (nr_shape_changed >= - nr_inps / 2) { // NCHW32 > NCHW4 -> use NCHW32 - for (size_t i = 0; i < nr_inps; ++i) { - if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - auto symvar = RelayoutPlaceholder::make( - new_inp[i], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_NCHW32); - inps[i] = symvar.node(); - } - } - } else { // NCHW32 < NCHW4 -> use NCHW4 - for (size_t i = 0; i < nr_inps; ++i) { - if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - auto symvar = RelayoutPlaceholder::make( - new_inp[i], RelayoutPlaceholder::LayoutType:: - NCHW32_TO_NCHW4); - inps[i] = symvar.node(); - } - } - } - return serialization::copy_opr_shallow(*opr, inps, opr->config()); - } - return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); - }; - // for oprs only supports NCHW4 layout - auto replace_inps_to_nchw4 = [](OperatorNodeBase* opr, - const VarNodeArray new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - VarNodeArray inps = new_inp; - for (size_t i = 0; i < opr->input().size(); ++i) { - if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - mgb_assert(opr->input(i)->shape().ndim == 5 && - opr->input(i)->shape()[4] == 4); - mgb_assert(new_inp[i]->shape().ndim == 5 && - new_inp[i]->shape()[4] == 32); - auto symvar = RelayoutPlaceholder::make( - new_inp[i], - RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); - inps[i] = symvar.node(); - } - } - auto new_opr = - serialization::copy_opr_shallow(*opr, inps, opr->config()); - return new_opr; - }; - auto replace_non_nchw4_opr = [](OperatorNodeBase* opr, - const VarNodeArray new_inp) { - size_t nr_inps = opr->input().size(); - bool shape_has_changed = false; - for (size_t i = 0; i < nr_inps; ++i) { - if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - shape_has_changed = true; - } - } - mgb_assert(!shape_has_changed, - "EnableTensorCorePass assumes that inputs' shape of " - "non-nchw4 operators " - "can not be changed in this opt " - "pass"); - return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); - - }; - auto replace_warp_affine_opr = - [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, const VarNodeArray new_inp) { - using Param = opr::WarpAffineForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& warp = opr->cast_final_safe(); - if (warp.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_warp_perspective_opr = - [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, const VarNodeArray new_inp) { - using Param = opr::WarpPerspectiveForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& warp = - opr->cast_final_safe(); - if (warp.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - using Param = opr::ResizeForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& resize = opr->cast_final_safe(); - if (resize.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_pooling_opr = [replace_non_nchw4_opr]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - using Param = opr::PoolingForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& pooling = opr->cast_final_safe(); - if (pooling.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - size_t nr_inps = opr->input().size(); - MGB_MARK_USED_VAR(nr_inps); - mgb_assert(nr_inps == 1); - if (!opr->input(0)->shape().eq_shape(new_inp[0]->shape())) { - mgb_assert(opr->input(0)->shape().ndim == 5 && - opr->input(0)->shape()[4] == 4); - mgb_assert(new_inp[0]->shape().ndim == 5 && - new_inp[0]->shape()[4] == 32); - auto new_param = pooling.param(); - new_param.format = Format::NCHW32; - auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param, - opr->config()); - return new_pooling.node()->owner_opr(); - } - return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); - }; - auto ret = std::make_unique(); - ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); - auto&& replace_func = ret->m_opr_replace_func; - replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr; - - // elemwise like - replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; - replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; - replace_func[opr::ElemwiseMultiType::typeinfo()] = - replace_elemwise_like_opr; - replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr; - - // format aware - replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; - replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr; - replace_func[opr::WarpPerspectiveForward::typeinfo()] = - replace_warp_perspective_opr; - replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; - - // to nchw4 - replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; - return ret; -} - -/* ================ EnableCHWN4Pass =============== */ -VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var, - VarNode* /* orig_var */) const { - if (m_varshape_changed.count(new_var)) { - return RelayoutPlaceholder::make( - new_var, RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) - .node(); - } - return new_var; -} - -std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { - auto ret = std::make_unique(); - ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); - auto&& replace_func = ret->m_opr_replace_func; - auto&& varshape_changed = ret->m_varshape_changed; - // replace rule for conv bias opr - auto replace_conv_bias_opr = [&varshape_changed]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - using Param = megdnn::param::ConvBias; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& conv_bias = opr->cast_final_safe(); - if (conv_bias.param().format != Format::NCHW4 || - conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) { - size_t nr_inps = new_inp.size(); - bool shape_has_changed = false; - for (size_t i = 0; i < nr_inps; ++i) { - if (varshape_changed.count(new_inp[i])) { - shape_has_changed = true; - break; - } - } - mgb_assert( - !shape_has_changed, - "EnableCHWN4Pass assumes that the shape of inputs of" - "ConvBias operators whose output dtype is not QuantizedS8 " - "can not be changed in this opt pass"); - return serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); - } - mgb_assert(varshape_changed.count(new_inp[1]) == 0, - "EnableCHWN4Pass assumes that filter tensor of " - "conv_bias operator can not be changed by other operators"); - VarNode *src = nullptr, *weight = nullptr, *bias = nullptr, - *z_inp = nullptr; - // process src tensor - if (varshape_changed.count(new_inp[0]) == - 0) { // new input is NCHW4 layout - // currently not support group conv - auto symvar = RelayoutPlaceholder::make( - new_inp[0], - RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); - src = symvar.node(); - } else { // new input is NCHW32 layout - src = new_inp[0]; - } - // process weight tensor - { - auto symvar = RelayoutPlaceholder::make( - new_inp[1], - RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); - weight = symvar.node(); - } - if (new_inp.size() == 2) { - auto param = conv_bias.param(); - param.format = Format::CHWN4; - auto new_opr = opr::ConvBiasForward::make( - src, weight, param, conv_bias.execution_policy(), - conv_bias.config()); - varshape_changed.insert(new_opr.node()); - return new_opr.node()->owner_opr(); - } - auto process_inp = [&](VarNode* inp) -> VarNode* { - if (varshape_changed.count(inp) == 0) { - auto symvar = RelayoutPlaceholder::make( - inp, RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); - return symvar.node(); - } else { - return inp; - } - }; - // process bias tensor - bias = process_inp(new_inp[2]); - if (new_inp.size() == 3) { - auto param = conv_bias.param(); - param.format = Format::CHWN4; - auto new_opr = opr::ConvBiasForward::make( - src, weight, bias, param, conv_bias.execution_policy(), - conv_bias.config()); - varshape_changed.insert(new_opr.node()); - return new_opr.node()->owner_opr(); - } - // process z_inp tensor - z_inp = process_inp(new_inp[3]); - auto param = conv_bias.param(); - param.format = Format::CHWN4; - auto new_opr = opr::ConvBiasForward::make( - src, weight, bias, z_inp, param, conv_bias.execution_policy(), - conv_bias.config()); - varshape_changed.insert(new_opr.node()); - return new_opr.node()->owner_opr(); - }; - // replace rule for elemwise like opr - // for oprs support NCHW4 and CHWN4 layout - auto replace_elemwise_like_opr = [&varshape_changed]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - size_t nr_inps = new_inp.size(); - size_t nr_shape_changed = 0; - for (size_t i = 0; i < nr_inps; ++i) { - if (varshape_changed.count(new_inp[i])) { - nr_shape_changed++; - } - } - if (nr_shape_changed) { - auto inps = new_inp; - if (nr_shape_changed >= nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4 - for (size_t i = 0; i < nr_inps; ++i) { - if (varshape_changed.count(new_inp[i]) == 0) { - auto symvar = RelayoutPlaceholder::make( - new_inp[i], RelayoutPlaceholder::LayoutType:: - NCHW4_TO_CHWN4); - inps[i] = symvar.node(); - } - } - auto new_opr = serialization::copy_opr_shallow(*opr, inps, - opr->config()); - varshape_changed.insert(new_opr->output(0)); - return new_opr; - } else { // CHWN4 < NCHW4 -> use NCHW4 - for (size_t i = 0; i < nr_inps; ++i) { - if (varshape_changed.count(new_inp[i])) { - auto symvar = RelayoutPlaceholder::make( - new_inp[i], RelayoutPlaceholder::LayoutType:: - CHWN4_TO_NCHW4); - inps[i] = symvar.node(); - } - } - return serialization::copy_opr_shallow(*opr, inps, - opr->config()); - } - } - return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); - }; - // for oprs only supports NCHW4 layout - auto replace_inps_to_nchw4 = [&varshape_changed]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - VarNodeArray inps = new_inp; - for (size_t i = 0; i < opr->input().size(); ++i) { - if (varshape_changed.count(new_inp[i])) { - auto symvar = RelayoutPlaceholder::make( - new_inp[i], - RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4); - inps[i] = symvar.node(); - } - } - auto new_opr = - serialization::copy_opr_shallow(*opr, inps, opr->config()); - return new_opr; - }; - auto replace_non_nchw4_opr = [&varshape_changed]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - size_t nr_inps = opr->input().size(); - bool shape_has_changed = false; - for (size_t i = 0; i < nr_inps; ++i) { - if (varshape_changed.count(new_inp[i])) { - shape_has_changed = true; - } - } - mgb_assert(!shape_has_changed, - "EnableCHWN4Pass assumes that inputs' shape of " - "non-nchw4 operators " - "can not be changed in this opt " - "pass"); - return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); - - }; - // capture by copy to avoid use after return - auto replace_warp_affine_opr = - [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, const VarNodeArray new_inp) { - using Param = opr::WarpAffineForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& warp = opr->cast_final_safe(); - if (warp.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_warp_perspective_opr = - [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, const VarNodeArray new_inp) { - using Param = opr::WarpPerspectiveForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& warp = - opr->cast_final_safe(); - if (warp.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - using Param = opr::ResizeForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& resize = opr->cast_final_safe(); - if (resize.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - return replace_inps_to_nchw4(opr, new_inp); - }; - auto replace_pooling_opr = [&varshape_changed, replace_non_nchw4_opr]( - OperatorNodeBase* opr, - const VarNodeArray new_inp) { - using Param = opr::PoolingForward::Param; - using Format = Param::Format; - mgb_assert(opr->input().size() == new_inp.size()); - auto& pooling = opr->cast_final_safe(); - if (pooling.param().format != Format::NCHW4) { - return replace_non_nchw4_opr(opr, new_inp); - } - size_t nr_inps = opr->input().size(); - MGB_MARK_USED_VAR(nr_inps); - mgb_assert(nr_inps == 1); - if (varshape_changed.count(new_inp[0])) { - auto new_param = pooling.param(); - new_param.format = Format::CHWN4; - auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param, - opr->config()); - varshape_changed.insert(new_pooling.node()); - return new_pooling.node()->owner_opr(); - } - return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); - }; - replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr; - - // elemwise like - replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; - replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; - replace_func[opr::ElemwiseMultiType::typeinfo()] = - replace_elemwise_like_opr; - replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr; - - // format aware - replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; - replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr; - replace_func[opr::WarpPerspectiveForward::typeinfo()] = - replace_warp_perspective_opr; - replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; - - // to nchw4 - replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; - replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4; - return ret; -} - -/* ================ EnableNchwxxPass =============== */ -VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, - VarNode* orig_var) const { - if (!orig_var->shape().eq_shape(new_var->shape())) { - return RelayoutPlaceholder::make( - new_var, RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) - .node(); - } - return new_var; -} - -std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( - size_t pack_c_size) { - auto ret = std::make_unique(); - ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); - //! First is whether the conv can trans to nchwxx, second is the filter - //! trans mode - using RelayoutMode = RelayoutPlaceholder::LayoutType; - using TestFilterResult = std::pair; - RelayoutMode weight_to_nchwxx_mode_dense = - RelayoutMode::WEIGHT_NCHW_TO_NCHW88_DENSE; - RelayoutMode weight_to_nchwxx_mode_group = - RelayoutMode::WEIGHT_NCHW_TO_NCHW88_GROUP; - RelayoutMode weight_to_nchwxx_mode_chan = - RelayoutMode::WEIGHT_NCHW_TO_NCHW88_CHAN; - RelayoutMode hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW88; - RelayoutMode src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW88; - RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW88_TO_NCHW; - megdnn::param::ConvBias::Format conv_bias_format = - megdnn::param::ConvBias::Format::NCHW88; - megdnn::param::Convolution::Format conv_format = - megdnn::param::ConvolutionV0::Format::NCHW88; - megdnn::param::Pooling::Format pooling_format = - megdnn::param::Pooling::Format::NCHW88; - std::string convter_pass_name = "conv_format_nchw88"; - mgb_assert(pack_c_size == static_cast(8), - "The ConvertFormatPass to nchwxx only support NCHW88 now !"); - auto test_trans_nchwxx = - [pack_c_size, weight_to_nchwxx_mode_dense, - weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, - hybrid_nchw_nchwxx]( - const megdnn::param::Convolution::Sparse conv_mode, - const VarNode* filter) -> TestFilterResult { - TestFilterResult ret{TransType::TRANS_NONE, {}}; - if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { - size_t IC = filter->shape()[1]; - size_t OC = filter->shape()[0]; - if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { - ret.first = TransType::TRANS_PURE_NCHWXX; - ret.second = weight_to_nchwxx_mode_dense; - } else if (IC < pack_c_size && OC % pack_c_size == 0) { - ret.first = TransType::TRANS_HYBIRD_NCHWXX; - ret.second = hybrid_nchw_nchwxx; - } - } else { - mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); - size_t group = filter->shape()[0]; - size_t ocpg = filter->shape()[1]; - size_t icpg = filter->shape()[2]; - if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) { - ret.first = TransType::TRANS_PURE_NCHWXX; - ret.second = weight_to_nchwxx_mode_chan; - } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { - ret.first = TransType::TRANS_PURE_NCHWXX; - ret.second = weight_to_nchwxx_mode_group; - } - } - return ret; - }; - auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode, - src_to_nchw_mode](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - auto& conv_opr = opr->cast_final_safe(); - mgb_assert(conv_opr.param().format == - megdnn::param::Convolution::Format::NCHW, - "ConvertFormat Pass only support converting NCHW to NCHWXX"); - auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1]); - //! can not trans to nchwxx - if (is_trans.first == TransType::TRANS_NONE) { - mgb_assert(new_inp[1]->shape().ndim == 4 || - new_inp[1]->shape().ndim == 5, - "The origin filter is not NCHW mode"); - VarNodeArray temp_inp = new_inp; - //! if src is nchwxx, should RelayoutPlaceholder to nchw - if (temp_inp[0]->shape().ndim == 5) { - auto new_src = - RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode); - temp_inp[0] = new_src.node(); - } - auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, - opr->config()); - return new_opr; - } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { - //! filter trans to nchwxx mode - mgb_assert(new_inp[1]->shape().ndim == 4 || - new_inp[1]->shape().ndim == 5, - "The origin filter is not NCHW mode"); - VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; - auto new_filter = - RelayoutPlaceholder::make(new_inp[1], is_trans.second); - conv_filter = new_filter.node(); - //! src trans to nchwxx mode - if (new_inp[0]->shape().ndim != 5) { - mgb_assert(new_inp[0]->shape().ndim == 4); - auto new_src = RelayoutPlaceholder::make(new_inp[0], - src_to_nchwxx_mode); - conv_src = new_src.node(); - } - auto new_param = conv_opr.param(); - new_param.format = conv_format; - mgb_assert(conv_src->shape().ndim == 5 && - conv_filter->shape().ndim >= 6, - "The conv src dim is not trans to nchwxx"); - auto new_conv_opr = opr::Convolution::make( - conv_src, conv_filter, new_param, - conv_opr.execution_policy(), conv_opr.config()); - OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); - mgb_assert(new_conv_opr.shape().ndim == 5, - "The conv dst dim is not trans to nchwxx"); - return new_opr; - } else { - mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); - VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; - auto new_filter = - RelayoutPlaceholder::make(new_inp[1], is_trans.second); - conv_filter = new_filter.node(); - mgb_assert(conv_src->shape().ndim == 4 && - conv_filter->shape().ndim == 5, - "The src and filter is OK"); - auto new_param = conv_opr.param(); - new_param.format = conv_format; - auto new_conv_opr = opr::Convolution::make( - conv_src, conv_filter, new_param, - conv_opr.execution_policy(), conv_opr.config()); - OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); - mgb_assert(new_conv_opr.shape().ndim == 5, - "The conv dst dim is not trans to nchwxx"); - return new_opr; - } - }; - - auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format, - src_to_nchwxx_mode, src_to_nchw_mode]( - OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - auto& conv_bias_opr = opr->cast_final_safe(); - mgb_assert(conv_bias_opr.param().format == - megdnn::param::ConvBias::Format::NCHW, - "ConvertFormat Pass only support converting NCHW to NCHWXX"); - auto is_trans = - test_trans_nchwxx(conv_bias_opr.param().sparse, new_inp[1]); - //! can not trans to nchwxx - if (is_trans.first == TransType::TRANS_NONE) { - mgb_assert(new_inp[1]->shape().ndim == 4 || - new_inp[1]->shape().ndim == 5, - "The origin filter is not NCHW mode"); - VarNodeArray temp_inp = new_inp; - //! if src is nchwxx, should RelayoutPlaceholder to nchw - if (temp_inp[0]->shape().ndim == 5) { - auto new_src = - RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode); - temp_inp[0] = new_src.node(); - } - //! the bias is nchwxx - if (temp_inp[2]->shape().ndim == 5) { - auto new_bias = - RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode); - temp_inp[2] = new_bias.node(); - } - auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, - opr->config()); - return new_opr; - } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { - VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], - *conv_bias_bias = new_inp[2]; - //! filter trans to nchwxx mode - mgb_assert(new_inp[1]->shape().ndim == 4 || - new_inp[1]->shape().ndim == 5, - "The origin filter is not NCHW mode"); - auto new_filter = - RelayoutPlaceholder::make(new_inp[1], is_trans.second); - conv_bias_filter = new_filter.node(); - //! src trans to nchwxx mode - if (new_inp[0]->shape().ndim != 5) { - mgb_assert(new_inp[0]->shape().ndim == 4); - auto new_src = RelayoutPlaceholder::make(new_inp[0], - src_to_nchwxx_mode); - conv_bias_src = new_src.node(); - } - //! bias trans to nchwxx mode, bias may be scale - if (new_inp[2]->shape().ndim == 4) { - auto new_bias = RelayoutPlaceholder::make(new_inp[2], - src_to_nchwxx_mode); - conv_bias_bias = new_bias.node(); - } - - auto new_param = conv_bias_opr.param(); - new_param.format = conv_bias_format; - mgb_assert(conv_bias_src->shape().ndim == 5 && - conv_bias_filter->shape().ndim >= 6, - "The conv_bias src dim is not trans to nchwxx"); - auto new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, - conv_bias_opr.execution_policy(), conv_bias_opr.config()); - OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); - mgb_assert(new_conv_bias_opr.shape().ndim == 5, - "The conv_bias dst dim is not trans to nchwxx"); - return new_opr; - } else { - mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); - VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], - *conv_bias_bias = new_inp[2]; - auto new_filter = - RelayoutPlaceholder::make(new_inp[1], is_trans.second); - conv_bias_filter = new_filter.node(); - //! bias trans to nchwxx mode, bias may be scale - if (new_inp[2]->shape().ndim == 4) { - auto new_bias = RelayoutPlaceholder::make(new_inp[2], - src_to_nchwxx_mode); - conv_bias_bias = new_bias.node(); - } - mgb_assert(conv_bias_src->shape().ndim == 4 && - conv_bias_filter->shape().ndim == 5); - mgb_assert((conv_bias_bias->shape().ndim == 5) || - conv_bias_bias->shape().is_scalar()); - auto new_param = conv_bias_opr.param(); - new_param.format = conv_bias_format; - auto new_conv_bias_opr = opr::ConvBias::make( - conv_bias_src, conv_bias_filter, new_param, - conv_bias_opr.execution_policy(), conv_bias_opr.config()); - OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); - mgb_assert(new_conv_bias_opr.shape().ndim == 5, - "The conv dst dim is not trans to nchwxx"); - return new_opr; - } - }; - - auto replace_pooling_opr = [=](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - auto& pooling_opr = opr->cast_final_safe(); - mgb_assert(pooling_opr.param().format == - megdnn::param::Pooling::Format::NCHW, - "ConvertFormat Pass only support converting NCHW to NCHWxx"); - VarNode* inp = new_inp[0]; - //! if input is nchwxx - if (inp->shape().ndim == 5) { - auto new_param = pooling_opr.param(); - new_param.format = pooling_format; - auto new_pooling_opr = - opr::PoolingForward::make(inp, new_param, opr->config()); - mgb_assert(new_pooling_opr.shape().ndim == 5, - "The pooling dst dim is not trans to nchwxx"); - return new_pooling_opr.node()->owner_opr(); - } else { - auto new_opr = serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); - return new_opr; - } - }; - - auto replace_elemwise_opr = [=](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - bool has_inp_changed = false; - for (size_t i = 0; i < opr->input().size(); i++) { - if (new_inp[i]->shape().ndim == 5) { - has_inp_changed = true; - break; - } - } - if (has_inp_changed) { - auto temp_inp = new_inp; - for (size_t i = 0; i < opr->input().size(); i++) { - if (new_inp[i]->shape().ndim == 4) { - auto new_var = RelayoutPlaceholder::make( - new_inp[i], src_to_nchwxx_mode); - temp_inp[i] = new_var.node(); - } else { - mgb_assert((new_inp[i]->shape().ndim == 5) || - new_inp[i]->shape().is_scalar()); - } - } - return serialization::copy_opr_shallow(*opr, temp_inp, - opr->config()); - } else { - return serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); - } - }; - - auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - VarNodeArray temp_inp = new_inp; - for (size_t i = 0; i < opr->input().size(); i++) { - if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { - mgb_assert(opr->input(i)->shape().ndim == 4); - mgb_assert(new_inp[i]->shape().ndim == 5); - auto new_var = - RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode); - temp_inp[i] = new_var.node(); - } - } - return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); - }; - - ret->set_name(convter_pass_name); - auto&& replace_func = ret->m_opr_replace_func; - //! supportted nchwxx - replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; - replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; - replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; - replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; - replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; - replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; - replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; - //! not support yet - replace_func[opr::ConvolutionBackwardData::typeinfo()] = - relayout_inp_to_nchw; - replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; - replace_func[opr::WarpPerspectiveForward::typeinfo()] = - relayout_inp_to_nchw; - replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; - return ret; -} - -/* ==================== ShuffleShuffleRemovePass ================= */ -class ShuffleShuffleRemovePass::Impl { - using TensorFormat = opr::ConvBias::Param::Format; - - OptState& m_opt_state; - ThinHashMap, - thin_function> - m_reformat; - - class AbstractShuffleOpr; - - void detect_shuffle_operations(); - void do_replace(); - -public: - Impl(OptState& opt_state) : m_opt_state{opt_state} { - m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW4)] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make( - {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); - auto y0 = opr::Reshape::make(x, tshp); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW32)] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make( - {sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0); - auto y0 = opr::Reshape::make(x, tshp); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = - opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = - opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0); - auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); - auto y1 = opr::Reshape::make(y0, tshp); - return y1.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW32)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, - 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW4)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp0 = opr::Concat::make( - {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, - 0), - tshp1 = opr::Concat::make( - {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); - auto y0 = opr::Reshape::make(x, tshp0); - auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); - auto y2 = opr::Reshape::make(y1, tshp1); - return y2.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::CHWN4)] = - [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - - }; - - m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] = - [](VarNode* inp) -> VarNode* { - megdnn::param::RelayoutFormat param; - param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4; - auto reformat = opr::RelayoutFormat::make(inp, param); - return reformat.node(); - }; - - m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::CHWN4)] = - [](VarNode* inp) -> VarNode* { - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = opr::Concat::make( - {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); - auto y0 = opr::Reshape::make(x, tshp); - auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2}); - return y1.node(); - - }; - - m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] = - [](VarNode* inp) -> VarNode* { - mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); - auto x = SymbolVar(inp); - auto xshp = opr::GetVarShape::make(x); - - auto cv = [&x](int v) { return x.make_scalar(v); }; - auto sub = [&xshp, &cv](int idx) { - return opr::IndexAt::make(xshp, {{0, cv(idx)}}); - }; - auto tshp = - opr::Concat::make({sub(3), sub(0) * 4, sub(1), sub(2)}, 0); - auto y0 = opr::Dimshuffle::make(x, {3, 0, 4, 1, 2}); - auto y1 = opr::Reshape::make(y0, tshp); - return y1.node(); - }; - detect_shuffle_operations(); - do_replace(); - } -}; - -/*! - * \brief abstract operator representation of shuffle operation - */ -MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, - cg::SingleCNOperatorNodeBase) // { -public: - AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, - TensorFormat out_format); - - static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, - TensorFormat out_format); - - TensorFormat inp_format() const { return m_inp_format; } - - TensorFormat out_format() const { return m_out_format; } - -private: - void init_output_static_infer_desc() override; - void scn_do_execute() override; - const TensorFormat m_inp_format; - const TensorFormat m_out_format; -}; - -MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); - -void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::scn_do_execute() { - mgb_throw(InternalError, "AbstractShuffleOpr cannot be executed"); -} - -void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr:: - init_output_static_infer_desc() { - using namespace cg::static_infer; - auto&& mgr = owner_graph()->static_infer_manager(); - DepVal deps; - for (auto i : input()) - deps.push_back({i, DepType::SHAPE}); - auto infer_shape = [this](TensorShape& dst, const InpVal& inp) { - TensorShape inp_shape = inp.val[0].shape(); - if (m_inp_format == TensorFormat::NCHW4 && - m_out_format == TensorFormat::NCHW32) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst = inp_shape; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4] * 8; - } else if (m_inp_format == TensorFormat::NCHW32 && - m_out_format == TensorFormat::NCHW4) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); - dst = inp_shape; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 8; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = inp_shape[4] / 8; - } else if (m_inp_format == TensorFormat::NCHW && - m_out_format == TensorFormat::NCHW4) { - mgb_assert(inp_shape.ndim == 4); - dst.ndim = 5; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] / 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - dst[4] = 4; - } else if (m_inp_format == TensorFormat::NCHW4 && - m_out_format == TensorFormat::NCHW) { - mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); - dst.ndim = 4; - dst[0] = inp_shape[0]; - dst[1] = inp_shape[1] * 4; - dst[2] = inp_shape[2]; - dst[3] = inp_shape[3]; - } else if (m_inp_format == TensorFormat::NCHW4 && - m_out_format == TensorFormat::CHWN4) { - dst.ndim = 5; - dst[0] = inp_shape[1]; - dst[1] = inp_shape[2]; - dst[2] = inp_shape[3]; - dst[3] = inp_shape[0]; - dst[4] = inp_shape[4]; - } else if (m_inp_format == TensorFormat::CHWN4 && - m_out_format == TensorFormat::NCHW4) { - dst.ndim = 5; - dst[0] = inp_shape[3]; - dst[1] = inp_shape[0]; - dst[2] = inp_shape[1]; - dst[3] = inp_shape[2]; - dst[4] = inp_shape[4]; - } else { - mgb_throw(InternalError, - "Unsupported input format and output format."); - } - return true; - }; - mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape}); -} - -ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::AbstractShuffleOpr( - VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) - : Super(inpvar->owner_graph(), {}, "AbstractShuffleOpr", {inpvar}), - m_inp_format{inp_format}, - m_out_format{out_format} { - add_input({inpvar}); - add_equivalence_component>(m_inp_format); - add_equivalence_component>(m_out_format); - add_output(None)->dtype(inpvar->dtype()); -} - -SymbolVar ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::make( - VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) { - return inpvar->owner_graph() - ->insert_opr(std::make_unique( - inpvar, inp_format, out_format)) - ->output(0); -} - -void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() { - auto rewriter = m_opt_state.graph().make_rewriter(); - auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()}; - auto try_reshape_shuffle = [&rewriter, - &uniq_reader_check](OperatorNodeBase* opr) { - // check shuffle - auto shuffle = try_cast_as_op(opr); - if (shuffle == nullptr) - return false; - auto&& param = shuffle->param(); - if (param.pattern_len != 5) - return false; - bool is_nchw2nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && - param.pattern[2] == 3 && param.pattern[3] == 4 && - param.pattern[4] == 2 && - opr->output(0)->shape()[4] == 4; - if (!is_nchw2nchw4) - return false; - if (!uniq_reader_check(shuffle->input(0))) - return false; - - // check reshape - auto reshape = try_cast_as_op(opr->input(0)->owner_opr()); - if (reshape == nullptr) - return false; - auto inp_var = rewriter.get_var(reshape->input(0)); - auto abstract_shuffle = AbstractShuffleOpr::make( - inp_var, TensorFormat::NCHW, TensorFormat::NCHW4); - rewriter.replace_var( - opr->output(0), abstract_shuffle.node(), - mgb_cstr_log("replace reformat(nchw -> nchw4) to " - "AbstractShuffleOpr(nchw -> nchw4).")); - return true; - }; - - auto try_reshape_shuffle_reshape = [&rewriter, &uniq_reader_check]( - OperatorNodeBase* opr) { - // check reshape - auto reshape1 = try_cast_as_op(opr); - if (reshape1 == nullptr) - return false; - if (!uniq_reader_check(reshape1->input(0))) - return false; - - // check shuffle - auto shuffle = - try_cast_as_op(opr->input(0)->owner_opr()); - if (shuffle == nullptr) - return false; - auto&& param = shuffle->param(); - if (param.pattern_len != 6) - return false; - bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && - param.pattern[2] == 3 && param.pattern[3] == 4 && - param.pattern[4] == 2 && param.pattern[5] == 5 && - shuffle->input(0)->shape()[5] == 4 && - shuffle->input(0)->shape()[2] == 8; - bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && - param.pattern[2] == 4 && param.pattern[3] == 2 && - param.pattern[4] == 3 && param.pattern[5] == 5 && - shuffle->input(0)->shape()[4] == 8 && - shuffle->input(0)->shape()[5] == 4; - if (!is_nchw42nchw32 && !is_nchw322nchw4) - return false; - if (!uniq_reader_check(shuffle->input(0))) - return false; - - // check reshape - auto reshape2 = - try_cast_as_op(shuffle->input(0)->owner_opr()); - if (reshape2 == nullptr) - return false; - auto inp_var = rewriter.get_var(reshape2->input(0)); - TensorFormat inp_format = is_nchw42nchw32 ? TensorFormat::NCHW4 - : TensorFormat::NCHW32, - out_format = is_nchw42nchw32 ? TensorFormat::NCHW32 - : TensorFormat::NCHW4; - auto abstract_shuffle = - AbstractShuffleOpr::make(inp_var, inp_format, out_format); - std::string reformat_type = - is_nchw42nchw32 ? "nchw4 -> nchw32" : "nchw32 -> nchw4"; - rewriter.replace_var(opr->output(0), abstract_shuffle.node(), - mgb_cstr_log(ssprintf("replace reformat(%s) to " - "AbstractShuffleOpr(%s).", - reformat_type.c_str(), - reformat_type.c_str()) - .c_str())); - return true; - }; - - auto try_shuffle_reshape = [&rewriter, - &uniq_reader_check](OperatorNodeBase* opr) { - // check reshape - auto reshape = try_cast_as_op(opr); - if (reshape == nullptr) - return false; - if (!uniq_reader_check(reshape->input(0))) - return false; - - // check shuffle - auto shuffle = - try_cast_as_op(opr->input(0)->owner_opr()); - if (shuffle == nullptr) - return false; - auto&& param = shuffle->param(); - if (param.pattern_len != 5) - return false; - bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && - param.pattern[2] == 4 && param.pattern[3] == 2 && - param.pattern[4] == 3 && - shuffle->input(0)->shape()[4] == 4; - if (!is_nchw42nchw) - return false; - auto inp_var = rewriter.get_var(shuffle->input(0)); - auto abstract_shuffle = AbstractShuffleOpr::make( - inp_var, TensorFormat::NCHW4, TensorFormat::NCHW); - rewriter.replace_var( - opr->output(0), abstract_shuffle.node(), - mgb_cstr_log("replace reformat(nchw4 -> nchw) to " - "AbstractShuffleOpr(nchw4 -> nchw).")); - return true; - }; - - auto try_relayout_format = [&rewriter](OperatorNodeBase* opr) { - // check relayout format - auto reformat = try_cast_as_op(opr); - if (reformat == nullptr) - return false; - auto&& param = reformat->param(); - if (param.mode != opr::RelayoutFormat::Param::Mode::CHWN4_NCHW4 && - param.mode != opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) - return false; - auto inp_var = rewriter.get_var(reformat->input(0)); - cg::SymbolVar abstract_shuffle; - if (param.mode == opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) { - abstract_shuffle = AbstractShuffleOpr::make( - inp_var, TensorFormat::NCHW4, TensorFormat::CHWN4); - } else { - abstract_shuffle = AbstractShuffleOpr::make( - inp_var, TensorFormat::CHWN4, TensorFormat::NCHW4); - } - rewriter.replace_var( - opr->output(0), abstract_shuffle.node(), - mgb_cstr_log("replace reformat(nchw4 -> nchw) to " - "AbstractShuffleOpr(nchw4 -> nchw).")); - return true; - }; - - auto on_opr = [&try_reshape_shuffle, &try_shuffle_reshape, - &try_reshape_shuffle_reshape, &try_relayout_format, - &rewriter, &uniq_reader_check](OperatorNodeBase* opr) { - if (!try_reshape_shuffle_reshape(opr) && !try_reshape_shuffle(opr) && - !try_shuffle_reshape(opr) && !try_relayout_format(opr)) { - auto new_opr = rewriter.auto_replace_outputs(opr); - uniq_reader_check.update_on_opr_auto_replace(opr, new_opr); - } - }; - m_opt_state.graph().iter(on_opr); - rewriter.apply_inplace(); -} - -void ShuffleShuffleRemovePass::Impl::do_replace() { - auto rewriter = m_opt_state.graph().make_rewriter(); - auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()}; - ThinHashMap var2endpoint; - ThinHashSet trt_opr_inps; - SmallVector topo_order; - - auto cb = [&topo_order, &trt_opr_inps](OperatorNodeBase* opr) { - topo_order.push_back(opr); - MGB_MARK_USED_VAR(trt_opr_inps); -#if MGB_ENABLE_TENSOR_RT - if (opr->same_type()) { - for (auto&& inp : opr->input()) - trt_opr_inps.insert(inp); - } -#endif - }; - m_opt_state.graph().iter(cb); - - for (auto&& opr : reverse_adaptor(topo_order)) { - if (opr->same_type() || - opr->same_type()) { - auto find = var2endpoint.find(opr->output(0)); - if (find != var2endpoint.end()) { - if (uniq_reader_check(opr->output(0))) { - var2endpoint[opr->input(0)] = find->second; - } else { - var2endpoint[opr->input(0)] = opr->output(0); - } - } else { - var2endpoint[opr->input(0)] = opr->output(0); - } - } - } - - auto on_opr = [this, &rewriter, &uniq_reader_check, &trt_opr_inps, - &var2endpoint](OperatorNodeBase* opr) { - MGB_MARK_USED_VAR(trt_opr_inps); - bool cond_opr = opr->same_type() || - opr->same_type(); - if (cond_opr) { - bool cond_endpoint = var2endpoint[opr->input(0)] == opr->output(0); - if (!cond_endpoint) - return; - auto cur = opr; - auto var = opr->output(0), inp_var = opr->input(0); - bool force_folding_typecvt = false; - bool first_shuffle = false; - // initialize inp_format and out_format - TensorFormat out_format = TensorFormat::NCHW, inp_format = out_format; - megdnn::DType inp_dtype = cur->input(0)->dtype(), - out_dtype = cur->output(0)->dtype(); - SmallVector out_dtype_vec; - while (cond_opr) { - if (cur->same_type()) { - auto shuffle = try_cast_as_op(cur); - inp_format = shuffle->inp_format(); - if (!first_shuffle) { - out_format = shuffle->out_format(); - first_shuffle = true; - } - } else { - mgb_assert(cur->same_type()); - out_dtype_vec.push_back(cur->output(0)->dtype()); - } - inp_var = cur->input(0); - bool cond_reader = uniq_reader_check(inp_var); - if (!cond_reader) - break; - cur = cur->input(0)->owner_opr(); - cond_opr = cur->same_type() || - cur->same_type(); - } - std::reverse(out_dtype_vec.begin(), out_dtype_vec.end()); -#if MGB_ENABLE_TENSOR_RT - force_folding_typecvt = - inp_var->owner_opr()->same_type() || - trt_opr_inps.count(var); -#endif - auto new_var = rewriter.get_var(inp_var); - if (inp_format != out_format) { - new_var = m_reformat[std::make_pair(inp_format, out_format)]( - new_var); - } - if (force_folding_typecvt) { - inp_dtype = inp_var->dtype(); - if (inp_dtype != out_dtype) { - auto type_cvt = opr::TypeCvt::make(new_var, out_dtype); - new_var = type_cvt.node(); - } - } else { - if (out_dtype_vec.back() != var->dtype()) - out_dtype_vec.push_back(var->dtype()); - for (auto&& dtype : out_dtype_vec) { - auto type_cvt = opr::TypeCvt::make(new_var, dtype); - new_var = type_cvt.node(); - } - } - rewriter.replace_var( - var, new_var, - mgb_cstr_log("replace Dimshuffle and TypeCvt chain")); - } else { - auto new_opr = rewriter.auto_replace_outputs(opr); - uniq_reader_check.update_on_opr_auto_replace(opr, new_opr); - } - }; - m_opt_state.graph().iter(on_opr); - rewriter.apply_inplace(); -} - -const char* ShuffleShuffleRemovePass::name() const { - return mgb_cstr_log("shuffle shuffle remove pass"); -} - -void ShuffleShuffleRemovePass::apply(OptState& opt) const { - opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE | - VarReplaceCheckFlag::CHECK_DTYPE); - Impl{opt}; -} - -void gopt::reformat_to_chwn4_transform_dest_vars_inplace( - mgb::cg::VarNodeArray& dest_vars) { - gopt::GraphOptimizer optimizer; - optimizer.add_pass(); - optimizer.add_pass(); - optimizer.add_pass(EnableCHWN4Pass::make_chwn4_converter()); - optimizer.add_pass(); - optimizer.add_pass(); - optimizer.add_pass(); - optimizer.apply_inplace(dest_vars); -} - // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp new file mode 100644 index 00000000..e8c475e9 --- /dev/null +++ b/src/gopt/impl/tensor_reformat.cpp @@ -0,0 +1,1997 @@ +/** + * \file src/gopt/impl/tensor_reformat.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/gopt/inference.h" +#include "megbrain/gopt/gtrans.h" +#include "megbrain/gopt/basic_arith.h" +#include "megbrain/graph/event.h" +#include "megbrain/opr/dnn/batch_norm.h" +#include "megbrain/opr/dnn/local.h" +#include "megbrain/utils/shared_set.h" +#include "megbrain/serialization/opr_shallow_copy.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/blas.h" +#include "megbrain/opr/misc.h" +#include "megbrain/opr/utility.h" +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/imgproc.h" +#include "megbrain/opr/nn_int.h" + +#include "megdnn/tensor_format.h" + +#if MGB_ENABLE_TENSOR_RT +#include "megbrain/tensorrt/tensorrt_opr.h" +#endif + +#include "megbrain/gopt/misc.h" + +using namespace mgb; +using namespace gopt; + +/* ================ TensorReformatPass =============== */ +/*! + * \brief relayout placeholder opr + * + * RelayoutPlaceholder oprs act as the placeholders of the ComputingGraph + * during graph opt pass `TensorReformatPass`. These oprs are introduced + * into a ComputingGraph for conveniently discovering further optimize + * opportunities (such as fuse consecutive relayouts, translate into + * optimized implementations). They are canonized to have a shape infer, so + * the ouput's shape can be correctly deduced during the opt pass. + * + * Note that the oprs in the ComputingGraph are only used as intermediate + * representations before being translated to MegBrain oprs, so the + * oprs should not get involved in any actual computing. + */ +MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, + cg::SingleCNOperatorNodeBase) // { +public: + //! relayout type of this opr + enum class LayoutType { + NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout + NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout + NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout + CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout + NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout + NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout + WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 + //!< layout + WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to + //!< nchw88 layout + WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout + //!< to nchw88 layout + //!< the weight layout of input is nchw output is nchw88, special for + //!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8} + WEIGHT_HYBIRD_NCHW_NCHW88, + }; + + RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); + + /*! + * \param src_var the input var + * \param layout_type tensor layout transform type of this relayout + * placeholder as described in LayoutType + */ + static SymbolVar make(VarNode* src_var, LayoutType layout_type); + + LayoutType layout_type() const { return m_layout_type; } + +private: + void init_output_static_infer_desc() override; + void scn_do_execute() override; + void init_output_comp_node() override; + const LayoutType m_layout_type; +}; + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder); + +TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder( + VarNode* src_var, LayoutType layout_type) + : Super(src_var->owner_graph(), {}, "RelayoutPlaceholder", {src_var}), + m_layout_type{layout_type} { + add_input({src_var}); + add_equivalence_component>(m_layout_type); + add_output(None)->dtype(src_var->dtype()); +} + +void TensorReformatPass::RelayoutPlaceholder::scn_do_execute() { + mgb_throw(InternalError, "RelayoutPlaceholder opr can not be executed"); +} + +void TensorReformatPass::RelayoutPlaceholder::init_output_comp_node() { + output(0)->comp_node(input(0)->comp_node()); +} + +void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + DepVal deps; + for (auto i : input()) + deps.push_back({i, DepType::SHAPE}); + auto infer_shape = [this](TensorShape& dst, const InpVal& inp) { + TensorShape inp_shape = inp.val[0].shape(); + dst = inp_shape; + if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 8; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4] * 8; + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] * 8; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4] / 8; + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); + dst[0] = inp_shape[1]; + dst[1] = inp_shape[2]; + dst[2] = inp_shape[3]; + dst[3] = inp_shape[0]; + dst[4] = inp_shape[4]; + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); + dst[0] = inp_shape[3]; + dst[1] = inp_shape[0]; + dst[2] = inp_shape[1]; + dst[3] = inp_shape[2]; + dst[4] = inp_shape[4]; + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) { + mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0); + dst.ndim = 5; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 8; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = 8; + } else if (layout_type() == + RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 8); + dst.ndim = 4; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] * 8; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + } else if (layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW88_DENSE) { + mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0 && + inp_shape[1] % 8 == 0); + dst.ndim = 6; + dst[0] = inp_shape[0] / 8; + dst[1] = inp_shape[1] / 8; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = 8; + dst[5] = 8; + } else if (layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW88_GROUP) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 8 == 0 && + inp_shape[2] % 8 == 0); + dst.ndim = 7; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 8; + dst[2] = inp_shape[2] / 8; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4]; + dst[5] = 8; + dst[6] = 8; + } else if (layout_type() == RelayoutPlaceholder::LayoutType:: + WEIGHT_NCHW_TO_NCHW88_CHAN) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 && + inp_shape[2] == 1 && inp_shape[0] % 8 == 0); + dst.ndim = 6; + dst[0] = inp_shape[0] / 8; + dst[1] = inp_shape[1]; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4]; + dst[5] = 8; + } else { + mgb_assert( + layout_type() == + RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88); + mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0); + dst.ndim = 5; + dst[0] = inp_shape[0] / 8; + dst[1] = inp_shape[2]; + dst[2] = inp_shape[3]; + dst[3] = inp_shape[1]; + dst[4] = 8; + } + return true; + }; + mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape}); +} + +SymbolVar TensorReformatPass::RelayoutPlaceholder::make( + VarNode* src_var, LayoutType layout_type) { + return src_var->owner_graph() + ->insert_opr( + std::make_unique(src_var, layout_type)) + ->output(0); +} + +void TensorReformatPass::insert_pass(OptState& opt) const { + opt.set_var_replace_check_flag(m_var_replace_check_flag); + auto rewriter = opt.graph().make_rewriter(); + VarNodeArray new_inp_cache; + auto on_opr = [this, &opt, &rewriter, + &new_inp_cache](OperatorNodeBase* opr) { + auto it = m_opr_replace_func.find(opr->dyn_typeinfo()); + if (it != m_opr_replace_func.end()) { + auto& new_inp = new_inp_cache; + new_inp.clear(); + new_inp.reserve(opr->input().size()); + for (auto&& inp : opr->input()) { + new_inp.push_back(rewriter.get_var(inp)); + } + auto new_opr = (it->second)(opr, new_inp); + auto &&out0 = opr->output(), &&out1 = new_opr->output(); + mgb_assert(out0.size() == out1.size(), + "bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu " + "dst.size=%zu", + opr->cname(), opr->dyn_typeinfo()->name, + new_opr->cname(), new_opr->dyn_typeinfo()->name, + out0.size(), out1.size()); + for (size_t i = 0; i < out0.size(); ++i) { + if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { + mgb_assert(!out1[i]->contain_flag( + VarNode::Flag::VOLATILE_CONTENT)); + auto src = out0[i]; + auto dst = out1[i]; + if (opt.graph().endpoint_contain(src)) { + // additional process on endpoint var node + dst = on_graph_endpoint_var(dst, src); + } + rewriter.replace_var(src, dst, nullptr); + } + } + } else { + rewriter.auto_replace_outputs(opr); + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); +} + +void TensorReformatPass::translate_pass(OptState& opt) const { + ThinHashMap> + reformat; + using LayoutType = RelayoutPlaceholder::LayoutType; + reformat[LayoutType::NCHW4_TO_CHWN4] = [](VarNode* inp) -> VarNode* { + megdnn::param::RelayoutFormat param; + param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4; + auto reformat = opr::RelayoutFormat::make(inp, param); + return reformat.node(); + }; + reformat[LayoutType::CHWN4_TO_NCHW4] = [](VarNode* inp) -> VarNode* { + megdnn::param::RelayoutFormat param; + param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4; + auto reformat = opr::RelayoutFormat::make(inp, param); + return reformat.node(); + }; + reformat[LayoutType::NCHW4_TO_NCHW32] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::NCHW32_TO_NCHW4] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 8, cv(8), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 8, sub(2), sub(3), cv(8)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::NCHW88_TO_NCHW] = [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(1) * 8, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp0); + return y1.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_DENSE] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0) / 8, cv(8), sub(1) / 8, cv(8), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0) / 8, sub(1) / 8, sub(2), sub(3), cv(8), cv(8)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_GROUP] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make({sub(0), sub(1) / 8, cv(8), sub(2) / 8, + cv(8), sub(3), sub(4)}, + 0), + tshp1 = opr::Concat::make({sub(0), sub(1) / 8, sub(2) / 8, sub(3), + sub(4), cv(8), cv(8)}, + 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_CHAN] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0) / 8, cv(8), sub(1), sub(2), sub(3), sub(4)}, 0), + tshp1 = opr::Concat::make( + {sub(0) / 8, sub(1), sub(2), sub(3), sub(4), cv(8)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0) / 8, cv(8), sub(1), sub(2), sub(3)}, 0), + tshp1 = opr::Concat::make( + {sub(0) / 8, sub(2), sub(3), sub(1), cv(8)}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + + auto rewriter = opt.graph().make_rewriter(); + auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { + if (opr->same_type()) { + auto ph = try_cast_as_op(opr); + auto new_inp = rewriter.get_var(opr->input(0)); + mgb_assert(reformat.count(ph->layout_type()), + "no replace rule can be found for layout_type(%u)", + static_cast(ph->layout_type())); + auto new_var = reformat[ph->layout_type()](new_inp); + rewriter.replace_var(opr->output(0), new_var, + mgb_cstr_log("replace relayout placeholder")); + return; + } + rewriter.auto_replace_outputs(opr); + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); +} + +void TensorReformatPass::apply(OptState& opt) const { + insert_pass(opt); + translate_pass(opt); +} + +/* ================ EnableTensorCorePass =============== */ +VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const { + if (!orig_var->shape().eq_shape(new_var->shape())) { + return RelayoutPlaceholder::make( + new_var, + RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) + .node(); + } + return new_var; +} + +std::unique_ptr +EnableTensorCorePass::make_tensorcore_converter() { + // replace rule for conv bias opr + auto replace_conv_bias_opr = [](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + using Param = megdnn::param::ConvBias; + using Format = Param::Format; + using Sparse = Param::Sparse; + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_bias = opr->cast_final_safe(); + if (conv_bias.param().format != Format::NCHW4 || + conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) { + size_t nr_inps = opr->input().size(); + bool shape_has_changed = false; + for (size_t i = 0; i < nr_inps; ++i) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + shape_has_changed = true; + } + } + MGB_MARK_USED_VAR(shape_has_changed); + mgb_assert( + !shape_has_changed, + "EnableTensorCorePass assumes that the shape of inputs of" + "ConvBias operators whose output dtype is not QuantizedS8 " + "can not be changed in this opt pass"); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()), + "EnableTensorCorePass assumes that filter tensor of " + "conv_bias operator can not be changed by other operators"); + VarNode* orig_filter = opr->input(1); + auto is_nchw4 = [](TensorShape shape) -> bool { + return shape.ndim == 5 && shape[4] == 4; + }; + auto is_nchw32 = [](TensorShape shape) -> bool { + return shape.ndim == 5 && shape[4] == 32; + }; + bool can_replace_nchw32 = false; + VarNode *src = nullptr, *weight = nullptr, *bias = nullptr, + *z_inp = nullptr; + // process src tensor + if (is_nchw4(new_inp[0]->shape())) { // new input is NCHW4 layout + size_t group = 1, icpg, ocpg; + if (conv_bias.param().sparse == Sparse::DENSE) { + icpg = orig_filter->shape()[1] * 4; + ocpg = orig_filter->shape()[0]; + } else { + mgb_assert(conv_bias.param().sparse == Sparse::GROUP); + group = orig_filter->shape()[0]; + icpg = orig_filter->shape()[2]; + ocpg = orig_filter->shape()[1]; + if (icpg == 1 && ocpg == 1) { // channel wise conv + group *= 4; + } else { + icpg *= 4; + } + } + // nchw32 layout need that input width and height are larger than 3 + size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3]; + if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && + iw >= 3) { + auto symvar = RelayoutPlaceholder::make( + new_inp[0], + RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); + src = symvar.node(); + can_replace_nchw32 = true; + } else { + src = new_inp[0]; + } + } else { // new input is NCHW32 layout + mgb_assert(is_nchw32(new_inp[0]->shape())); + size_t group = 1, ocpg; + if (conv_bias.param().sparse == Sparse::DENSE) { + ocpg = orig_filter->shape()[0]; + } else { + mgb_assert(conv_bias.param().sparse == Sparse::GROUP); + size_t icpg = orig_filter->shape()[2]; + ocpg = orig_filter->shape()[1]; + if (icpg == 1 && ocpg == 1) { + group *= 4; + } else { + icpg *= 4; + } + } + size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3]; + if (group == 1 && ocpg % 32 == 0 && ih >= 3 && iw >= 3) { + can_replace_nchw32 = true; + src = new_inp[0]; + } else { + auto symvar = RelayoutPlaceholder::make( + new_inp[0], + RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); + src = symvar.node(); + } + } + // process filter tensor + if (can_replace_nchw32) { + auto symvar = RelayoutPlaceholder::make( + new_inp[1], + RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); + weight = symvar.node(); + } else { + weight = new_inp[1]; + } + if (new_inp.size() == 2) { + if (can_replace_nchw32) { + auto param = conv_bias.param(); + param.format = Format::NCHW32; + auto new_opr = opr::ConvBiasForward::make( + src, weight, param, conv_bias.execution_policy(), + conv_bias.config()); + return new_opr.node()->owner_opr(); + } else { + VarNodeArray inps{src, weight}; + auto new_opr = serialization::copy_opr_shallow(*opr, inps, + opr->config()); + return new_opr; + } + } + auto process_inp = [&](VarNode* inp) -> VarNode* { + if (can_replace_nchw32) { + if (is_nchw4(inp->shape())) { + auto symvar = RelayoutPlaceholder::make( + inp, + RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); + return symvar.node(); + } else { + mgb_assert(is_nchw32(inp->shape())); + return inp; + } + } else { + if (is_nchw4(inp->shape())) { + return inp; + } else { + mgb_assert(is_nchw32(inp->shape())); + auto symvar = RelayoutPlaceholder::make( + inp, + RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); + return symvar.node(); + } + } + }; + // process bias tensor + bias = process_inp(new_inp[2]); + if (new_inp.size() == 3) { + if (can_replace_nchw32) { + auto param = conv_bias.param(); + param.format = Format::NCHW32; + auto new_opr = opr::ConvBiasForward::make( + src, weight, bias, param, conv_bias.execution_policy(), + conv_bias.config()); + return new_opr.node()->owner_opr(); + } else { + VarNodeArray inps{src, weight, bias}; + auto new_opr = serialization::copy_opr_shallow(*opr, inps, + opr->config()); + return new_opr; + } + } + // process z_inp tensor + z_inp = process_inp(new_inp[3]); + if (can_replace_nchw32) { + auto param = conv_bias.param(); + param.format = Format::NCHW32; + auto new_opr = opr::ConvBiasForward::make( + src, weight, bias, z_inp, param, + conv_bias.execution_policy(), conv_bias.config()); + return new_opr.node()->owner_opr(); + } + VarNodeArray inps{src, weight, bias, z_inp}; + auto new_opr = + serialization::copy_opr_shallow(*opr, inps, opr->config()); + return new_opr; + }; + // replace rule for elemwise like opr + // for oprs support NCHW4 and NCHW32 layout + auto replace_elemwise_like_opr = [](OperatorNodeBase* opr, + const VarNodeArray new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + size_t nr_inps = new_inp.size(); + size_t nr_shape_changed = 0; + for (size_t i = 0; i < nr_inps; ++i) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + nr_shape_changed++; + } + } + if (nr_shape_changed) { + auto inps = new_inp; + if (nr_shape_changed >= + nr_inps / 2) { // NCHW32 > NCHW4 -> use NCHW32 + for (size_t i = 0; i < nr_inps; ++i) { + if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + auto symvar = RelayoutPlaceholder::make( + new_inp[i], RelayoutPlaceholder::LayoutType:: + NCHW4_TO_NCHW32); + inps[i] = symvar.node(); + } + } + } else { // NCHW32 < NCHW4 -> use NCHW4 + for (size_t i = 0; i < nr_inps; ++i) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + auto symvar = RelayoutPlaceholder::make( + new_inp[i], RelayoutPlaceholder::LayoutType:: + NCHW32_TO_NCHW4); + inps[i] = symvar.node(); + } + } + } + return serialization::copy_opr_shallow(*opr, inps, opr->config()); + } + return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + }; + // for oprs only supports NCHW4 layout + auto replace_inps_to_nchw4 = [](OperatorNodeBase* opr, + const VarNodeArray new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + VarNodeArray inps = new_inp; + for (size_t i = 0; i < opr->input().size(); ++i) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + mgb_assert(opr->input(i)->shape().ndim == 5 && + opr->input(i)->shape()[4] == 4); + mgb_assert(new_inp[i]->shape().ndim == 5 && + new_inp[i]->shape()[4] == 32); + auto symvar = RelayoutPlaceholder::make( + new_inp[i], + RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); + inps[i] = symvar.node(); + } + } + auto new_opr = + serialization::copy_opr_shallow(*opr, inps, opr->config()); + return new_opr; + }; + auto replace_non_nchw4_opr = [](OperatorNodeBase* opr, + const VarNodeArray new_inp) { + size_t nr_inps = opr->input().size(); + bool shape_has_changed = false; + for (size_t i = 0; i < nr_inps; ++i) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + shape_has_changed = true; + } + } + mgb_assert(!shape_has_changed, + "EnableTensorCorePass assumes that inputs' shape of " + "non-nchw4 operators " + "can not be changed in this opt " + "pass"); + return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + + }; + auto replace_warp_affine_opr = + [replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, const VarNodeArray new_inp) { + using Param = opr::WarpAffineForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& warp = opr->cast_final_safe(); + if (warp.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_warp_perspective_opr = + [replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, const VarNodeArray new_inp) { + using Param = opr::WarpPerspectiveForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& warp = + opr->cast_final_safe(); + if (warp.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, + const VarNodeArray new_inp) { + using Param = opr::ResizeForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& resize = opr->cast_final_safe(); + if (resize.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_pooling_opr = [replace_non_nchw4_opr]( + OperatorNodeBase* opr, + const VarNodeArray new_inp) { + using Param = opr::PoolingForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& pooling = opr->cast_final_safe(); + if (pooling.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + size_t nr_inps = opr->input().size(); + MGB_MARK_USED_VAR(nr_inps); + mgb_assert(nr_inps == 1); + if (!opr->input(0)->shape().eq_shape(new_inp[0]->shape())) { + mgb_assert(opr->input(0)->shape().ndim == 5 && + opr->input(0)->shape()[4] == 4); + mgb_assert(new_inp[0]->shape().ndim == 5 && + new_inp[0]->shape()[4] == 32); + auto new_param = pooling.param(); + new_param.format = Format::NCHW32; + auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param, + opr->config()); + return new_pooling.node()->owner_opr(); + } + return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + }; + auto ret = std::make_unique(); + ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); + auto&& replace_func = ret->m_opr_replace_func; + replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr; + + // elemwise like + replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; + replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; + replace_func[opr::ElemwiseMultiType::typeinfo()] = + replace_elemwise_like_opr; + replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr; + + // format aware + replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; + replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr; + replace_func[opr::WarpPerspectiveForward::typeinfo()] = + replace_warp_perspective_opr; + replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; + + // to nchw4 + replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; + return ret; +} + +/* ================ EnableCHWN4Pass =============== */ +VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var, + VarNode* /* orig_var */) const { + if (m_varshape_changed.count(new_var)) { + return RelayoutPlaceholder::make( + new_var, RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) + .node(); + } + return new_var; +} + +std::unique_ptr EnableCHWN4Pass::make_chwn4_converter() { + auto ret = std::make_unique(); + ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); + auto&& replace_func = ret->m_opr_replace_func; + auto&& varshape_changed = ret->m_varshape_changed; + // replace rule for conv bias opr + auto replace_conv_bias_opr = [&varshape_changed]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + using Param = megdnn::param::ConvBias; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_bias = opr->cast_final_safe(); + if (conv_bias.param().format != Format::NCHW4 || + conv_bias.output(0)->dtype().enumv() != DTypeEnum::QuantizedS8) { + size_t nr_inps = new_inp.size(); + bool shape_has_changed = false; + for (size_t i = 0; i < nr_inps; ++i) { + if (varshape_changed.count(new_inp[i])) { + shape_has_changed = true; + break; + } + } + mgb_assert( + !shape_has_changed, + "EnableCHWN4Pass assumes that the shape of inputs of" + "ConvBias operators whose output dtype is not QuantizedS8 " + "can not be changed in this opt pass"); + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + mgb_assert(varshape_changed.count(new_inp[1]) == 0, + "EnableCHWN4Pass assumes that filter tensor of " + "conv_bias operator can not be changed by other operators"); + VarNode *src = nullptr, *weight = nullptr, *bias = nullptr, + *z_inp = nullptr; + // process src tensor + if (varshape_changed.count(new_inp[0]) == + 0) { // new input is NCHW4 layout + // currently not support group conv + auto symvar = RelayoutPlaceholder::make( + new_inp[0], + RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); + src = symvar.node(); + } else { // new input is NCHW32 layout + src = new_inp[0]; + } + // process weight tensor + { + auto symvar = RelayoutPlaceholder::make( + new_inp[1], + RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); + weight = symvar.node(); + } + if (new_inp.size() == 2) { + auto param = conv_bias.param(); + param.format = Format::CHWN4; + auto new_opr = opr::ConvBiasForward::make( + src, weight, param, conv_bias.execution_policy(), + conv_bias.config()); + varshape_changed.insert(new_opr.node()); + return new_opr.node()->owner_opr(); + } + auto process_inp = [&](VarNode* inp) -> VarNode* { + if (varshape_changed.count(inp) == 0) { + auto symvar = RelayoutPlaceholder::make( + inp, RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4); + return symvar.node(); + } else { + return inp; + } + }; + // process bias tensor + bias = process_inp(new_inp[2]); + if (new_inp.size() == 3) { + auto param = conv_bias.param(); + param.format = Format::CHWN4; + auto new_opr = opr::ConvBiasForward::make( + src, weight, bias, param, conv_bias.execution_policy(), + conv_bias.config()); + varshape_changed.insert(new_opr.node()); + return new_opr.node()->owner_opr(); + } + // process z_inp tensor + z_inp = process_inp(new_inp[3]); + auto param = conv_bias.param(); + param.format = Format::CHWN4; + auto new_opr = opr::ConvBiasForward::make( + src, weight, bias, z_inp, param, conv_bias.execution_policy(), + conv_bias.config()); + varshape_changed.insert(new_opr.node()); + return new_opr.node()->owner_opr(); + }; + // replace rule for elemwise like opr + // for oprs support NCHW4 and CHWN4 layout + auto replace_elemwise_like_opr = [&varshape_changed]( + OperatorNodeBase* opr, + const VarNodeArray new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + size_t nr_inps = new_inp.size(); + size_t nr_shape_changed = 0; + for (size_t i = 0; i < nr_inps; ++i) { + if (varshape_changed.count(new_inp[i])) { + nr_shape_changed++; + } + } + if (nr_shape_changed) { + auto inps = new_inp; + if (nr_shape_changed >= nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4 + for (size_t i = 0; i < nr_inps; ++i) { + if (varshape_changed.count(new_inp[i]) == 0) { + auto symvar = RelayoutPlaceholder::make( + new_inp[i], RelayoutPlaceholder::LayoutType:: + NCHW4_TO_CHWN4); + inps[i] = symvar.node(); + } + } + auto new_opr = serialization::copy_opr_shallow(*opr, inps, + opr->config()); + varshape_changed.insert(new_opr->output(0)); + return new_opr; + } else { // CHWN4 < NCHW4 -> use NCHW4 + for (size_t i = 0; i < nr_inps; ++i) { + if (varshape_changed.count(new_inp[i])) { + auto symvar = RelayoutPlaceholder::make( + new_inp[i], RelayoutPlaceholder::LayoutType:: + CHWN4_TO_NCHW4); + inps[i] = symvar.node(); + } + } + return serialization::copy_opr_shallow(*opr, inps, + opr->config()); + } + } + return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + }; + // for oprs only supports NCHW4 layout + auto replace_inps_to_nchw4 = [&varshape_changed]( + OperatorNodeBase* opr, + const VarNodeArray new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + VarNodeArray inps = new_inp; + for (size_t i = 0; i < opr->input().size(); ++i) { + if (varshape_changed.count(new_inp[i])) { + auto symvar = RelayoutPlaceholder::make( + new_inp[i], + RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4); + inps[i] = symvar.node(); + } + } + auto new_opr = + serialization::copy_opr_shallow(*opr, inps, opr->config()); + return new_opr; + }; + auto replace_non_nchw4_opr = [&varshape_changed]( + OperatorNodeBase* opr, + const VarNodeArray new_inp) { + size_t nr_inps = opr->input().size(); + bool shape_has_changed = false; + for (size_t i = 0; i < nr_inps; ++i) { + if (varshape_changed.count(new_inp[i])) { + shape_has_changed = true; + } + } + mgb_assert(!shape_has_changed, + "EnableCHWN4Pass assumes that inputs' shape of " + "non-nchw4 operators " + "can not be changed in this opt " + "pass"); + return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + + }; + // capture by copy to avoid use after return + auto replace_warp_affine_opr = + [replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, const VarNodeArray new_inp) { + using Param = opr::WarpAffineForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& warp = opr->cast_final_safe(); + if (warp.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_warp_perspective_opr = + [replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, const VarNodeArray new_inp) { + using Param = opr::WarpPerspectiveForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& warp = + opr->cast_final_safe(); + if (warp.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( + OperatorNodeBase* opr, + const VarNodeArray new_inp) { + using Param = opr::ResizeForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& resize = opr->cast_final_safe(); + if (resize.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + return replace_inps_to_nchw4(opr, new_inp); + }; + auto replace_pooling_opr = [&varshape_changed, replace_non_nchw4_opr]( + OperatorNodeBase* opr, + const VarNodeArray new_inp) { + using Param = opr::PoolingForward::Param; + using Format = Param::Format; + mgb_assert(opr->input().size() == new_inp.size()); + auto& pooling = opr->cast_final_safe(); + if (pooling.param().format != Format::NCHW4) { + return replace_non_nchw4_opr(opr, new_inp); + } + size_t nr_inps = opr->input().size(); + MGB_MARK_USED_VAR(nr_inps); + mgb_assert(nr_inps == 1); + if (varshape_changed.count(new_inp[0])) { + auto new_param = pooling.param(); + new_param.format = Format::CHWN4; + auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param, + opr->config()); + varshape_changed.insert(new_pooling.node()); + return new_pooling.node()->owner_opr(); + } + return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); + }; + replace_func[opr::ConvBiasForward::typeinfo()] = replace_conv_bias_opr; + + // elemwise like + replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; + replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; + replace_func[opr::ElemwiseMultiType::typeinfo()] = + replace_elemwise_like_opr; + replace_func[opr::PowC::typeinfo()] = replace_elemwise_like_opr; + + // format aware + replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; + replace_func[opr::WarpAffineForward::typeinfo()] = replace_warp_affine_opr; + replace_func[opr::WarpPerspectiveForward::typeinfo()] = + replace_warp_perspective_opr; + replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; + + // to nchw4 + replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::Concat::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::Reshape::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; + replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4; + return ret; +} + +/* ================ EnableNchwxxPass =============== */ +VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, + VarNode* orig_var) const { + if (!orig_var->shape().eq_shape(new_var->shape())) { + return RelayoutPlaceholder::make( + new_var, RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) + .node(); + } + return new_var; +} + +std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( + size_t pack_c_size) { + auto ret = std::make_unique(); + ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); + //! First is whether the conv can trans to nchwxx, second is the filter + //! trans mode + using RelayoutMode = RelayoutPlaceholder::LayoutType; + using TestFilterResult = std::pair; + RelayoutMode weight_to_nchwxx_mode_dense = + RelayoutMode::WEIGHT_NCHW_TO_NCHW88_DENSE; + RelayoutMode weight_to_nchwxx_mode_group = + RelayoutMode::WEIGHT_NCHW_TO_NCHW88_GROUP; + RelayoutMode weight_to_nchwxx_mode_chan = + RelayoutMode::WEIGHT_NCHW_TO_NCHW88_CHAN; + RelayoutMode hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW88; + RelayoutMode src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW88; + RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW88_TO_NCHW; + megdnn::param::ConvBias::Format conv_bias_format = + megdnn::param::ConvBias::Format::NCHW88; + megdnn::param::Convolution::Format conv_format = + megdnn::param::ConvolutionV0::Format::NCHW88; + megdnn::param::Pooling::Format pooling_format = + megdnn::param::Pooling::Format::NCHW88; + std::string convter_pass_name = "conv_format_nchw88"; + mgb_assert(pack_c_size == static_cast(8), + "The ConvertFormatPass to nchwxx only support NCHW88 now !"); + auto test_trans_nchwxx = + [pack_c_size, weight_to_nchwxx_mode_dense, + weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, + hybrid_nchw_nchwxx]( + const megdnn::param::Convolution::Sparse conv_mode, + const VarNode* filter) -> TestFilterResult { + TestFilterResult ret{TransType::TRANS_NONE, {}}; + if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { + size_t IC = filter->shape()[1]; + size_t OC = filter->shape()[0]; + if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { + ret.first = TransType::TRANS_PURE_NCHWXX; + ret.second = weight_to_nchwxx_mode_dense; + } else if (IC < pack_c_size && OC % pack_c_size == 0) { + ret.first = TransType::TRANS_HYBIRD_NCHWXX; + ret.second = hybrid_nchw_nchwxx; + } + } else { + mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); + size_t group = filter->shape()[0]; + size_t ocpg = filter->shape()[1]; + size_t icpg = filter->shape()[2]; + if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) { + ret.first = TransType::TRANS_PURE_NCHWXX; + ret.second = weight_to_nchwxx_mode_chan; + } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { + ret.first = TransType::TRANS_PURE_NCHWXX; + ret.second = weight_to_nchwxx_mode_group; + } + } + return ret; + }; + auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode, + src_to_nchw_mode](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_opr = opr->cast_final_safe(); + mgb_assert(conv_opr.param().format == + megdnn::param::Convolution::Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHWXX"); + auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1]); + //! can not trans to nchwxx + if (is_trans.first == TransType::TRANS_NONE) { + mgb_assert(new_inp[1]->shape().ndim == 4 || + new_inp[1]->shape().ndim == 5, + "The origin filter is not NCHW mode"); + VarNodeArray temp_inp = new_inp; + //! if src is nchwxx, should RelayoutPlaceholder to nchw + if (temp_inp[0]->shape().ndim == 5) { + auto new_src = + RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode); + temp_inp[0] = new_src.node(); + } + auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, + opr->config()); + return new_opr; + } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { + //! filter trans to nchwxx mode + mgb_assert(new_inp[1]->shape().ndim == 4 || + new_inp[1]->shape().ndim == 5, + "The origin filter is not NCHW mode"); + VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; + auto new_filter = + RelayoutPlaceholder::make(new_inp[1], is_trans.second); + conv_filter = new_filter.node(); + //! src trans to nchwxx mode + if (new_inp[0]->shape().ndim != 5) { + mgb_assert(new_inp[0]->shape().ndim == 4); + auto new_src = RelayoutPlaceholder::make(new_inp[0], + src_to_nchwxx_mode); + conv_src = new_src.node(); + } + auto new_param = conv_opr.param(); + new_param.format = conv_format; + mgb_assert(conv_src->shape().ndim == 5 && + conv_filter->shape().ndim >= 6, + "The conv src dim is not trans to nchwxx"); + auto new_conv_opr = opr::Convolution::make( + conv_src, conv_filter, new_param, + conv_opr.execution_policy(), conv_opr.config()); + OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); + mgb_assert(new_conv_opr.shape().ndim == 5, + "The conv dst dim is not trans to nchwxx"); + return new_opr; + } else { + mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); + VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; + auto new_filter = + RelayoutPlaceholder::make(new_inp[1], is_trans.second); + conv_filter = new_filter.node(); + mgb_assert(conv_src->shape().ndim == 4 && + conv_filter->shape().ndim == 5, + "The src and filter is OK"); + auto new_param = conv_opr.param(); + new_param.format = conv_format; + auto new_conv_opr = opr::Convolution::make( + conv_src, conv_filter, new_param, + conv_opr.execution_policy(), conv_opr.config()); + OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr(); + mgb_assert(new_conv_opr.shape().ndim == 5, + "The conv dst dim is not trans to nchwxx"); + return new_opr; + } + }; + + auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format, + src_to_nchwxx_mode, src_to_nchw_mode]( + OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& conv_bias_opr = opr->cast_final_safe(); + mgb_assert(conv_bias_opr.param().format == + megdnn::param::ConvBias::Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHWXX"); + auto is_trans = + test_trans_nchwxx(conv_bias_opr.param().sparse, new_inp[1]); + //! can not trans to nchwxx + if (is_trans.first == TransType::TRANS_NONE) { + mgb_assert(new_inp[1]->shape().ndim == 4 || + new_inp[1]->shape().ndim == 5, + "The origin filter is not NCHW mode"); + VarNodeArray temp_inp = new_inp; + //! if src is nchwxx, should RelayoutPlaceholder to nchw + if (temp_inp[0]->shape().ndim == 5) { + auto new_src = + RelayoutPlaceholder::make(new_inp[0], src_to_nchw_mode); + temp_inp[0] = new_src.node(); + } + //! the bias is nchwxx + if (temp_inp[2]->shape().ndim == 5) { + auto new_bias = + RelayoutPlaceholder::make(new_inp[2], src_to_nchw_mode); + temp_inp[2] = new_bias.node(); + } + auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, + opr->config()); + return new_opr; + } else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { + VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], + *conv_bias_bias = new_inp[2]; + //! filter trans to nchwxx mode + mgb_assert(new_inp[1]->shape().ndim == 4 || + new_inp[1]->shape().ndim == 5, + "The origin filter is not NCHW mode"); + auto new_filter = + RelayoutPlaceholder::make(new_inp[1], is_trans.second); + conv_bias_filter = new_filter.node(); + //! src trans to nchwxx mode + if (new_inp[0]->shape().ndim != 5) { + mgb_assert(new_inp[0]->shape().ndim == 4); + auto new_src = RelayoutPlaceholder::make(new_inp[0], + src_to_nchwxx_mode); + conv_bias_src = new_src.node(); + } + //! bias trans to nchwxx mode, bias may be scale + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make(new_inp[2], + src_to_nchwxx_mode); + conv_bias_bias = new_bias.node(); + } + + auto new_param = conv_bias_opr.param(); + new_param.format = conv_bias_format; + mgb_assert(conv_bias_src->shape().ndim == 5 && + conv_bias_filter->shape().ndim >= 6, + "The conv_bias src dim is not trans to nchwxx"); + auto new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); + mgb_assert(new_conv_bias_opr.shape().ndim == 5, + "The conv_bias dst dim is not trans to nchwxx"); + return new_opr; + } else { + mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); + VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], + *conv_bias_bias = new_inp[2]; + auto new_filter = + RelayoutPlaceholder::make(new_inp[1], is_trans.second); + conv_bias_filter = new_filter.node(); + //! bias trans to nchwxx mode, bias may be scale + if (new_inp[2]->shape().ndim == 4) { + auto new_bias = RelayoutPlaceholder::make(new_inp[2], + src_to_nchwxx_mode); + conv_bias_bias = new_bias.node(); + } + mgb_assert(conv_bias_src->shape().ndim == 4 && + conv_bias_filter->shape().ndim == 5); + mgb_assert((conv_bias_bias->shape().ndim == 5) || + conv_bias_bias->shape().is_scalar()); + auto new_param = conv_bias_opr.param(); + new_param.format = conv_bias_format; + auto new_conv_bias_opr = opr::ConvBias::make( + conv_bias_src, conv_bias_filter, new_param, + conv_bias_opr.execution_policy(), conv_bias_opr.config()); + OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr(); + mgb_assert(new_conv_bias_opr.shape().ndim == 5, + "The conv dst dim is not trans to nchwxx"); + return new_opr; + } + }; + + auto replace_pooling_opr = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + auto& pooling_opr = opr->cast_final_safe(); + mgb_assert(pooling_opr.param().format == + megdnn::param::Pooling::Format::NCHW, + "ConvertFormat Pass only support converting NCHW to NCHWxx"); + VarNode* inp = new_inp[0]; + //! if input is nchwxx + if (inp->shape().ndim == 5) { + auto new_param = pooling_opr.param(); + new_param.format = pooling_format; + auto new_pooling_opr = + opr::PoolingForward::make(inp, new_param, opr->config()); + mgb_assert(new_pooling_opr.shape().ndim == 5, + "The pooling dst dim is not trans to nchwxx"); + return new_pooling_opr.node()->owner_opr(); + } else { + auto new_opr = serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + return new_opr; + } + }; + + auto replace_elemwise_opr = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + bool has_inp_changed = false; + for (size_t i = 0; i < opr->input().size(); i++) { + if (new_inp[i]->shape().ndim == 5) { + has_inp_changed = true; + break; + } + } + if (has_inp_changed) { + auto temp_inp = new_inp; + for (size_t i = 0; i < opr->input().size(); i++) { + if (new_inp[i]->shape().ndim == 4) { + auto new_var = RelayoutPlaceholder::make( + new_inp[i], src_to_nchwxx_mode); + temp_inp[i] = new_var.node(); + } else { + mgb_assert((new_inp[i]->shape().ndim == 5) || + new_inp[i]->shape().is_scalar()); + } + } + return serialization::copy_opr_shallow(*opr, temp_inp, + opr->config()); + } else { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } + }; + + auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { + mgb_assert(opr->input().size() == new_inp.size()); + VarNodeArray temp_inp = new_inp; + for (size_t i = 0; i < opr->input().size(); i++) { + if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { + mgb_assert(opr->input(i)->shape().ndim == 4); + mgb_assert(new_inp[i]->shape().ndim == 5); + auto new_var = + RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode); + temp_inp[i] = new_var.node(); + } + } + return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); + }; + + ret->set_name(convter_pass_name); + auto&& replace_func = ret->m_opr_replace_func; + //! supportted nchwxx + replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; + replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; + replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; + replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; + replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; + replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; + replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; + //! not support yet + replace_func[opr::ConvolutionBackwardData::typeinfo()] = + relayout_inp_to_nchw; + replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Concat::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::GetVarShape::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Dimshuffle::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Reduce::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::ResizeForward::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::WarpPerspectiveForward::typeinfo()] = + relayout_inp_to_nchw; + replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; + return ret; +} + +/* ==================== ShuffleShuffleRemovePass ================= */ +class ShuffleShuffleRemovePass::Impl { + using TensorFormat = opr::ConvBias::Param::Format; + + OptState& m_opt_state; + ThinHashMap, + thin_function> + m_reformat; + + class AbstractShuffleOpr; + + void detect_shuffle_operations(); + void do_replace(); + +public: + Impl(OptState& opt_state) : m_opt_state{opt_state} { + m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW4)] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = opr::Concat::make( + {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); + auto y0 = opr::Reshape::make(x, tshp); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); + return y1.node(); + }; + + m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW32)] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = opr::Concat::make( + {sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0); + auto y0 = opr::Reshape::make(x, tshp); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2}); + return y1.node(); + }; + + m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW)] = + [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = + opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp); + return y1.node(); + }; + + m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW)] = + [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = + opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0); + auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); + auto y1 = opr::Reshape::make(y0, tshp); + return y1.node(); + }; + + m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW32)] = + [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, + 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + + m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW4)] = + [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp0 = opr::Concat::make( + {sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, + 0), + tshp1 = opr::Concat::make( + {sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); + auto y0 = opr::Reshape::make(x, tshp0); + auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); + auto y2 = opr::Reshape::make(y1, tshp1); + return y2.node(); + }; + + m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::CHWN4)] = + [](VarNode* inp) -> VarNode* { + megdnn::param::RelayoutFormat param; + param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4; + auto reformat = opr::RelayoutFormat::make(inp, param); + return reformat.node(); + + }; + + m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] = + [](VarNode* inp) -> VarNode* { + megdnn::param::RelayoutFormat param; + param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4; + auto reformat = opr::RelayoutFormat::make(inp, param); + return reformat.node(); + }; + + m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::CHWN4)] = + [](VarNode* inp) -> VarNode* { + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = opr::Concat::make( + {sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0); + auto y0 = opr::Reshape::make(x, tshp); + auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2}); + return y1.node(); + + }; + + m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] = + [](VarNode* inp) -> VarNode* { + mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); + auto x = SymbolVar(inp); + auto xshp = opr::GetVarShape::make(x); + + auto cv = [&x](int v) { return x.make_scalar(v); }; + auto sub = [&xshp, &cv](int idx) { + return opr::IndexAt::make(xshp, {{0, cv(idx)}}); + }; + auto tshp = + opr::Concat::make({sub(3), sub(0) * 4, sub(1), sub(2)}, 0); + auto y0 = opr::Dimshuffle::make(x, {3, 0, 4, 1, 2}); + auto y1 = opr::Reshape::make(y0, tshp); + return y1.node(); + }; + detect_shuffle_operations(); + do_replace(); + } +}; + +/*! + * \brief abstract operator representation of shuffle operation + */ +MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, + cg::SingleCNOperatorNodeBase) // { +public: + AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, + TensorFormat out_format); + + static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, + TensorFormat out_format); + + TensorFormat inp_format() const { return m_inp_format; } + + TensorFormat out_format() const { return m_out_format; } + +private: + void init_output_static_infer_desc() override; + void scn_do_execute() override; + const TensorFormat m_inp_format; + const TensorFormat m_out_format; +}; + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); + +void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::scn_do_execute() { + mgb_throw(InternalError, "AbstractShuffleOpr cannot be executed"); +} + +void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr:: + init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + DepVal deps; + for (auto i : input()) + deps.push_back({i, DepType::SHAPE}); + auto infer_shape = [this](TensorShape& dst, const InpVal& inp) { + TensorShape inp_shape = inp.val[0].shape(); + if (m_inp_format == TensorFormat::NCHW4 && + m_out_format == TensorFormat::NCHW32) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); + dst = inp_shape; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 8; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4] * 8; + } else if (m_inp_format == TensorFormat::NCHW32 && + m_out_format == TensorFormat::NCHW4) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); + dst = inp_shape; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] * 8; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = inp_shape[4] / 8; + } else if (m_inp_format == TensorFormat::NCHW && + m_out_format == TensorFormat::NCHW4) { + mgb_assert(inp_shape.ndim == 4); + dst.ndim = 5; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] / 4; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + dst[4] = 4; + } else if (m_inp_format == TensorFormat::NCHW4 && + m_out_format == TensorFormat::NCHW) { + mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); + dst.ndim = 4; + dst[0] = inp_shape[0]; + dst[1] = inp_shape[1] * 4; + dst[2] = inp_shape[2]; + dst[3] = inp_shape[3]; + } else if (m_inp_format == TensorFormat::NCHW4 && + m_out_format == TensorFormat::CHWN4) { + dst.ndim = 5; + dst[0] = inp_shape[1]; + dst[1] = inp_shape[2]; + dst[2] = inp_shape[3]; + dst[3] = inp_shape[0]; + dst[4] = inp_shape[4]; + } else if (m_inp_format == TensorFormat::CHWN4 && + m_out_format == TensorFormat::NCHW4) { + dst.ndim = 5; + dst[0] = inp_shape[3]; + dst[1] = inp_shape[0]; + dst[2] = inp_shape[1]; + dst[3] = inp_shape[2]; + dst[4] = inp_shape[4]; + } else { + mgb_throw(InternalError, + "Unsupported input format and output format."); + } + return true; + }; + mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape}); +} + +ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::AbstractShuffleOpr( + VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) + : Super(inpvar->owner_graph(), {}, "AbstractShuffleOpr", {inpvar}), + m_inp_format{inp_format}, + m_out_format{out_format} { + add_input({inpvar}); + add_equivalence_component>(m_inp_format); + add_equivalence_component>(m_out_format); + add_output(None)->dtype(inpvar->dtype()); +} + +SymbolVar ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::make( + VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) { + return inpvar->owner_graph() + ->insert_opr(std::make_unique( + inpvar, inp_format, out_format)) + ->output(0); +} + +void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() { + auto rewriter = m_opt_state.graph().make_rewriter(); + auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()}; + auto try_reshape_shuffle = [&rewriter, + &uniq_reader_check](OperatorNodeBase* opr) { + // check shuffle + auto shuffle = try_cast_as_op(opr); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 5) + return false; + bool is_nchw2nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 3 && param.pattern[3] == 4 && + param.pattern[4] == 2 && + opr->output(0)->shape()[4] == 4; + if (!is_nchw2nchw4) + return false; + if (!uniq_reader_check(shuffle->input(0))) + return false; + + // check reshape + auto reshape = try_cast_as_op(opr->input(0)->owner_opr()); + if (reshape == nullptr) + return false; + auto inp_var = rewriter.get_var(reshape->input(0)); + auto abstract_shuffle = AbstractShuffleOpr::make( + inp_var, TensorFormat::NCHW, TensorFormat::NCHW4); + rewriter.replace_var( + opr->output(0), abstract_shuffle.node(), + mgb_cstr_log("replace reformat(nchw -> nchw4) to " + "AbstractShuffleOpr(nchw -> nchw4).")); + return true; + }; + + auto try_reshape_shuffle_reshape = [&rewriter, &uniq_reader_check]( + OperatorNodeBase* opr) { + // check reshape + auto reshape1 = try_cast_as_op(opr); + if (reshape1 == nullptr) + return false; + if (!uniq_reader_check(reshape1->input(0))) + return false; + + // check shuffle + auto shuffle = + try_cast_as_op(opr->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 6) + return false; + bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 3 && param.pattern[3] == 4 && + param.pattern[4] == 2 && param.pattern[5] == 5 && + shuffle->input(0)->shape()[5] == 4 && + shuffle->input(0)->shape()[2] == 8; + bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 4 && param.pattern[3] == 2 && + param.pattern[4] == 3 && param.pattern[5] == 5 && + shuffle->input(0)->shape()[4] == 8 && + shuffle->input(0)->shape()[5] == 4; + if (!is_nchw42nchw32 && !is_nchw322nchw4) + return false; + if (!uniq_reader_check(shuffle->input(0))) + return false; + + // check reshape + auto reshape2 = + try_cast_as_op(shuffle->input(0)->owner_opr()); + if (reshape2 == nullptr) + return false; + auto inp_var = rewriter.get_var(reshape2->input(0)); + TensorFormat inp_format = is_nchw42nchw32 ? TensorFormat::NCHW4 + : TensorFormat::NCHW32, + out_format = is_nchw42nchw32 ? TensorFormat::NCHW32 + : TensorFormat::NCHW4; + auto abstract_shuffle = + AbstractShuffleOpr::make(inp_var, inp_format, out_format); + std::string reformat_type = + is_nchw42nchw32 ? "nchw4 -> nchw32" : "nchw32 -> nchw4"; + rewriter.replace_var(opr->output(0), abstract_shuffle.node(), + mgb_cstr_log(ssprintf("replace reformat(%s) to " + "AbstractShuffleOpr(%s).", + reformat_type.c_str(), + reformat_type.c_str()) + .c_str())); + return true; + }; + + auto try_shuffle_reshape = [&rewriter, + &uniq_reader_check](OperatorNodeBase* opr) { + // check reshape + auto reshape = try_cast_as_op(opr); + if (reshape == nullptr) + return false; + if (!uniq_reader_check(reshape->input(0))) + return false; + + // check shuffle + auto shuffle = + try_cast_as_op(opr->input(0)->owner_opr()); + if (shuffle == nullptr) + return false; + auto&& param = shuffle->param(); + if (param.pattern_len != 5) + return false; + bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && + param.pattern[2] == 4 && param.pattern[3] == 2 && + param.pattern[4] == 3 && + shuffle->input(0)->shape()[4] == 4; + if (!is_nchw42nchw) + return false; + auto inp_var = rewriter.get_var(shuffle->input(0)); + auto abstract_shuffle = AbstractShuffleOpr::make( + inp_var, TensorFormat::NCHW4, TensorFormat::NCHW); + rewriter.replace_var( + opr->output(0), abstract_shuffle.node(), + mgb_cstr_log("replace reformat(nchw4 -> nchw) to " + "AbstractShuffleOpr(nchw4 -> nchw).")); + return true; + }; + + auto try_relayout_format = [&rewriter](OperatorNodeBase* opr) { + // check relayout format + auto reformat = try_cast_as_op(opr); + if (reformat == nullptr) + return false; + auto&& param = reformat->param(); + if (param.mode != opr::RelayoutFormat::Param::Mode::CHWN4_NCHW4 && + param.mode != opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) + return false; + auto inp_var = rewriter.get_var(reformat->input(0)); + cg::SymbolVar abstract_shuffle; + if (param.mode == opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) { + abstract_shuffle = AbstractShuffleOpr::make( + inp_var, TensorFormat::NCHW4, TensorFormat::CHWN4); + } else { + abstract_shuffle = AbstractShuffleOpr::make( + inp_var, TensorFormat::CHWN4, TensorFormat::NCHW4); + } + rewriter.replace_var( + opr->output(0), abstract_shuffle.node(), + mgb_cstr_log("replace reformat(nchw4 -> nchw) to " + "AbstractShuffleOpr(nchw4 -> nchw).")); + return true; + }; + + auto on_opr = [&try_reshape_shuffle, &try_shuffle_reshape, + &try_reshape_shuffle_reshape, &try_relayout_format, + &rewriter, &uniq_reader_check](OperatorNodeBase* opr) { + if (!try_reshape_shuffle_reshape(opr) && !try_reshape_shuffle(opr) && + !try_shuffle_reshape(opr) && !try_relayout_format(opr)) { + auto new_opr = rewriter.auto_replace_outputs(opr); + uniq_reader_check.update_on_opr_auto_replace(opr, new_opr); + } + }; + m_opt_state.graph().iter(on_opr); + rewriter.apply_inplace(); +} + +void ShuffleShuffleRemovePass::Impl::do_replace() { + auto rewriter = m_opt_state.graph().make_rewriter(); + auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()}; + ThinHashMap var2endpoint; + ThinHashSet trt_opr_inps; + SmallVector topo_order; + + auto cb = [&topo_order, &trt_opr_inps](OperatorNodeBase* opr) { + topo_order.push_back(opr); + MGB_MARK_USED_VAR(trt_opr_inps); +#if MGB_ENABLE_TENSOR_RT + if (opr->same_type()) { + for (auto&& inp : opr->input()) + trt_opr_inps.insert(inp); + } +#endif + }; + m_opt_state.graph().iter(cb); + + for (auto&& opr : reverse_adaptor(topo_order)) { + if (opr->same_type() || + opr->same_type()) { + auto find = var2endpoint.find(opr->output(0)); + if (find != var2endpoint.end()) { + if (uniq_reader_check(opr->output(0))) { + var2endpoint[opr->input(0)] = find->second; + } else { + var2endpoint[opr->input(0)] = opr->output(0); + } + } else { + var2endpoint[opr->input(0)] = opr->output(0); + } + } + } + + auto on_opr = [this, &rewriter, &uniq_reader_check, &trt_opr_inps, + &var2endpoint](OperatorNodeBase* opr) { + MGB_MARK_USED_VAR(trt_opr_inps); + bool cond_opr = opr->same_type() || + opr->same_type(); + if (cond_opr) { + bool cond_endpoint = var2endpoint[opr->input(0)] == opr->output(0); + if (!cond_endpoint) + return; + auto cur = opr; + auto var = opr->output(0), inp_var = opr->input(0); + bool force_folding_typecvt = false; + bool first_shuffle = false; + // initialize inp_format and out_format + TensorFormat out_format = TensorFormat::NCHW, inp_format = out_format; + megdnn::DType inp_dtype = cur->input(0)->dtype(), + out_dtype = cur->output(0)->dtype(); + SmallVector out_dtype_vec; + while (cond_opr) { + if (cur->same_type()) { + auto shuffle = try_cast_as_op(cur); + inp_format = shuffle->inp_format(); + if (!first_shuffle) { + out_format = shuffle->out_format(); + first_shuffle = true; + } + } else { + mgb_assert(cur->same_type()); + out_dtype_vec.push_back(cur->output(0)->dtype()); + } + inp_var = cur->input(0); + bool cond_reader = uniq_reader_check(inp_var); + if (!cond_reader) + break; + cur = cur->input(0)->owner_opr(); + cond_opr = cur->same_type() || + cur->same_type(); + } + std::reverse(out_dtype_vec.begin(), out_dtype_vec.end()); +#if MGB_ENABLE_TENSOR_RT + force_folding_typecvt = + inp_var->owner_opr()->same_type() || + trt_opr_inps.count(var); +#endif + auto new_var = rewriter.get_var(inp_var); + if (inp_format != out_format) { + new_var = m_reformat[std::make_pair(inp_format, out_format)]( + new_var); + } + if (force_folding_typecvt) { + inp_dtype = inp_var->dtype(); + if (inp_dtype != out_dtype) { + auto type_cvt = opr::TypeCvt::make(new_var, out_dtype); + new_var = type_cvt.node(); + } + } else { + if (out_dtype_vec.back() != var->dtype()) + out_dtype_vec.push_back(var->dtype()); + for (auto&& dtype : out_dtype_vec) { + auto type_cvt = opr::TypeCvt::make(new_var, dtype); + new_var = type_cvt.node(); + } + } + rewriter.replace_var( + var, new_var, + mgb_cstr_log("replace Dimshuffle and TypeCvt chain")); + } else { + auto new_opr = rewriter.auto_replace_outputs(opr); + uniq_reader_check.update_on_opr_auto_replace(opr, new_opr); + } + }; + m_opt_state.graph().iter(on_opr); + rewriter.apply_inplace(); +} + +const char* ShuffleShuffleRemovePass::name() const { + return mgb_cstr_log("shuffle shuffle remove pass"); +} + +void ShuffleShuffleRemovePass::apply(OptState& opt) const { + opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE | + VarReplaceCheckFlag::CHECK_DTYPE); + Impl{opt}; +} + +void gopt::reformat_to_chwn4_transform_dest_vars_inplace( + mgb::cg::VarNodeArray& dest_vars) { + gopt::GraphOptimizer optimizer; + optimizer.add_pass(); + optimizer.add_pass(); + optimizer.add_pass(EnableCHWN4Pass::make_chwn4_converter()); + optimizer.add_pass(); + optimizer.add_pass(); + optimizer.add_pass(); + optimizer.apply_inplace(dest_vars); +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}