/** * \file src/gopt/impl/basic_arith/chain.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/gtrans.h" #include "megbrain/opr/basic_arith_wrapper.h" #include //! TODO: here has to be know some megdnn::opr when there is produced midout.h //! fix it if there is another graceful way. #include "megdnn/oprs.h" #include "megbrain/utils/hash_ct.h" #include "midout.h" MIDOUT_DECL(megbrain_chain) #define MIDOUT_B(tag) \ MIDOUT_BEGIN(megbrain_chain, midout_iv(MGB_HASH_STR(tag))) { #define MIDOUT_E \ } \ MIDOUT_END(); using namespace mgb; using namespace gopt; using namespace opr; #define FOREACH_FUSE_ADD_MODE(cb) \ cb(RELU) cb(SIGMOID) cb(TANH) cb(H_SWISH) namespace { //! call process_opr_chain() when a chain of same mode is detected class ElemChainImplHelper { void on_opr(OperatorNodeBase *opr); protected: using Mode = Elemwise::Mode; OptState &m_opt_state; SubGraph::Rewriter m_rewriter; UniqReaderCheck m_uniq_reader_check; ElemChainImplHelper(OptState &opt_state): m_opt_state{opt_state}, m_rewriter{opt_state.graph().make_rewriter()}, m_uniq_reader_check{opt_state.graph()} { } ~ElemChainImplHelper() = default; void run_elem_chain() { using namespace std::placeholders; m_opt_state.graph().iter( std::bind(&ElemChainImplHelper::on_opr, this, _1)); m_rewriter.apply_inplace(); } //! called when an opr on original graph is visited virtual void on_opr_visited(OperatorNodeBase *opr) { MGB_MARK_USED_VAR(opr); } //! called when a chain of same mode on original graph is detected virtual void process_chain(VarNode *endpoint, Mode mode) = 0; /*! * \brief called at the end of visiting an operator * \return whether this opr should be further processed by * process_chain() if it is an endpoint */ virtual bool on_opr_visit_finished(Elemwise *opr) { MGB_MARK_USED_VAR(opr); return true; } //! check whether a mode should be processed virtual bool check_mode(Mode mode) = 0; VarNodeArray extract_chain_terms(VarNode *endpoint, Mode mode); }; } void ElemChainImplHelper::on_opr(OperatorNodeBase *opr) { m_uniq_reader_check.update_on_opr_auto_replace( opr, m_rewriter.auto_replace_outputs(opr)); on_opr_visited(opr); auto elem = try_cast_as_op(opr); Mode mode = elem ? elem->param().mode : Mode::NEGATE; bool inp_changed = false; for (auto i: opr->input()) { if (m_rewriter.has_manual_replace(i)) { inp_changed = true; continue; } auto ielem = try_cast_as_op(i->owner_opr()); if (ielem) { auto imode = ielem->param().mode; // To ensure that all leaves(chain terms) which found by // extract_chain_terms have been processed. In other word, // we would call process_chain in topological order. if ((!elem || imode != mode || !m_uniq_reader_check(i)) && check_mode(imode)) { inp_changed = true; m_opt_state.call_with_opr(i->owner_opr(), [&]{this->process_chain(i, imode);}); } } } if (inp_changed) { m_uniq_reader_check.update_on_opr_auto_replace( opr, m_rewriter.auto_replace_outputs(opr)); } if (elem && on_opr_visit_finished(elem)) { auto ovar = opr->output(0); if (check_mode(mode) && m_opt_state.graph().endpoint_contain(ovar)) process_chain(ovar, mode); } } VarNodeArray ElemChainImplHelper::extract_chain_terms( VarNode *endpoint, Mode mode) { auto pred = [mode, this, eo=endpoint->owner_opr()](OperatorNodeBase *opr) { return as_elem_opr(opr, mode) && ( opr == eo || m_uniq_reader_check(opr->output(0))); }; auto ret = extract_opr_leaves(endpoint, pred); mgb_assert(!ret.empty()); return ret; } /* ================ ExpandFusedArithPass ================ */ const char* ExpandFusedArithPass::name() const { return mgb_cstr_log("expand_fused_arith"); } void ExpandFusedArithPass::apply(OptState &opt) const { MIDOUT_B("ExpandFusedArithPass::apply") auto rewriter = opt.graph().make_rewriter(); auto on_opr = [&](OperatorNodeBase *opr) { using Mode = Elemwise::Mode; auto repl_opr = rewriter.auto_replace_outputs(opr); auto elem = try_cast_as_op(opr); if (elem) { auto src = opr->output(0); opr = repl_opr; SymbolVar out; const char *msg = nullptr; switch (elem->param().mode) { case Mode::FUSE_MUL_ADD3: out = SymbolVar{opr->input(0)} * opr->input(1) + opr->input(2); msg = mgb_cstr_log("expand fma3"); break; case Mode::FUSE_MUL_ADD4: out = SymbolVar{opr->input(0)} * opr->input(1) + SymbolVar{opr->input(2)} * opr->input(3); msg = mgb_cstr_log("expand fma4"); break; #define cb(m) case Mode::FUSE_ADD_##m: \ out = opr::Elemwise::make( \ {opr::add(opr->input(0), opr->input(1))}, \ Mode::m); \ msg = mgb_cstr_log("expand FUSE_ADD_" #m); \ break; FOREACH_FUSE_ADD_MODE(cb) #undef cb default: break; } if (auto dst = out.node()) { rewriter.replace_var(src, dst, msg); } } }; opt.graph().iter(on_opr); rewriter.apply_inplace(); MIDOUT_E } /* ================ NormalizeArithChainPass ================ */ class NormalizeArithChainPass::Impl { using Mode = Elemwise::Mode; struct Var2CoeffRec { dt_max_float coeff; size_t order = 0; bool operator < (const Var2CoeffRec &rhs) const { return order < rhs.order; } }; OptState &m_opt_state; SubGraph::Rewriter m_rewriter; ThinHashMap m_var2nr_val_dep; ThinHashSet m_processed_vars; //! passed from process_opr_chain() to sum_var2coeff() ThinHashMap m_var2coeff; //! tmp var used by sum_var2coeff() std::vector> m_var2coeff_sort; void sort_var2coeff() { auto &&sorted = m_var2coeff_sort; sorted.clear(); for (auto &&i: m_var2coeff) sorted.push_back({i.second, i.first}); std::sort(sorted.begin(), sorted.end()); } //! abstract operator representation struct AbstractOpr { enum class Type { ADD, SUB, COEFF }; Type type; //! inputs for ADD/SUB VarNode *i0 = nullptr, *i1 = nullptr; //! input var for COEFF VarNode *ic; //! coeff mul value dt_max_float coeff; static AbstractOpr make_coeff(VarNode *ic, float coeff) { return {Type::COEFF, nullptr, nullptr, ic, coeff}; } template static Maybe from(VarNode* var); }; struct AddTrait { static constexpr Mode ADD = Mode::ADD, SUB = Mode::SUB; static constexpr float UNIT = 0; static Maybe extract_coeff(Mode mode, Elemwise *opr); static Maybe extract_from_non_elemwise(OperatorNodeBase*) { return None; } static SymbolVar neg(SymbolVar x) { return -x; } static SymbolVar make_term(SymbolVar x, dt_max_float coeff) { return x * x.make_scalar_dt(coeff); } }; struct MulTrait { static constexpr Mode ADD = Mode::MUL, SUB = Mode::TRUE_DIV; static constexpr float UNIT = 1; static Maybe extract_coeff(Mode mode, Elemwise *opr); static Maybe extract_from_non_elemwise( OperatorNodeBase* opr); static SymbolVar neg(SymbolVar x) { return opr::powf(x, -1); } static SymbolVar make_term(SymbolVar x, dt_max_float coeff) { return opr::powf(x, coeff); } }; struct QueueNode { dt_max_float coeff; VarNode *var; }; //! sum m_var2coeff template VarNode* sum_var2coeff(); template void process_opr_chain(VarNode* endpoint); void on_opr(OperatorNodeBase *opr); public: Impl(OptState &opt_state): m_opt_state{opt_state}, m_rewriter{opt_state.graph().make_rewriter()}, m_var2nr_val_dep{opt_state.graph().get_var2nr_val_dep_oprs()} { using namespace std::placeholders; opt_state.graph().iter(std::bind(&Impl::on_opr, this, _1)); m_rewriter.apply_inplace(); } }; Maybe NormalizeArithChainPass::Impl::AddTrait::extract_coeff( Mode mode, Elemwise *opr) { if (mode == Mode::NEGATE) return AbstractOpr::make_coeff(opr->input(0), -1); if (mode == Mode::MUL) { SymbolVar i0 = opr->input(0), i1 = opr->input(1); auto i0v = i0.as_immutable_scalar_require_shape(); if (!i0v.valid()) { std::swap(i0, i1); i0v = i0.as_immutable_scalar_require_shape(); if (!i0v.valid()) return None; } return AbstractOpr::make_coeff( i1.node(), i0v->get_cast()); } return None; } Maybe NormalizeArithChainPass::Impl::MulTrait::extract_coeff( Mode mode, Elemwise *opr) { if (mode != Mode::POW) return None; auto exp = SymbolVar{opr->input(1)}.as_immutable_scalar_require_shape(); if (exp.valid()) { return AbstractOpr::make_coeff( opr->input(0), exp->get_cast()); } return None; } Maybe NormalizeArithChainPass::Impl::MulTrait::extract_from_non_elemwise( OperatorNodeBase* opr) { if (auto powc = try_cast_as_op(opr)) { return AbstractOpr::make_coeff(powc->input(0), powc->param().exp); } return None; } template Maybe NormalizeArithChainPass::Impl::AbstractOpr::from(VarNode* var) { auto opr = var->owner_opr(); auto non_elem_ret = Trait::extract_from_non_elemwise(opr); if (non_elem_ret.valid()) { return non_elem_ret; } auto elem = try_cast_as_op(opr); if (!elem) return None; auto mode = elem->param().mode; if (mode == Trait::ADD || mode == Trait::SUB) { auto type = mode == Trait::ADD ? Type::ADD : Type::SUB; return AbstractOpr{type, elem->input(0), elem->input(1), nullptr, 0}; } return Trait::extract_coeff(mode, elem); } template void NormalizeArithChainPass::Impl::process_opr_chain(VarNode* endpoint) { if (!m_processed_vars.insert(endpoint).second) return; if (std::is_same::value && endpoint->dtype().category() == DTypeCategory::INT) { // do not normalize int mul/div, since int mul/div is not a closed group return; } auto &&var2coeff = m_var2coeff; var2coeff.clear(); std::deque queue; bool has_non_elem_case = false; // non-elemwise oprs should be canonized size_t nr_sub = 0, nr_non1_coeff = 0, nr_term = 0; queue.push_back({dt_max_float(1), endpoint}); while (!queue.empty()) { auto qh = queue.front(); queue.pop_front(); VarNode* var = qh.var; // find leaf nodes on original graph (without applying rewriter) if ((var == endpoint || m_var2nr_val_dep.at(var) <= 1) && var->comp_node() == endpoint->comp_node()) { Maybe aopr = AbstractOpr::from(var); auto append = [&](VarNode *var, dt_max_float coeff) { queue.push_back({qh.coeff * coeff, var}); }; if (aopr.valid()) { auto &&val = aopr.val(); using Type = AbstractOpr::Type; if (!var->owner_opr()->same_type()) { has_non_elem_case = true; } switch (val.type) { case Type::ADD: append(val.i0, 1); append(val.i1, 1); break; case Type::SUB: ++ nr_sub; append(val.i0, 1); append(val.i1, -1); break; case Type::COEFF: if (val.coeff != 1) ++ nr_non1_coeff; append(val.ic, val.coeff); break; default: mgb_assert(0); } continue; } } // var is a leaf node that can not be expanded ++ nr_term; var = m_rewriter.get_var(var); // apply previous trans on leaf nodes auto &&dest = var2coeff[var]; dest.coeff += qh.coeff; if (!dest.order) { dest.order = nr_term; } } if (nr_sub || nr_non1_coeff >= 2 || nr_term > var2coeff.size() || has_non_elem_case) { auto sum = sum_var2coeff(); if (endpoint != sum) { m_rewriter.replace_var( endpoint, sum, ssprintf("normalize elemwise chain with %zu terms", nr_term) .c_str()); } } } template VarNode* NormalizeArithChainPass::Impl::sum_var2coeff() { sort_var2coeff(); // use another function to bypass GCC-5 bug auto &&sorted = m_var2coeff_sort; VarNode *sum = nullptr; for (auto &&var_cnt_pair: sorted) { SymbolVar x = var_cnt_pair.second, term; dt_max_float coeff = var_cnt_pair.first.coeff; auto eq = [coeff](dt_max_float v) { return almost_equal(coeff, v); }; if (eq(0)) { term = x.fill_retain_dtype(Trait::UNIT); } else if (eq(1)) { term = x; } else if (eq(-1)) { term = Trait::neg(x); } else { // note: for power 2, 2 * x is better than x + x, because 2 * x * y // may be reordered to 2 * y * x, and it does not seem to cause // other overhead term = Trait::make_term(x, coeff); } if (!sum) { sum = term.node(); } else { sum = Elemwise::make({sum, term}, Trait::ADD).node(); } } return sum; } void NormalizeArithChainPass::Impl::on_opr(OperatorNodeBase *opr) { m_rewriter.auto_replace_outputs(opr); using proc_fn_t = void (Impl::*)(VarNode*); auto dispatch_proc_fn = [](OperatorNodeBase* opr) -> proc_fn_t { if (auto elem = try_cast_as_op(opr)) { auto mode = elem->param().mode; if (mode == Mode::ADD || mode == Mode::SUB || mode == Mode::NEGATE) { return &Impl::process_opr_chain; } if (mode == Mode::MUL || mode == Mode::TRUE_DIV || (mode == Mode::POW && SymbolVar{opr->input(1)} .as_immutable_scalar_require_shape() .valid())) { return &Impl::process_opr_chain; } } if (opr->same_type()) { return &Impl::process_opr_chain; } return nullptr; }; VarNode* out0 = nullptr; auto func_self = dispatch_proc_fn(opr); if (func_self) { out0 = opr->output(0); } bool inp_changed = false; for (auto i: opr->input()) { if (m_rewriter.has_manual_replace(i)) { inp_changed = true; continue; } auto func_in = dispatch_proc_fn(i->owner_opr()); if (func_in && (func_in != func_self || m_var2nr_val_dep.at(i) >= 2)) { // note: we process starting from an endpoint of a chain of the same // mode (either ADD or MUL) to ensure linear time complexity. An // endpoint is a var that must be preserved, which is either: (1) // received by multiple readers (2) received by an opr of different // mode or non-elemwise opr (3) the endpoint of the whole graph. The // cases (1) and (2) are handled here, and case (3) is handled // below by calling func_self(). inp_changed = true; m_opt_state.call_with_opr(i->owner_opr(), [&]{(this->*func_in)(i);}); } } if (inp_changed) m_rewriter.auto_replace_outputs(opr); if (func_self && m_opt_state.graph().endpoint_contain(out0)) { (this->*func_self)(out0); } } const char* NormalizeArithChainPass::name() const { return mgb_cstr_log("normalize_arith_expr"); } void NormalizeArithChainPass::apply(OptState &opt) const { MIDOUT_B("NormalizeArithChainPass::apply") Impl{opt}; MIDOUT_E } /* ================ ReorderArithChainPass ================ */ class ReorderArithChainPass::Impl final: public ElemChainImplHelper { using ShapedVars = std::vector>; ConstVarPropogate m_cvprop; TensorShapeArray m_tmp_inp_shp; //! tmp var: (shape, is_const) -> terms TensorShapeHashKey::Map> m_shp2terms; ShapedVars m_const_terms, m_nonconst_terms; //! reduce two terms static VarNode* reduce(Mode mode, VarNode *a, VarNode *b); //! reduce m_shp2terms into a sum var VarNode* reduce_shp2terms(Mode mode); //! merge src and dst into dst, if merging does not broadcast both bool merge_shape_if_compatible(const TensorShape &src, TensorShape &dst); //! merge compatible shapes void merge_shaped_terms(Mode mode, ShapedVars &vars, bool allow_compatible); void process_chain(VarNode *endpoint, Mode mode) override; void on_opr_visited(OperatorNodeBase *opr) override { m_cvprop.add_opr(opr); } bool check_mode(Mode mode) override { return mode == Mode::ADD || mode == Mode::MUL || mode == Mode::MAX || mode == Mode::MIN; } public: Impl(const ReorderArithChainPass &pass, OptState &opt_state): ElemChainImplHelper(opt_state), m_cvprop{pass.m_const_var_type} { run_elem_chain(); } }; VarNode* ReorderArithChainPass::Impl::reduce( Mode mode, VarNode *a, VarNode *b) { if (!a) return b; if (!b) return a; return opr::Elemwise::make({a, b}, mode).node(); } bool ReorderArithChainPass::Impl::merge_shape_if_compatible( const TensorShape &src, TensorShape &dst) { m_tmp_inp_shp.resize(2); m_tmp_inp_shp[0] = src; m_tmp_inp_shp[1] = dst; TensorShape out; megdnn::Elemwise::deduce_shape(m_tmp_inp_shp, out); if (out.eq_shape(src)) { dst = out; return true; } return out.eq_shape(dst); } VarNode* ReorderArithChainPass::Impl::reduce_shp2terms(Mode mode) { // populate m_const_terms and m_nonconst_terms m_const_terms.clear(); m_nonconst_terms.clear(); for (auto &&i: m_shp2terms) { if (!i.second[0].empty()) { m_nonconst_terms.emplace_back( i.first.shape(), elemwise_reduce_var_list(i.second[0], mode)); } if (!i.second[1].empty()) { m_const_terms.emplace_back( i.first.shape(), elemwise_reduce_var_list(i.second[1], mode)); } } { // sorted by id(), so the same set of input terms would get the // same reduced var auto cmp = [](const ShapedVars::value_type &a, const ShapedVars::value_type &b) { return a.second->id() < b.second->id(); }; small_sort(m_const_terms.begin(), m_const_terms.end(), cmp); small_sort(m_nonconst_terms.begin(), m_nonconst_terms.end(), cmp); } merge_shaped_terms(mode, m_const_terms, true); auto &&all_terms = m_const_terms; all_terms.insert(all_terms.end(), m_nonconst_terms.begin(), m_nonconst_terms.end()); // merge eq shape merge_shaped_terms(mode, all_terms, false); // merge compatible shape merge_shaped_terms(mode, all_terms, true); // simple heuristic: reduce in increasing size order auto cmp = [](const ShapedVars::value_type &a, const ShapedVars::value_type &b) { return a.first.total_nr_elems() < b.first.total_nr_elems(); }; small_sort(all_terms.begin(), all_terms.end(), cmp); VarNode *sum = nullptr; for (auto &&i: all_terms) { sum = reduce(mode, sum, i.second); } mgb_assert(sum); return sum; } void ReorderArithChainPass::Impl::merge_shaped_terms( Mode mode, ShapedVars &vars, bool allow_compatible) { for (bool merged = true; merged;) { merged = false; for (size_t i = 0; !merged && i < vars.size(); ++ i) { auto &&src = vars[i]; if (!src.first.ndim) continue; TensorShape dst_shape; size_t dst_idx = -1; auto update_dst = [&](size_t idx, const TensorShape &shp) { if (!dst_shape.ndim || shp.total_nr_elems() < dst_shape.total_nr_elems()) { dst_shape = shp; dst_idx = idx; } }; for (size_t j = 0; j < vars.size(); ++ j) { auto &&dst = vars[j]; if (i == j || !dst.first.ndim) continue; if (allow_compatible) { auto tshp = dst.first; if (merge_shape_if_compatible(src.first, tshp)) { update_dst(j, tshp); } } else { if (src.first.eq_shape(dst.first)) { update_dst(j, dst.first); } } } if (dst_shape.ndim) { auto &&dst = vars[dst_idx]; dst.first = dst_shape; dst.second = reduce(mode, src.second, dst.second); mgb_assert( (!dst.second->shape().ndim && !cg::is_static_var_shape(dst.second)) || dst.second->shape().eq_shape(dst.first)); std::swap(src, vars.back()); vars.pop_back(); merged = true; break; } } } } void ReorderArithChainPass::Impl::process_chain(VarNode *endpoint, Mode mode) { if (m_cvprop.is_const(endpoint)) return; auto vars = extract_chain_terms(endpoint, mode); if (vars.size() == 1) return; // to ensure the same set of input terms get the same reduced var // TODO: consider maintain a cache(map) of (sorted input terms -> reduced var) std::sort(vars.begin(), vars.end(), [](VarNode *x, VarNode *y){ return x->id() < y->id(); }); m_shp2terms.clear(); for (auto i: vars) { auto inew = m_rewriter.get_var(i); m_shp2terms[i->shape()][m_cvprop.is_const(i)].push_back(inew); } auto sum = reduce_shp2terms(mode); if (m_rewriter.get_var(endpoint) != sum) { m_rewriter.replace_var(endpoint, sum, mgb_ssprintf_log("reorder %zu %s terms", vars.size(), megdnn::Elemwise::ModeTrait::from_mode(mode).name).c_str()); } } const char* ReorderArithChainPass::name() const { return mgb_cstr_log("reorder_arith_chain"); } void ReorderArithChainPass::apply(OptState &opt) const { MIDOUT_B("ReorderArithChainPass::apply") Impl{*this, opt}; MIDOUT_E } /* ================ ArithFusePass ================ */ class ArithFusePass::Impl final: public ElemChainImplHelper { using MulTermArray = std::vector>; class SumVars; size_t m_nr_fma3, m_nr_fma4; TensorShapeHashKey::PairMap m_mul_terms; TensorShapeHashKey::Map m_bias_terms; bool check_mode(Mode mode) override { return mode == Mode::ADD; } void process_chain(VarNode *endpoint, Mode mode) override; VarNode* find_pop_bias_term(const TensorShape &shape) { auto iter = m_bias_terms.find(shape); if (iter != m_bias_terms.end()) { auto ret = elemwise_reduce_var_list(iter->second, Mode::ADD); m_bias_terms.erase(iter); return ret; } return nullptr; } VarNode* process_mul_term(MulTermArray &terms); bool on_opr_visit_finished(Elemwise *opr) override; public: Impl(OptState &opt_state): ElemChainImplHelper(opt_state) { run_elem_chain(); } }; class ArithFusePass::Impl::SumVars { VarNode *m_sum = nullptr; public: void add(SymbolVar var) { if (!m_sum) { m_sum = var.node(); } else { m_sum = opr::add(m_sum, var).node(); } } VarNode* get() const { return m_sum; } }; void ArithFusePass::Impl::process_chain(VarNode *endpoint, Mode mode) { if (!endpoint->shape().ndim) return; mgb_assert(mode == Mode::ADD); m_mul_terms.clear(); m_bias_terms.clear(); m_nr_fma3 = m_nr_fma4 = 0; auto vars = extract_chain_terms(endpoint, mode); for (auto var: vars) { auto opr = var->owner_opr(); Elemwise *mul; if (m_uniq_reader_check(var) && (mul = as_elem_opr(opr, Mode::MUL))) { auto a = mul->input(0), b = mul->input(1); if (a->shape().total_nr_elems() > b->shape().total_nr_elems()) { std::swap(a, b); } a = m_rewriter.get_var(a); b = m_rewriter.get_var(b); m_mul_terms[{a->shape(), b->shape()}].push_back({a, b}); } else { var = m_rewriter.get_var(var); m_bias_terms[var->shape()].push_back(var); } } if (m_mul_terms.empty()) return; // merge same shapes, so they can be used as bias by others for (auto i = m_mul_terms.begin(); i != m_mul_terms.end(); ) { auto &&s = i->first; if (s.first.shape().eq_shape(s.second.shape())) { auto merged = process_mul_term(i->second); mgb_assert(merged->shape().eq_shape(s.first.shape())); m_bias_terms[merged->shape()].push_back(merged); mgb_assert(i->second.empty()); i = m_mul_terms.erase(i); } else { ++ i; } } // sort mul_terms by size TensorShapeArray shp_inp(2); using SortedTermItem = std::pair; std::vector mul_terms_sorted; for (auto &&i: m_mul_terms) { shp_inp[0] = i.first.first.shape(); shp_inp[1] = i.first.second.shape(); TensorShape tshp; megdnn::Elemwise::deduce_shape(shp_inp, tshp); mul_terms_sorted.push_back({tshp.total_nr_elems(), &i.second}); } auto cmp = [](const SortedTermItem &a, const SortedTermItem &b) { return a.first < b.first || ( a.first == b.first && a.second->size() < b.second->size()); }; std::sort(mul_terms_sorted.begin(), mul_terms_sorted.end(), cmp); // merge from smallest to largest for (auto &&i: mul_terms_sorted) { auto merged = process_mul_term(*i.second); mgb_assert(i.second->empty() && merged->shape().ndim); m_bias_terms[merged->shape()].push_back(merged); } SumVars sum_vars; for (auto &&i: m_bias_terms) { sum_vars.add(elemwise_reduce_var_list(i.second, Mode::ADD)); } auto sum = sum_vars.get(); m_rewriter.replace_var(endpoint, sum, mgb_ssprintf_log( "replace %zu fma3, %zu fma4", m_nr_fma3, m_nr_fma4).c_str()); } VarNode* ArithFusePass::Impl::process_mul_term(MulTermArray &terms) { mgb_assert(!terms.empty()); SumVars sum_vars; while (terms.size() >= 2) { auto b = terms.back(); terms.pop_back(); auto a = terms.back(); terms.pop_back(); ++ m_nr_fma4; sum_vars.add(Elemwise::make({a.first, a.second, b.first, b.second}, Mode::FUSE_MUL_ADD4)); } if (!terms.empty()) { auto t = terms.back(); terms.pop_back(); auto bias = find_pop_bias_term(t.first->shape()); if (!bias) bias = find_pop_bias_term(t.second->shape()); if (bias) { ++ m_nr_fma3; sum_vars.add(Elemwise::make({t.first, t.second, bias}, Mode::FUSE_MUL_ADD3)); } else { sum_vars.add(opr::mul(t.first, t.second)); } } return sum_vars.get(); } bool ArithFusePass::Impl::on_opr_visit_finished(Elemwise *opr) { if (opr->input().size() != 1) return true; if (!m_uniq_reader_check(opr->input(0))) return true; auto iadd = as_elem_opr(m_rewriter.get_var(opr->input(0)), Mode::ADD); if (!iadd) return true; if (opr->input(0)->dtype().category() == DTypeCategory::QUANTIZED) return true; Mode fmode; const char *msg; switch (opr->param().mode) { #define cb(m) \ case Mode::m: \ fmode = Mode::FUSE_ADD_##m; \ msg = mgb_cstr_log("fuse " #m "(x + y)"); \ break; FOREACH_FUSE_ADD_MODE(cb) #undef cb default: return true; } m_opt_state.call_with_opr(opr, [&]{ auto fused = opr::Elemwise::make({iadd->input(0), iadd->input(1)}, fmode).node(); m_rewriter.replace_var(opr->output(0), fused, msg); m_uniq_reader_check.update_on_opr_auto_replace(opr, fused->owner_opr()); }); return false; } const char* ArithFusePass::name() const { return mgb_cstr_log("arith_fuse"); } void ArithFusePass::apply(OptState &opt) const { MIDOUT_B("ArithFusePass::apply") Impl{opt}; MIDOUT_E } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}