You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

chain.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972
  1. /**
  2. * \file src/gopt/impl/basic_arith/chain.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/gopt/basic_arith.h"
  12. #include "megbrain/gopt/gtrans.h"
  13. #include "megbrain/opr/basic_arith_wrapper.h"
  14. #include <deque>
  15. //! TODO: here has to be know some megdnn::opr when there is produced midout.h
  16. //! fix it if there is another graceful way.
  17. #include "megdnn/oprs.h"
  18. #include "megbrain/utils/hash_ct.h"
  19. #include "midout.h"
  20. MIDOUT_DECL(megbrain_chain)
  21. #define MIDOUT_B(tag) \
  22. MIDOUT_BEGIN(megbrain_chain, midout_iv(MGB_HASH_STR(tag))) {
  23. #define MIDOUT_E \
  24. } \
  25. MIDOUT_END();
  26. using namespace mgb;
  27. using namespace gopt;
  28. using namespace opr;
  29. #define FOREACH_FUSE_ADD_MODE(cb) \
  30. cb(RELU) cb(SIGMOID) cb(TANH) cb(H_SWISH)
  31. namespace {
  32. //! call process_opr_chain() when a chain of same mode is detected
  33. class ElemChainImplHelper {
  34. void on_opr(OperatorNodeBase *opr);
  35. protected:
  36. using Mode = Elemwise::Mode;
  37. OptState &m_opt_state;
  38. SubGraph::Rewriter m_rewriter;
  39. UniqReaderCheck m_uniq_reader_check;
  40. ElemChainImplHelper(OptState &opt_state):
  41. m_opt_state{opt_state},
  42. m_rewriter{opt_state.graph().make_rewriter()},
  43. m_uniq_reader_check{opt_state.graph()}
  44. {
  45. }
  46. ~ElemChainImplHelper() = default;
  47. void run_elem_chain() {
  48. using namespace std::placeholders;
  49. m_opt_state.graph().iter(
  50. std::bind(&ElemChainImplHelper::on_opr, this, _1));
  51. m_rewriter.apply_inplace();
  52. }
  53. //! called when an opr on original graph is visited
  54. virtual void on_opr_visited(OperatorNodeBase *opr) {
  55. MGB_MARK_USED_VAR(opr);
  56. }
  57. //! called when a chain of same mode on original graph is detected
  58. virtual void process_chain(VarNode *endpoint, Mode mode) = 0;
  59. /*!
  60. * \brief called at the end of visiting an operator
  61. * \return whether this opr should be further processed by
  62. * process_chain() if it is an endpoint
  63. */
  64. virtual bool on_opr_visit_finished(Elemwise *opr) {
  65. MGB_MARK_USED_VAR(opr);
  66. return true;
  67. }
  68. //! check whether a mode should be processed
  69. virtual bool check_mode(Mode mode) = 0;
  70. VarNodeArray extract_chain_terms(VarNode *endpoint, Mode mode);
  71. };
  72. }
  73. void ElemChainImplHelper::on_opr(OperatorNodeBase *opr) {
  74. m_uniq_reader_check.update_on_opr_auto_replace(
  75. opr, m_rewriter.auto_replace_outputs(opr));
  76. on_opr_visited(opr);
  77. auto elem = try_cast_as_op<Elemwise>(opr);
  78. Mode mode = elem ? elem->param().mode : Mode::NEGATE;
  79. bool inp_changed = false;
  80. for (auto i: opr->input()) {
  81. if (m_rewriter.has_manual_replace(i)) {
  82. inp_changed = true;
  83. continue;
  84. }
  85. auto ielem = try_cast_as_op<Elemwise>(i->owner_opr());
  86. if (ielem) {
  87. auto imode = ielem->param().mode;
  88. // To ensure that all leaves(chain terms) which found by
  89. // extract_chain_terms have been processed. In other word,
  90. // we would call process_chain in topological order.
  91. if ((!elem || imode != mode || !m_uniq_reader_check(i))
  92. && check_mode(imode)) {
  93. inp_changed = true;
  94. m_opt_state.call_with_opr(i->owner_opr(),
  95. [&]{this->process_chain(i, imode);});
  96. }
  97. }
  98. }
  99. if (inp_changed) {
  100. m_uniq_reader_check.update_on_opr_auto_replace(
  101. opr, m_rewriter.auto_replace_outputs(opr));
  102. }
  103. if (elem && on_opr_visit_finished(elem)) {
  104. auto ovar = opr->output(0);
  105. if (check_mode(mode) && m_opt_state.graph().endpoint_contain(ovar))
  106. process_chain(ovar, mode);
  107. }
  108. }
  109. VarNodeArray ElemChainImplHelper::extract_chain_terms(
  110. VarNode *endpoint, Mode mode) {
  111. auto pred = [mode, this, eo=endpoint->owner_opr()](OperatorNodeBase *opr) {
  112. return as_elem_opr(opr, mode) && (
  113. opr == eo || m_uniq_reader_check(opr->output(0)));
  114. };
  115. auto ret = extract_opr_leaves(endpoint, pred);
  116. mgb_assert(!ret.empty());
  117. return ret;
  118. }
  119. /* ================ ExpandFusedArithPass ================ */
  120. const char* ExpandFusedArithPass::name() const {
  121. return mgb_cstr_log("expand_fused_arith");
  122. }
  123. void ExpandFusedArithPass::apply(OptState &opt) const {
  124. MIDOUT_B("ExpandFusedArithPass::apply")
  125. auto rewriter = opt.graph().make_rewriter();
  126. auto on_opr = [&](OperatorNodeBase *opr) {
  127. using Mode = Elemwise::Mode;
  128. auto repl_opr = rewriter.auto_replace_outputs(opr);
  129. auto elem = try_cast_as_op<Elemwise>(opr);
  130. if (elem) {
  131. auto src = opr->output(0);
  132. opr = repl_opr;
  133. SymbolVar out;
  134. const char *msg = nullptr;
  135. switch (elem->param().mode) {
  136. case Mode::FUSE_MUL_ADD3:
  137. out = SymbolVar{opr->input(0)} * opr->input(1) +
  138. opr->input(2);
  139. msg = mgb_cstr_log("expand fma3");
  140. break;
  141. case Mode::FUSE_MUL_ADD4:
  142. out = SymbolVar{opr->input(0)} * opr->input(1) +
  143. SymbolVar{opr->input(2)} * opr->input(3);
  144. msg = mgb_cstr_log("expand fma4");
  145. break;
  146. #define cb(m) case Mode::FUSE_ADD_##m: \
  147. out = opr::Elemwise::make( \
  148. {opr::add(opr->input(0), opr->input(1))}, \
  149. Mode::m); \
  150. msg = mgb_cstr_log("expand FUSE_ADD_" #m); \
  151. break;
  152. FOREACH_FUSE_ADD_MODE(cb)
  153. #undef cb
  154. default:
  155. break;
  156. }
  157. if (auto dst = out.node()) {
  158. rewriter.replace_var(src, dst, msg);
  159. }
  160. }
  161. };
  162. opt.graph().iter(on_opr);
  163. rewriter.apply_inplace();
  164. MIDOUT_E
  165. }
  166. /* ================ NormalizeArithChainPass ================ */
  167. class NormalizeArithChainPass::Impl {
  168. using Mode = Elemwise::Mode;
  169. struct Var2CoeffRec {
  170. dt_max_float coeff;
  171. size_t order = 0;
  172. bool operator < (const Var2CoeffRec &rhs) const {
  173. return order < rhs.order;
  174. }
  175. };
  176. OptState &m_opt_state;
  177. SubGraph::Rewriter m_rewriter;
  178. ThinHashMap<VarNode*, size_t> m_var2nr_val_dep;
  179. ThinHashSet<VarNode*> m_processed_vars;
  180. //! passed from process_opr_chain() to sum_var2coeff()
  181. ThinHashMap<VarNode*, Var2CoeffRec> m_var2coeff;
  182. //! tmp var used by sum_var2coeff()
  183. std::vector<std::pair<Var2CoeffRec, VarNode*>> m_var2coeff_sort;
  184. void sort_var2coeff() {
  185. auto &&sorted = m_var2coeff_sort;
  186. sorted.clear();
  187. for (auto &&i: m_var2coeff)
  188. sorted.push_back({i.second, i.first});
  189. std::sort(sorted.begin(), sorted.end());
  190. }
  191. //! abstract operator representation
  192. struct AbstractOpr {
  193. enum class Type {
  194. ADD, SUB, COEFF
  195. };
  196. Type type;
  197. //! inputs for ADD/SUB
  198. VarNode *i0 = nullptr, *i1 = nullptr;
  199. //! input var for COEFF
  200. VarNode *ic;
  201. //! coeff mul value
  202. dt_max_float coeff;
  203. static AbstractOpr make_coeff(VarNode *ic, float coeff) {
  204. return {Type::COEFF, nullptr, nullptr, ic, coeff};
  205. }
  206. template<class Trait>
  207. static Maybe<AbstractOpr> from(VarNode* var);
  208. };
  209. struct AddTrait {
  210. static constexpr Mode ADD = Mode::ADD, SUB = Mode::SUB;
  211. static constexpr float UNIT = 0;
  212. static Maybe<AbstractOpr> extract_coeff(Mode mode, Elemwise *opr);
  213. static Maybe<AbstractOpr> extract_from_non_elemwise(OperatorNodeBase*) {
  214. return None;
  215. }
  216. static SymbolVar neg(SymbolVar x) {
  217. return -x;
  218. }
  219. static SymbolVar make_term(SymbolVar x, dt_max_float coeff) {
  220. return x * x.make_scalar_dt(coeff);
  221. }
  222. };
  223. struct MulTrait {
  224. static constexpr Mode ADD = Mode::MUL, SUB = Mode::TRUE_DIV;
  225. static constexpr float UNIT = 1;
  226. static Maybe<AbstractOpr> extract_coeff(Mode mode, Elemwise *opr);
  227. static Maybe<AbstractOpr> extract_from_non_elemwise(
  228. OperatorNodeBase* opr);
  229. static SymbolVar neg(SymbolVar x) {
  230. return opr::powf(x, -1);
  231. }
  232. static SymbolVar make_term(SymbolVar x, dt_max_float coeff) {
  233. return opr::powf(x, coeff);
  234. }
  235. };
  236. struct QueueNode {
  237. dt_max_float coeff;
  238. VarNode *var;
  239. };
  240. //! sum m_var2coeff
  241. template<class Trait>
  242. VarNode* sum_var2coeff();
  243. template<class Trait>
  244. void process_opr_chain(VarNode* endpoint);
  245. void on_opr(OperatorNodeBase *opr);
  246. public:
  247. Impl(OptState &opt_state):
  248. m_opt_state{opt_state},
  249. m_rewriter{opt_state.graph().make_rewriter()},
  250. m_var2nr_val_dep{opt_state.graph().get_var2nr_val_dep_oprs()}
  251. {
  252. using namespace std::placeholders;
  253. opt_state.graph().iter(std::bind(&Impl::on_opr, this, _1));
  254. m_rewriter.apply_inplace();
  255. }
  256. };
  257. Maybe<NormalizeArithChainPass::Impl::AbstractOpr>
  258. NormalizeArithChainPass::Impl::AddTrait::extract_coeff(
  259. Mode mode, Elemwise *opr) {
  260. if (mode == Mode::NEGATE)
  261. return AbstractOpr::make_coeff(opr->input(0), -1);
  262. if (mode == Mode::MUL) {
  263. SymbolVar i0 = opr->input(0), i1 = opr->input(1);
  264. auto i0v = i0.as_immutable_scalar_require_shape();
  265. if (!i0v.valid()) {
  266. std::swap(i0, i1);
  267. i0v = i0.as_immutable_scalar_require_shape();
  268. if (!i0v.valid())
  269. return None;
  270. }
  271. return AbstractOpr::make_coeff(
  272. i1.node(), i0v->get_cast<dt_max_float>());
  273. }
  274. return None;
  275. }
  276. Maybe<NormalizeArithChainPass::Impl::AbstractOpr>
  277. NormalizeArithChainPass::Impl::MulTrait::extract_coeff(
  278. Mode mode, Elemwise *opr) {
  279. if (mode != Mode::POW)
  280. return None;
  281. auto exp = SymbolVar{opr->input(1)}.as_immutable_scalar_require_shape();
  282. if (exp.valid()) {
  283. return AbstractOpr::make_coeff(
  284. opr->input(0), exp->get_cast<dt_max_float>());
  285. }
  286. return None;
  287. }
  288. Maybe<NormalizeArithChainPass::Impl::AbstractOpr>
  289. NormalizeArithChainPass::Impl::MulTrait::extract_from_non_elemwise(
  290. OperatorNodeBase* opr) {
  291. if (auto powc = try_cast_as_op<PowC>(opr)) {
  292. return AbstractOpr::make_coeff(powc->input(0), powc->param().exp);
  293. }
  294. return None;
  295. }
  296. template <class Trait>
  297. Maybe<NormalizeArithChainPass::Impl::AbstractOpr>
  298. NormalizeArithChainPass::Impl::AbstractOpr::from(VarNode* var) {
  299. auto opr = var->owner_opr();
  300. auto non_elem_ret = Trait::extract_from_non_elemwise(opr);
  301. if (non_elem_ret.valid()) {
  302. return non_elem_ret;
  303. }
  304. auto elem = try_cast_as_op<Elemwise>(opr);
  305. if (!elem)
  306. return None;
  307. auto mode = elem->param().mode;
  308. if (mode == Trait::ADD || mode == Trait::SUB) {
  309. auto type = mode == Trait::ADD ? Type::ADD : Type::SUB;
  310. return AbstractOpr{type, elem->input(0), elem->input(1), nullptr, 0};
  311. }
  312. return Trait::extract_coeff(mode, elem);
  313. }
  314. template<class Trait>
  315. void NormalizeArithChainPass::Impl::process_opr_chain(VarNode* endpoint) {
  316. if (!m_processed_vars.insert(endpoint).second)
  317. return;
  318. if (std::is_same<Trait, MulTrait>::value &&
  319. endpoint->dtype().category() == DTypeCategory::INT) {
  320. // do not normalize int mul/div, since int mul/div is not a closed group
  321. return;
  322. }
  323. auto &&var2coeff = m_var2coeff;
  324. var2coeff.clear();
  325. std::deque<QueueNode> queue;
  326. bool has_non_elem_case = false; // non-elemwise oprs should be canonized
  327. size_t nr_sub = 0, nr_non1_coeff = 0, nr_term = 0;
  328. queue.push_back({dt_max_float(1), endpoint});
  329. while (!queue.empty()) {
  330. auto qh = queue.front();
  331. queue.pop_front();
  332. VarNode* var = qh.var;
  333. // find leaf nodes on original graph (without applying rewriter)
  334. if ((var == endpoint || m_var2nr_val_dep.at(var) <= 1) &&
  335. var->comp_node() == endpoint->comp_node()) {
  336. Maybe<AbstractOpr> aopr = AbstractOpr::from<Trait>(var);
  337. auto append = [&](VarNode *var, dt_max_float coeff) {
  338. queue.push_back({qh.coeff * coeff, var});
  339. };
  340. if (aopr.valid()) {
  341. auto &&val = aopr.val();
  342. using Type = AbstractOpr::Type;
  343. if (!var->owner_opr()->same_type<opr::Elemwise>()) {
  344. has_non_elem_case = true;
  345. }
  346. switch (val.type) {
  347. case Type::ADD:
  348. append(val.i0, 1);
  349. append(val.i1, 1);
  350. break;
  351. case Type::SUB:
  352. ++ nr_sub;
  353. append(val.i0, 1);
  354. append(val.i1, -1);
  355. break;
  356. case Type::COEFF:
  357. if (val.coeff != 1)
  358. ++ nr_non1_coeff;
  359. append(val.ic, val.coeff);
  360. break;
  361. default:
  362. mgb_assert(0);
  363. }
  364. continue;
  365. }
  366. }
  367. // var is a leaf node that can not be expanded
  368. ++ nr_term;
  369. var = m_rewriter.get_var(var); // apply previous trans on leaf nodes
  370. auto &&dest = var2coeff[var];
  371. dest.coeff += qh.coeff;
  372. if (!dest.order) {
  373. dest.order = nr_term;
  374. }
  375. }
  376. if (nr_sub || nr_non1_coeff >= 2 || nr_term > var2coeff.size() ||
  377. has_non_elem_case) {
  378. auto sum = sum_var2coeff<Trait>();
  379. if (endpoint != sum) {
  380. m_rewriter.replace_var(
  381. endpoint, sum,
  382. ssprintf("normalize elemwise chain with %zu terms", nr_term)
  383. .c_str());
  384. }
  385. }
  386. }
  387. template<class Trait>
  388. VarNode* NormalizeArithChainPass::Impl::sum_var2coeff() {
  389. sort_var2coeff(); // use another function to bypass GCC-5 bug
  390. auto &&sorted = m_var2coeff_sort;
  391. VarNode *sum = nullptr;
  392. for (auto &&var_cnt_pair: sorted) {
  393. SymbolVar x = var_cnt_pair.second, term;
  394. dt_max_float coeff = var_cnt_pair.first.coeff;
  395. auto eq = [coeff](dt_max_float v) {
  396. return almost_equal(coeff, v);
  397. };
  398. if (eq(0)) {
  399. term = x.fill_retain_dtype(Trait::UNIT);
  400. } else if (eq(1)) {
  401. term = x;
  402. } else if (eq(-1)) {
  403. term = Trait::neg(x);
  404. } else {
  405. // note: for power 2, 2 * x is better than x + x, because 2 * x * y
  406. // may be reordered to 2 * y * x, and it does not seem to cause
  407. // other overhead
  408. term = Trait::make_term(x, coeff);
  409. }
  410. if (!sum) {
  411. sum = term.node();
  412. } else {
  413. sum = Elemwise::make({sum, term}, Trait::ADD).node();
  414. }
  415. }
  416. return sum;
  417. }
  418. void NormalizeArithChainPass::Impl::on_opr(OperatorNodeBase *opr) {
  419. m_rewriter.auto_replace_outputs(opr);
  420. using proc_fn_t = void (Impl::*)(VarNode*);
  421. auto dispatch_proc_fn = [](OperatorNodeBase* opr) -> proc_fn_t {
  422. if (auto elem = try_cast_as_op<Elemwise>(opr)) {
  423. auto mode = elem->param().mode;
  424. if (mode == Mode::ADD || mode == Mode::SUB ||
  425. mode == Mode::NEGATE) {
  426. return &Impl::process_opr_chain<AddTrait>;
  427. }
  428. if (mode == Mode::MUL || mode == Mode::TRUE_DIV ||
  429. (mode == Mode::POW &&
  430. SymbolVar{opr->input(1)}
  431. .as_immutable_scalar_require_shape()
  432. .valid())) {
  433. return &Impl::process_opr_chain<MulTrait>;
  434. }
  435. }
  436. if (opr->same_type<opr::PowC>()) {
  437. return &Impl::process_opr_chain<MulTrait>;
  438. }
  439. return nullptr;
  440. };
  441. VarNode* out0 = nullptr;
  442. auto func_self = dispatch_proc_fn(opr);
  443. if (func_self) {
  444. out0 = opr->output(0);
  445. }
  446. bool inp_changed = false;
  447. for (auto i: opr->input()) {
  448. if (m_rewriter.has_manual_replace(i)) {
  449. inp_changed = true;
  450. continue;
  451. }
  452. auto func_in = dispatch_proc_fn(i->owner_opr());
  453. if (func_in && (func_in != func_self || m_var2nr_val_dep.at(i) >= 2)) {
  454. // note: we process starting from an endpoint of a chain of the same
  455. // mode (either ADD or MUL) to ensure linear time complexity. An
  456. // endpoint is a var that must be preserved, which is either: (1)
  457. // received by multiple readers (2) received by an opr of different
  458. // mode or non-elemwise opr (3) the endpoint of the whole graph. The
  459. // cases (1) and (2) are handled here, and case (3) is handled
  460. // below by calling func_self().
  461. inp_changed = true;
  462. m_opt_state.call_with_opr(i->owner_opr(),
  463. [&]{(this->*func_in)(i);});
  464. }
  465. }
  466. if (inp_changed)
  467. m_rewriter.auto_replace_outputs(opr);
  468. if (func_self && m_opt_state.graph().endpoint_contain(out0)) {
  469. (this->*func_self)(out0);
  470. }
  471. }
  472. const char* NormalizeArithChainPass::name() const {
  473. return mgb_cstr_log("normalize_arith_expr");
  474. }
  475. void NormalizeArithChainPass::apply(OptState &opt) const {
  476. MIDOUT_B("NormalizeArithChainPass::apply")
  477. Impl{opt};
  478. MIDOUT_E
  479. }
  480. /* ================ ReorderArithChainPass ================ */
  481. class ReorderArithChainPass::Impl final: public ElemChainImplHelper {
  482. using ShapedVars = std::vector<std::pair<TensorShape, VarNode*>>;
  483. ConstVarPropogate m_cvprop;
  484. TensorShapeArray m_tmp_inp_shp;
  485. //! tmp var: (shape, is_const) -> terms
  486. TensorShapeHashKey::Map<std::array<VarNodeArray, 2>> m_shp2terms;
  487. ShapedVars m_const_terms, m_nonconst_terms;
  488. //! reduce two terms
  489. static VarNode* reduce(Mode mode, VarNode *a, VarNode *b);
  490. //! reduce m_shp2terms into a sum var
  491. VarNode* reduce_shp2terms(Mode mode);
  492. //! merge src and dst into dst, if merging does not broadcast both
  493. bool merge_shape_if_compatible(const TensorShape &src, TensorShape &dst);
  494. //! merge compatible shapes
  495. void merge_shaped_terms(Mode mode, ShapedVars &vars, bool allow_compatible);
  496. void process_chain(VarNode *endpoint, Mode mode) override;
  497. void on_opr_visited(OperatorNodeBase *opr) override {
  498. m_cvprop.add_opr(opr);
  499. }
  500. bool check_mode(Mode mode) override {
  501. return mode == Mode::ADD || mode == Mode::MUL ||
  502. mode == Mode::MAX || mode == Mode::MIN;
  503. }
  504. public:
  505. Impl(const ReorderArithChainPass &pass, OptState &opt_state):
  506. ElemChainImplHelper(opt_state),
  507. m_cvprop{pass.m_const_var_type}
  508. {
  509. run_elem_chain();
  510. }
  511. };
  512. VarNode* ReorderArithChainPass::Impl::reduce(
  513. Mode mode, VarNode *a, VarNode *b) {
  514. if (!a)
  515. return b;
  516. if (!b)
  517. return a;
  518. return opr::Elemwise::make({a, b}, mode).node();
  519. }
  520. bool ReorderArithChainPass::Impl::merge_shape_if_compatible(
  521. const TensorShape &src, TensorShape &dst) {
  522. m_tmp_inp_shp.resize(2);
  523. m_tmp_inp_shp[0] = src;
  524. m_tmp_inp_shp[1] = dst;
  525. TensorShape out;
  526. megdnn::Elemwise::deduce_shape(m_tmp_inp_shp, out);
  527. if (out.eq_shape(src)) {
  528. dst = out;
  529. return true;
  530. }
  531. return out.eq_shape(dst);
  532. }
  533. VarNode* ReorderArithChainPass::Impl::reduce_shp2terms(Mode mode) {
  534. // populate m_const_terms and m_nonconst_terms
  535. m_const_terms.clear();
  536. m_nonconst_terms.clear();
  537. for (auto &&i: m_shp2terms) {
  538. if (!i.second[0].empty()) {
  539. m_nonconst_terms.emplace_back(
  540. i.first.shape(),
  541. elemwise_reduce_var_list(i.second[0], mode));
  542. }
  543. if (!i.second[1].empty()) {
  544. m_const_terms.emplace_back(
  545. i.first.shape(),
  546. elemwise_reduce_var_list(i.second[1], mode));
  547. }
  548. }
  549. {
  550. // sorted by id(), so the same set of input terms would get the
  551. // same reduced var
  552. auto cmp = [](const ShapedVars::value_type &a,
  553. const ShapedVars::value_type &b) {
  554. return a.second->id() < b.second->id();
  555. };
  556. small_sort(m_const_terms.begin(), m_const_terms.end(), cmp);
  557. small_sort(m_nonconst_terms.begin(), m_nonconst_terms.end(), cmp);
  558. }
  559. merge_shaped_terms(mode, m_const_terms, true);
  560. auto &&all_terms = m_const_terms;
  561. all_terms.insert(all_terms.end(),
  562. m_nonconst_terms.begin(), m_nonconst_terms.end());
  563. // merge eq shape
  564. merge_shaped_terms(mode, all_terms, false);
  565. // merge compatible shape
  566. merge_shaped_terms(mode, all_terms, true);
  567. // simple heuristic: reduce in increasing size order
  568. auto cmp = [](const ShapedVars::value_type &a,
  569. const ShapedVars::value_type &b) {
  570. return a.first.total_nr_elems() < b.first.total_nr_elems();
  571. };
  572. small_sort(all_terms.begin(), all_terms.end(), cmp);
  573. VarNode *sum = nullptr;
  574. for (auto &&i: all_terms) {
  575. sum = reduce(mode, sum, i.second);
  576. }
  577. mgb_assert(sum);
  578. return sum;
  579. }
  580. void ReorderArithChainPass::Impl::merge_shaped_terms(
  581. Mode mode, ShapedVars &vars, bool allow_compatible) {
  582. for (bool merged = true; merged;) {
  583. merged = false;
  584. for (size_t i = 0; !merged && i < vars.size(); ++ i) {
  585. auto &&src = vars[i];
  586. if (!src.first.ndim)
  587. continue;
  588. TensorShape dst_shape;
  589. size_t dst_idx = -1;
  590. auto update_dst = [&](size_t idx, const TensorShape &shp) {
  591. if (!dst_shape.ndim || shp.total_nr_elems() <
  592. dst_shape.total_nr_elems()) {
  593. dst_shape = shp;
  594. dst_idx = idx;
  595. }
  596. };
  597. for (size_t j = 0; j < vars.size(); ++ j) {
  598. auto &&dst = vars[j];
  599. if (i == j || !dst.first.ndim)
  600. continue;
  601. if (allow_compatible) {
  602. auto tshp = dst.first;
  603. if (merge_shape_if_compatible(src.first, tshp)) {
  604. update_dst(j, tshp);
  605. }
  606. } else {
  607. if (src.first.eq_shape(dst.first)) {
  608. update_dst(j, dst.first);
  609. }
  610. }
  611. }
  612. if (dst_shape.ndim) {
  613. auto &&dst = vars[dst_idx];
  614. dst.first = dst_shape;
  615. dst.second = reduce(mode, src.second, dst.second);
  616. mgb_assert(
  617. (!dst.second->shape().ndim &&
  618. !cg::is_static_var_shape(dst.second)) ||
  619. dst.second->shape().eq_shape(dst.first));
  620. std::swap(src, vars.back());
  621. vars.pop_back();
  622. merged = true;
  623. break;
  624. }
  625. }
  626. }
  627. }
  628. void ReorderArithChainPass::Impl::process_chain(VarNode *endpoint, Mode mode) {
  629. if (m_cvprop.is_const(endpoint))
  630. return;
  631. auto vars = extract_chain_terms(endpoint, mode);
  632. if (vars.size() == 1)
  633. return;
  634. // to ensure the same set of input terms get the same reduced var
  635. // TODO: consider maintain a cache(map) of (sorted input terms -> reduced var)
  636. std::sort(vars.begin(), vars.end(),
  637. [](VarNode *x, VarNode *y){ return x->id() < y->id(); });
  638. m_shp2terms.clear();
  639. for (auto i: vars) {
  640. auto inew = m_rewriter.get_var(i);
  641. m_shp2terms[i->shape()][m_cvprop.is_const(i)].push_back(inew);
  642. }
  643. auto sum = reduce_shp2terms(mode);
  644. if (m_rewriter.get_var(endpoint) != sum) {
  645. m_rewriter.replace_var(endpoint, sum,
  646. mgb_ssprintf_log("reorder %zu %s terms", vars.size(),
  647. megdnn::Elemwise::ModeTrait::from_mode(mode).name).c_str());
  648. }
  649. }
  650. const char* ReorderArithChainPass::name() const {
  651. return mgb_cstr_log("reorder_arith_chain");
  652. }
  653. void ReorderArithChainPass::apply(OptState &opt) const {
  654. MIDOUT_B("ReorderArithChainPass::apply")
  655. Impl{*this, opt};
  656. MIDOUT_E
  657. }
  658. /* ================ ArithFusePass ================ */
  659. class ArithFusePass::Impl final: public ElemChainImplHelper {
  660. using MulTermArray = std::vector<std::pair<VarNode*, VarNode*>>;
  661. class SumVars;
  662. size_t m_nr_fma3, m_nr_fma4;
  663. TensorShapeHashKey::PairMap<MulTermArray> m_mul_terms;
  664. TensorShapeHashKey::Map<VarNodeArray> m_bias_terms;
  665. bool check_mode(Mode mode) override {
  666. return mode == Mode::ADD;
  667. }
  668. void process_chain(VarNode *endpoint, Mode mode) override;
  669. VarNode* find_pop_bias_term(const TensorShape &shape) {
  670. auto iter = m_bias_terms.find(shape);
  671. if (iter != m_bias_terms.end()) {
  672. auto ret = elemwise_reduce_var_list(iter->second, Mode::ADD);
  673. m_bias_terms.erase(iter);
  674. return ret;
  675. }
  676. return nullptr;
  677. }
  678. VarNode* process_mul_term(MulTermArray &terms);
  679. bool on_opr_visit_finished(Elemwise *opr) override;
  680. public:
  681. Impl(OptState &opt_state): ElemChainImplHelper(opt_state) {
  682. run_elem_chain();
  683. }
  684. };
  685. class ArithFusePass::Impl::SumVars {
  686. VarNode *m_sum = nullptr;
  687. public:
  688. void add(SymbolVar var) {
  689. if (!m_sum) {
  690. m_sum = var.node();
  691. } else {
  692. m_sum = opr::add(m_sum, var).node();
  693. }
  694. }
  695. VarNode* get() const {
  696. return m_sum;
  697. }
  698. };
  699. void ArithFusePass::Impl::process_chain(VarNode *endpoint, Mode mode) {
  700. if (!endpoint->shape().ndim)
  701. return;
  702. mgb_assert(mode == Mode::ADD);
  703. m_mul_terms.clear();
  704. m_bias_terms.clear();
  705. m_nr_fma3 = m_nr_fma4 = 0;
  706. auto vars = extract_chain_terms(endpoint, mode);
  707. for (auto var: vars) {
  708. auto opr = var->owner_opr();
  709. Elemwise *mul;
  710. if (m_uniq_reader_check(var) && (mul = as_elem_opr(opr, Mode::MUL))) {
  711. auto a = mul->input(0), b = mul->input(1);
  712. if (a->shape().total_nr_elems() > b->shape().total_nr_elems()) {
  713. std::swap(a, b);
  714. }
  715. a = m_rewriter.get_var(a);
  716. b = m_rewriter.get_var(b);
  717. m_mul_terms[{a->shape(), b->shape()}].push_back({a, b});
  718. } else {
  719. var = m_rewriter.get_var(var);
  720. m_bias_terms[var->shape()].push_back(var);
  721. }
  722. }
  723. if (m_mul_terms.empty())
  724. return;
  725. // merge same shapes, so they can be used as bias by others
  726. for (auto i = m_mul_terms.begin(); i != m_mul_terms.end(); ) {
  727. auto &&s = i->first;
  728. if (s.first.shape().eq_shape(s.second.shape())) {
  729. auto merged = process_mul_term(i->second);
  730. mgb_assert(merged->shape().eq_shape(s.first.shape()));
  731. m_bias_terms[merged->shape()].push_back(merged);
  732. mgb_assert(i->second.empty());
  733. i = m_mul_terms.erase(i);
  734. } else {
  735. ++ i;
  736. }
  737. }
  738. // sort mul_terms by size
  739. TensorShapeArray shp_inp(2);
  740. using SortedTermItem = std::pair<size_t, MulTermArray*>;
  741. std::vector<SortedTermItem> mul_terms_sorted;
  742. for (auto &&i: m_mul_terms) {
  743. shp_inp[0] = i.first.first.shape();
  744. shp_inp[1] = i.first.second.shape();
  745. TensorShape tshp;
  746. megdnn::Elemwise::deduce_shape(shp_inp, tshp);
  747. mul_terms_sorted.push_back({tshp.total_nr_elems(), &i.second});
  748. }
  749. auto cmp = [](const SortedTermItem &a, const SortedTermItem &b) {
  750. return a.first < b.first || (
  751. a.first == b.first && a.second->size() < b.second->size());
  752. };
  753. std::sort(mul_terms_sorted.begin(), mul_terms_sorted.end(), cmp);
  754. // merge from smallest to largest
  755. for (auto &&i: mul_terms_sorted) {
  756. auto merged = process_mul_term(*i.second);
  757. mgb_assert(i.second->empty() && merged->shape().ndim);
  758. m_bias_terms[merged->shape()].push_back(merged);
  759. }
  760. SumVars sum_vars;
  761. for (auto &&i: m_bias_terms) {
  762. sum_vars.add(elemwise_reduce_var_list(i.second, Mode::ADD));
  763. }
  764. auto sum = sum_vars.get();
  765. m_rewriter.replace_var(endpoint, sum,
  766. mgb_ssprintf_log(
  767. "replace %zu fma3, %zu fma4", m_nr_fma3, m_nr_fma4).c_str());
  768. }
  769. VarNode* ArithFusePass::Impl::process_mul_term(MulTermArray &terms) {
  770. mgb_assert(!terms.empty());
  771. SumVars sum_vars;
  772. while (terms.size() >= 2) {
  773. auto b = terms.back();
  774. terms.pop_back();
  775. auto a = terms.back();
  776. terms.pop_back();
  777. ++ m_nr_fma4;
  778. sum_vars.add(Elemwise::make({a.first, a.second, b.first, b.second},
  779. Mode::FUSE_MUL_ADD4));
  780. }
  781. if (!terms.empty()) {
  782. auto t = terms.back();
  783. terms.pop_back();
  784. auto bias = find_pop_bias_term(t.first->shape());
  785. if (!bias)
  786. bias = find_pop_bias_term(t.second->shape());
  787. if (bias) {
  788. ++ m_nr_fma3;
  789. sum_vars.add(Elemwise::make({t.first, t.second, bias},
  790. Mode::FUSE_MUL_ADD3));
  791. } else {
  792. sum_vars.add(opr::mul(t.first, t.second));
  793. }
  794. }
  795. return sum_vars.get();
  796. }
  797. bool ArithFusePass::Impl::on_opr_visit_finished(Elemwise *opr) {
  798. if (opr->input().size() != 1)
  799. return true;
  800. if (!m_uniq_reader_check(opr->input(0)))
  801. return true;
  802. auto iadd = as_elem_opr(m_rewriter.get_var(opr->input(0)), Mode::ADD);
  803. if (!iadd)
  804. return true;
  805. if (opr->input(0)->dtype().category() == DTypeCategory::QUANTIZED)
  806. return true;
  807. Mode fmode;
  808. const char *msg;
  809. switch (opr->param().mode) {
  810. #define cb(m) \
  811. case Mode::m: \
  812. fmode = Mode::FUSE_ADD_##m; \
  813. msg = mgb_cstr_log("fuse " #m "(x + y)"); \
  814. break;
  815. FOREACH_FUSE_ADD_MODE(cb)
  816. #undef cb
  817. default:
  818. return true;
  819. }
  820. m_opt_state.call_with_opr(opr, [&]{
  821. auto fused = opr::Elemwise::make({iadd->input(0), iadd->input(1)},
  822. fmode).node();
  823. m_rewriter.replace_var(opr->output(0), fused, msg);
  824. m_uniq_reader_check.update_on_opr_auto_replace(opr, fused->owner_opr());
  825. });
  826. return false;
  827. }
  828. const char* ArithFusePass::name() const {
  829. return mgb_cstr_log("arith_fuse");
  830. }
  831. void ArithFusePass::apply(OptState &opt) const {
  832. MIDOUT_B("ArithFusePass::apply")
  833. Impl{opt};
  834. MIDOUT_E
  835. }
  836. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台