You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fusion_pass.cpp 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. /**
  2. * \file src/jit/impl/fusion_pass.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/jit/fusion_pass.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/gopt/gtrans.h"
  14. #include "megbrain/jit/ast_c.h"
  15. #include "megbrain/jit/compiler.h"
  16. #include "megbrain/jit/internal_graph.h"
  17. #include "megbrain/opr/tensor_manip.h"
  18. #include "megbrain/serialization/serializer.h"
  19. #if MGB_JIT
  20. #if MGB_JIT_MLIR
  21. #include "./mlir/ir/each_mode.h"
  22. #endif
  23. using namespace mgb;
  24. using namespace gopt;
  25. using namespace jit;
  26. class JITFusionPass::Impl final {
  27. using Mode = opr::Elemwise::Mode;
  28. using DepType = OperatorNodeBase::NodeProp::DepType;
  29. const bool m_after_grad;
  30. JITFeatureBits m_feature_bits;
  31. OptState& m_opt_state;
  32. CompNode::UnorderedMap<size_t> m_cn2max_nr_input;
  33. SubGraph::Rewriter m_rewriter;
  34. SmallVector<std::unique_ptr<InternalGraphGenerator>> m_igraph_gen_storage;
  35. ThinHashMap<VarNode*, InternalGraphGenerator*> m_var2igraph_gen;
  36. //! map from var to its reader oprs and the corresponding dependency types
  37. ThinHashMap<VarNode*, SmallVector<std::pair<OperatorNodeBase*, DepType>>>
  38. m_var_readers;
  39. ThinHashSet<VarNode*> m_endpoint_set;
  40. //! create a new InternalGraphGenerator rooted at given opr
  41. InternalGraphGenerator* create_new_igraph_gen(OperatorNodeBase* opr);
  42. //! process a single operator, maintaining m_var2igraph_gen
  43. void process_opr(OperatorNodeBase* opr);
  44. size_t max_nr_input(CompNode cn);
  45. //! check whether all oprs which depend on the var are in i_graph
  46. bool test_all_readers_in_the_graph(VarNode* var, InternalGraphGenerator* i_graph);
  47. //! check shape to determine whether the opr should be added to the internal
  48. //! graph
  49. bool check_shape(cg::OperatorNodeBase* opr, InternalGraphGenerator* i_graph);
  50. //! use m_rewriter to update graph
  51. void update_graph();
  52. //! find the subgraph which can be fused
  53. void detect_fusion();
  54. //! check whether an opr can be fused
  55. bool can_be_fused(cg::OperatorNodeBase* opr) const;
  56. static size_t nr_non_const_vars(const VarNodeArray& vars) {
  57. size_t num = 0;
  58. for (auto i : vars) {
  59. num += !SymbolVar{i}.as_immutable_scalar().valid();
  60. }
  61. return num;
  62. }
  63. public:
  64. Impl(bool after_grad, JITFeatureBits feature_bits, OptState& opt_state)
  65. : m_after_grad{after_grad},
  66. m_feature_bits{feature_bits},
  67. m_opt_state{opt_state},
  68. m_rewriter{opt_state.graph().make_rewriter()} {
  69. detect_fusion();
  70. update_graph();
  71. }
  72. };
  73. void JITFusionPass::Impl::detect_fusion() {
  74. std::vector<OperatorNodeBase*> topo_order;
  75. m_opt_state.graph().iter([this, &topo_order](OperatorNodeBase* opr) {
  76. topo_order.push_back(opr);
  77. for (auto&& i : opr->node_prop().dep_map()) {
  78. m_var_readers[i.first].emplace_back(opr, i.second);
  79. }
  80. });
  81. for (auto opr : reverse_adaptor(topo_order)) {
  82. if (can_be_fused(opr)) {
  83. process_opr(opr);
  84. }
  85. }
  86. }
  87. void JITFusionPass::Impl::update_graph() {
  88. auto process = [this](OperatorNodeBase* opr) {
  89. if (!Compiler::is_supported_device(opr->output(0)->comp_node().device_type()))
  90. return;
  91. auto fuse_varnode = [this](VarNode* var) {
  92. auto ig_gen_iter = m_var2igraph_gen.find(var);
  93. if (ig_gen_iter == m_var2igraph_gen.end()) {
  94. return;
  95. }
  96. auto ig_gen = ig_gen_iter->second;
  97. if (m_endpoint_set.count(var) != 0 && ig_gen->opr_set().size() >= 2) {
  98. auto igraph = ig_gen->generate();
  99. auto&& inputs = ig_gen->orig_inps();
  100. if (m_after_grad || nr_non_const_vars(inputs) == 1) {
  101. // in the forward pass, only fuse oprs with one non-const
  102. // inp
  103. VarNodeArray rewritten_inputs;
  104. for (auto&& input : inputs) {
  105. auto new_input = m_rewriter.get_var(input);
  106. rewritten_inputs.push_back(new_input);
  107. }
  108. auto fusion_op = JITExecutor::make(igraph, rewritten_inputs);
  109. m_rewriter.replace_var(
  110. var, fusion_op.node(),
  111. mgb_ssprintf_log(
  112. "fuse endpoint: %s", var->owner_opr()->cname())
  113. .c_str());
  114. }
  115. }
  116. };
  117. for (auto i : opr->input()) {
  118. if (!m_rewriter.has_manual_replace(i)) {
  119. // if input i is a endpoint, and number of oprs in this subgraph
  120. // is greater than 2
  121. m_opt_state.call_with_opr(i->owner_opr(), [&] { fuse_varnode(i); });
  122. }
  123. }
  124. m_rewriter.auto_replace_outputs(opr);
  125. if (m_opt_state.graph().endpoint_contain(opr->output(0))) {
  126. // process final endpoint
  127. fuse_varnode(opr->output(0));
  128. }
  129. };
  130. m_opt_state.graph().iter(process);
  131. m_rewriter.apply_inplace();
  132. }
  133. bool JITFusionPass::Impl::test_all_readers_in_the_graph(
  134. VarNode* var, InternalGraphGenerator* ig_gen) {
  135. for (auto&& reader : m_var_readers.at(var)) {
  136. if (reader.second & DepType::DEV_VALUE) {
  137. if (ig_gen->opr_set().count(reader.first) == 0) {
  138. return false;
  139. }
  140. }
  141. }
  142. return true;
  143. }
  144. bool JITFusionPass::Impl::check_shape(
  145. cg::OperatorNodeBase* opr, InternalGraphGenerator* ig_gen) {
  146. if (!cg::is_static_var_shape(opr->output(0))) {
  147. // currently we do not handle dynamic shape in JIT
  148. return false;
  149. }
  150. if (!(m_feature_bits & JITFeatureBits::REDUCE)) {
  151. // By requiring opr output shape to be the same as final output shape,
  152. // we permit only one broadcast. If multiple broadcasts are fused,
  153. // together, execution would be actually slower.
  154. if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) && ig_gen->has_dimshuffle() &&
  155. ig_gen->oprs_depended_by_dimshuffe().count(opr)) {
  156. return opr->output(0)->shape().eq_shape(
  157. ig_gen->oprs_depended_by_dimshuffe().at(opr)->input(0)->shape());
  158. } else {
  159. return opr->output(0)->shape().eq_shape(ig_gen->output()->shape());
  160. }
  161. }
  162. bool before_reduce = false;
  163. for (auto&& op_set : ig_gen->reduce_out_var_deps()) {
  164. if (op_set.second.count(opr)) {
  165. before_reduce = true;
  166. break;
  167. }
  168. }
  169. if (opr->same_type<JITExecutor>()) {
  170. auto jit = &opr->cast_final<JITExecutor>();
  171. bool jit_has_reduce = jit->has_reduce();
  172. auto jit_inp_shp = jit->broadcasted_input_shape();
  173. if (jit_has_reduce) {
  174. if (before_reduce)
  175. return jit_inp_shp.eq_shape(jit->output(0)->shape()) &&
  176. jit_inp_shp.eq_shape(ig_gen->before_reduce_shape());
  177. else {
  178. bool ret = true;
  179. if (ig_gen->has_reduce()) {
  180. ret &= jit_inp_shp.eq_shape(ig_gen->before_reduce_shape());
  181. }
  182. ret &= jit->output(0)->shape().eq_shape(ig_gen->output()->shape());
  183. return ret;
  184. }
  185. }
  186. }
  187. if (opr->same_type<opr::Reduce>()) {
  188. // TODO: handle reduce target shape in sub graph (especially considering
  189. // placeholder has constant shape)
  190. //
  191. // The best way is to have a dedicated AST for the internal graph; but
  192. // we want to reuse the deduplication and gradient mechanisms from the
  193. // mgb cg
  194. auto reduce = &opr->cast_final<opr::Reduce>();
  195. if (before_reduce) {
  196. return reduce->input(0)->shape().eq_shape(ig_gen->before_reduce_shape()) &&
  197. reduce->output(0)->shape().eq_shape(ig_gen->before_reduce_shape());
  198. } else {
  199. bool ret = true;
  200. if (ig_gen->has_reduce()) {
  201. ret &= reduce->input(0)->shape().eq_shape(
  202. ig_gen->before_reduce_shape());
  203. }
  204. ret &= reduce->output(0)->shape().eq_shape(ig_gen->output()->shape());
  205. return ret;
  206. }
  207. }
  208. if (before_reduce) {
  209. return opr->output(0)->shape().eq_shape(ig_gen->before_reduce_shape());
  210. } else {
  211. return opr->output(0)->shape().eq_shape(ig_gen->output()->shape());
  212. }
  213. }
  214. InternalGraphGenerator* JITFusionPass::Impl::create_new_igraph_gen(
  215. OperatorNodeBase* opr) {
  216. auto uptr = std::make_unique<InternalGraphGenerator>(opr);
  217. auto ptr = uptr.get();
  218. m_igraph_gen_storage.emplace_back(std::move(uptr));
  219. m_var2igraph_gen[opr->output(0)] = ptr;
  220. m_endpoint_set.insert(opr->output(0));
  221. return ptr;
  222. }
  223. void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
  224. auto max_nr_input = this->max_nr_input(opr->output(0)->comp_node());
  225. if (nr_non_const_vars(opr->input()) > max_nr_input ||
  226. !cg::is_static_var_shape(opr->output(0))) {
  227. return;
  228. }
  229. // dimshuffle should not be an endpoint, because megbrain has lazy
  230. // dimshuffle machanism
  231. InternalGraphGenerator* ig_gen = nullptr;
  232. if (m_var2igraph_gen.count(opr->output(0)) == 0) {
  233. // because of the reverse traversal, when an operator is being
  234. // processed but not in m_var2igraph_gen, means it is a endpoint of a
  235. // JIT subgraph.
  236. if (opr->same_type<opr::Dimshuffle>()) {
  237. return;
  238. }
  239. ig_gen = create_new_igraph_gen(opr);
  240. } else {
  241. ig_gen = m_var2igraph_gen[opr->output(0)];
  242. // if all oprs which depend on this elemwise opr's output were already
  243. // in the subgraph and the opr's comp_node is same with the subgraph's,
  244. // then this opr can be fused to this graph as an internal node rather
  245. // than a leaf.
  246. bool cond_readers = test_all_readers_in_the_graph(opr->output(0), ig_gen),
  247. cond_cn = opr->output(0)->comp_node() == ig_gen->output()->comp_node(),
  248. cond_shp = check_shape(opr, ig_gen),
  249. cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input,
  250. cond_mlir_specific = true;
  251. if (cond_readers && cond_cn && cond_shp && cond_nr_inp && cond_mlir_specific) {
  252. ig_gen->add_opr(opr);
  253. } else {
  254. if (opr->same_type<opr::Dimshuffle>()) {
  255. return;
  256. }
  257. // create a new sub graph starting from this opr
  258. mgb_log_debug(
  259. "JIT graph stopped at opr %s{%s}: cond: readers=%d cn=%d "
  260. "shp=%d nr_inp=%d",
  261. opr->cname(), opr->dyn_typeinfo()->name, cond_readers, cond_cn,
  262. cond_shp, cond_nr_inp);
  263. ig_gen = create_new_igraph_gen(opr);
  264. }
  265. }
  266. // handle const inputs
  267. for (auto&& i : opr->node_prop().dep_map()) {
  268. if (i.second & cg::OperatorNodeBase::NodeProp::DepType::DEV_VALUE) {
  269. if (SymbolVar{i.first}.as_immutable_scalar_require_shape().valid()) {
  270. auto opr = i.first->owner_opr();
  271. mgb_assert(
  272. opr->same_type<opr::ImmutableTensor>(),
  273. "got imm scalar from non ImmutableTensor: %s{%s}", opr->cname(),
  274. opr->dyn_typeinfo()->name);
  275. ig_gen->add_opr(opr);
  276. continue;
  277. }
  278. }
  279. m_var2igraph_gen[i.first] = ig_gen;
  280. }
  281. }
  282. size_t JITFusionPass::Impl::max_nr_input(CompNode cn) {
  283. auto&& ret = m_cn2max_nr_input[cn];
  284. if (!ret) {
  285. ret = Compiler::get(*m_opt_state.graph().comp_graph(), cn)
  286. ->property()
  287. .max_nr_input;
  288. mgb_assert(ret);
  289. }
  290. return ret;
  291. }
  292. bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
  293. if (!Compiler::is_supported_device(opr->output(0)->comp_node().device_type())) {
  294. return false;
  295. }
  296. //! As MLIR backend has some contraints
  297. const char* backend = MGB_GETENV("MGB_JIT_BACKEND");
  298. if (!backend) {
  299. backend = "DEFAULT";
  300. }
  301. // float elemwise
  302. if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) {
  303. bool ret = true;
  304. #if MGB_JIT_MLIR
  305. if (!strcmp(backend, "MLIR")) {
  306. switch (elem->param().mode) {
  307. #define cb(_, _mode) \
  308. case opr::Elemwise::Mode::_mode: \
  309. ret = true; \
  310. break;
  311. MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb)
  312. MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
  313. MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
  314. default:
  315. ret = false;
  316. #undef cb
  317. }
  318. #define FOREACH_ELEMWISE_SKIP_MODE(cb) cb(SIN)
  319. //! FIXME mlir on cuda does't support sin currently.
  320. if (opr->output(0)->comp_node().device_type() ==
  321. CompNode::DeviceType::CUDA) {
  322. switch (elem->param().mode) {
  323. #define cb(_mode) \
  324. case opr::Elemwise::Mode::_mode: \
  325. ret = false; \
  326. break;
  327. FOREACH_ELEMWISE_SKIP_MODE(cb)
  328. default:
  329. break;
  330. #undef cb
  331. }
  332. }
  333. #undef FOREACH_ELEMWISE_SKIP_MODE
  334. }
  335. #endif // MGB_JIT_MLIR
  336. return ret && ast_c::check_elem_mode(elem->param().mode) &&
  337. elem->output(0)->dtype().category() == DTypeCategory::FLOAT;
  338. }
  339. if (strcmp(backend, "MLIR")) {
  340. if (opr->same_type<opr::PowC>()) {
  341. return true;
  342. }
  343. // float typecvt (e.g. used in f16 training)
  344. if (opr->same_type<opr::TypeCvt>()) {
  345. auto category = opr->input(0)->dtype().category();
  346. if (category != opr->output(0)->dtype().category())
  347. return false;
  348. return category == DTypeCategory::FLOAT;
  349. }
  350. // float reduce
  351. if ((m_feature_bits & JITFeatureBits::REDUCE) &&
  352. opr->same_type<opr::Reduce>()) {
  353. return opr->output(0)->dtype().category() == DTypeCategory::FLOAT;
  354. }
  355. // dimshuffle
  356. if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) &&
  357. opr->same_type<opr::Dimshuffle>()) {
  358. auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
  359. return param.pattern_len <= 4;
  360. }
  361. }
  362. // existing JITExecutor
  363. if (opr->same_type<JITExecutor>())
  364. return true;
  365. return false;
  366. }
  367. JITFusionPass::JITFusionPass(
  368. bool after_grad, int jit_opt_level, const JITConfig& jit_config)
  369. : m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} {
  370. // get default config from jit_opt_level
  371. JITConfig config;
  372. if (jit_opt_level == 1) {
  373. config.fuse_dimshuffle = JITConfig::ON;
  374. config.fuse_reduce = JITConfig::OFF;
  375. } else if (jit_opt_level >= 2) {
  376. config.fuse_dimshuffle = JITConfig::OFF;
  377. config.fuse_reduce = JITConfig::ON;
  378. }
  379. // overwrite default config with custom settings
  380. config.update(jit_config);
  381. bool fuse_dimshuffle = config.fuse_dimshuffle == JITConfig::ON;
  382. bool fuse_reduce = config.fuse_reduce == JITConfig::ON;
  383. if (fuse_dimshuffle && fuse_reduce) {
  384. mgb_assert(false, "reduce and dimshuffle can not coexist now");
  385. }
  386. if (fuse_dimshuffle) {
  387. m_feature_bits |= JITFeatureBits::DIMSHUFFLE;
  388. }
  389. if (fuse_reduce) {
  390. m_feature_bits |= JITFeatureBits::REDUCE;
  391. }
  392. }
  393. const char* JITFusionPass::name() const {
  394. return mgb_cstr_log("fusion_pass");
  395. }
  396. void JITFusionPass::apply(OptState& opt) const {
  397. Impl{m_after_grad, m_feature_bits, opt};
  398. }
  399. #endif
  400. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}