You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fuse_nchw4_int8_preprocess.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. /**
  2. * \file src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/gopt/inference.h"
  13. #include "megbrain/gopt/misc.h"
  14. #include "megbrain/graph/grad_impl.h"
  15. #include "megbrain/opr/cond.h"
  16. #include "megbrain/opr/io.h"
  17. #include "megbrain/opr/tensor_manip.h"
  18. #include "megbrain/opr/utility.h"
  19. #include "megbrain/serialization/opr_shallow_copy.h"
  20. #include "megbrain/serialization/serializer.h"
  21. #include "megbrain/opr/imgproc.h"
  22. using namespace mgb;
  23. using namespace gopt;
  24. namespace {
  25. #define RETURN_IF_FALSE(ok) \
  26. { \
  27. if (!ok) \
  28. return ok; \
  29. }
  30. struct SubGraphMatcher {
  31. struct Node {
  32. using CallBack = std::function<bool(OperatorNodeBase* opr)>;
  33. Node(Typeinfo* in_op_type) : op_type(in_op_type){};
  34. Node(Typeinfo* in_op_type, CallBack func)
  35. : op_type(in_op_type), cbk(func){};
  36. Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node)
  37. : op_type(in_op_type), pre_node(in_pre_node){};
  38. Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node,
  39. CallBack func)
  40. : op_type(in_op_type), pre_node(in_pre_node), cbk(func){};
  41. Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node,
  42. CallBack func, std::string in_msg)
  43. : op_type(in_op_type),
  44. pre_node(in_pre_node),
  45. cbk(func),
  46. msg(in_msg){};
  47. Typeinfo* op_type{nullptr};
  48. std::vector<std::vector<Node>> pre_node;
  49. //! cbk used to check param and gather args for creating fusion op
  50. CallBack cbk;
  51. std::string msg{""};
  52. };
  53. bool match(Node& root, OperatorNodeBase* opr) {
  54. if (opr == nullptr) {
  55. return false;
  56. }
  57. //! match nullptr node always
  58. if (root.op_type == nullptr || root.op_type == opr->dyn_typeinfo()) {
  59. bool current_match = true;
  60. if (root.cbk)
  61. current_match &= root.cbk(opr);
  62. RETURN_IF_FALSE(current_match);
  63. auto& inp = opr->input();
  64. bool any_sub_patten_match =
  65. root.pre_node.size() == 0 ? true : false;
  66. for (auto& sub_patten : root.pre_node) {
  67. bool patten_ok = true;
  68. for (size_t node_idx = 0; node_idx < sub_patten.size();
  69. ++node_idx) {
  70. bool valid_node_idx = node_idx < inp.size();
  71. if (!valid_node_idx) {
  72. patten_ok = false;
  73. break;
  74. }
  75. patten_ok = patten_ok && match(sub_patten[node_idx],
  76. inp[node_idx]->owner_opr());
  77. if (!patten_ok) {
  78. break;
  79. }
  80. }
  81. any_sub_patten_match = any_sub_patten_match || patten_ok;
  82. if (any_sub_patten_match) {
  83. break;
  84. }
  85. }
  86. return current_match && any_sub_patten_match;
  87. } else {
  88. return false;
  89. }
  90. }
  91. };
  92. #undef RETURN_IF_FALSE
  93. struct SubGraphChecker {
  94. using DepType = cg::OperatorNodeProp::DepType;
  95. using ReaderType =
  96. ThinHashMap<OperatorNodeBase*,
  97. SmallVector<std::pair<OperatorNodeBase*, DepType>>>;
  98. SubGraphChecker() {}
  99. bool check(ThinHashSet<OperatorNodeBase*> used_input,
  100. OperatorNodeBase* start_opr, OperatorNodeBase* stop_opr,
  101. ReaderType& readers, bool ignore_immutable = true) {
  102. bool is_all_inp_used = check_all_inp_used(used_input, start_opr,
  103. stop_opr, ignore_immutable);
  104. bool is_all_dep_inside =
  105. check_all_dep_inside_node(start_opr, stop_opr, readers);
  106. return is_all_inp_used && is_all_dep_inside;
  107. }
  108. bool check_all_inp_used(ThinHashSet<OperatorNodeBase*>& used_input,
  109. OperatorNodeBase* start_opr,
  110. OperatorNodeBase* stop_opr,
  111. bool ignore_immutable = true) {
  112. ThinHashSet<OperatorNodeBase*> leaf_set;
  113. get_leaf_node(start_opr, stop_opr, leaf_set);
  114. for (auto in_opr : leaf_set) {
  115. bool skip = in_opr->same_type<opr::ImmutableTensor>() &&
  116. ignore_immutable;
  117. if (used_input.find(in_opr) == used_input.end() && !skip) {
  118. return false;
  119. }
  120. }
  121. return true;
  122. }
  123. bool check_all_dep_inside_node(OperatorNodeBase* start_opr,
  124. OperatorNodeBase* stop_opr,
  125. ReaderType& readers) {
  126. ThinHashSet<OperatorNodeBase*> mid_set;
  127. get_mid_node(start_opr, start_opr, stop_opr, mid_set);
  128. for (auto inner_opr : mid_set) {
  129. if (readers.find(inner_opr) != readers.end()) {
  130. for (auto& out_node : readers[inner_opr]) {
  131. if (mid_set.find(out_node.first) == mid_set.end() &&
  132. out_node.first != start_opr &&
  133. out_node.second ==
  134. cg::OperatorNodeProp::DepType::DEV_VALUE) {
  135. return false;
  136. }
  137. }
  138. }
  139. }
  140. return true;
  141. }
  142. void get_mid_node(OperatorNodeBase* opr, OperatorNodeBase* start_opr,
  143. OperatorNodeBase* stop_opr,
  144. ThinHashSet<OperatorNodeBase*>& mid_set) {
  145. if (opr == nullptr) {
  146. return;
  147. }
  148. if (opr != start_opr) {
  149. mid_set.insert(opr);
  150. }
  151. if (opr == stop_opr) {
  152. return;
  153. }
  154. for (auto& tensor : opr->input()) {
  155. auto pre_opr = tensor->owner_opr();
  156. get_mid_node(pre_opr, start_opr, stop_opr, mid_set);
  157. }
  158. }
  159. void get_leaf_node(OperatorNodeBase* opr, OperatorNodeBase* stop_opr,
  160. ThinHashSet<OperatorNodeBase*>& leaf_set) {
  161. if (opr == nullptr) {
  162. return;
  163. }
  164. if (opr == stop_opr || opr->input().size() == 0) {
  165. leaf_set.insert(opr);
  166. }
  167. if (opr == stop_opr) {
  168. return;
  169. }
  170. for (auto& tensor : opr->input()) {
  171. auto pre_opr = tensor->owner_opr();
  172. get_leaf_node(pre_opr, stop_opr, leaf_set);
  173. }
  174. }
  175. };
  176. static inline bool is_shape_nchw(const TensorShape& shape) {
  177. return shape.ndim == 4;
  178. }
  179. static inline bool is_shape_before_nchw4(const TensorShape& shape) {
  180. return shape.ndim == 5 && shape[2] == 4;
  181. }
  182. static inline bool is_nchw_nchw4_shuffle_vec(
  183. const opr::Dimshuffle::Param param) {
  184. return param.ndim == 5 && param.pattern[0] == 0 && param.pattern[1] == 1 &&
  185. param.pattern[2] == 3 && param.pattern[3] == 4 &&
  186. param.pattern[4] == 2;
  187. }
  188. static inline bool is_shape_before_nhwc(const TensorShape& shape) {
  189. return shape.ndim == 4 && shape[1] == 4;
  190. }
  191. static inline bool is_nchw_nhwc_shuffle(const opr::Dimshuffle::Param param) {
  192. return param.ndim == 4 && param.pattern[0] == 0 && param.pattern[1] == 2 &&
  193. param.pattern[2] == 3 && param.pattern[3] == 1;
  194. }
  195. template <typename T>
  196. static inline bool is_immutable_equal(OperatorNodeBase* opr, T val,
  197. DTypeEnum dtype_enum) {
  198. auto const_opr = opr->try_cast_final<opr::ImmutableTensor>();
  199. if (!const_opr) {
  200. return false;
  201. }
  202. auto& host_value = const_opr->host_value();
  203. bool ok_value = host_value.layout().total_nr_elems() == 1 &&
  204. host_value.dtype().enumv() == dtype_enum &&
  205. host_value.ptr<T>()[0] == val;
  206. return ok_value;
  207. }
  208. template <typename T>
  209. static inline bool is_immutable_all_equal(OperatorNodeBase* opr,
  210. typename DTypeTrait<T>::ctype val) {
  211. auto const_opr = opr->try_cast_final<opr::ImmutableTensor>();
  212. if (!const_opr) {
  213. return false;
  214. }
  215. auto& host_value = const_opr->host_value();
  216. bool ok_value = host_value.dtype().enumv() == DTypeTrait<T>::enumv;
  217. if (!ok_value) {
  218. return false;
  219. }
  220. size_t nr_elem = host_value.layout().total_nr_elems();
  221. for (size_t i = 0; i < nr_elem; ++i) {
  222. if (host_value.ptr<typename DTypeTrait<T>::ctype>()[i] != val) {
  223. ok_value = false;
  224. break;
  225. }
  226. }
  227. return ok_value;
  228. }
  229. } // namespace
  230. const char* FuseNCHW4Int8Preprocess::name() const {
  231. return "fuse_pre_process_pass";
  232. }
  233. std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
  234. using SGM = SubGraphMatcher;
  235. auto gen_pad_dimshuffle_graph = [&](SGM::Node& in_node,
  236. SGM::Node::CallBack& pad_cbk,
  237. SGM::Node::CallBack& shape_cbk) {
  238. SGM::Node::CallBack check_pad = [&](OperatorNodeBase* opr) {
  239. SGM sub_matcher;
  240. SGM::Node immu_node{opr::ImmutableTensor::typeinfo(), pad_cbk};
  241. if (opr->same_type<opr::ImmutableTensor>()) {
  242. return sub_matcher.match(immu_node, opr);
  243. } else if (opr->same_type<opr::Broadcast>()) {
  244. return sub_matcher.match(immu_node,
  245. opr->input()[0]->owner_opr());
  246. } else {
  247. return false;
  248. }
  249. };
  250. SGM::Node broadcast_or_immutable{
  251. nullptr, {}, check_pad, "broadcast_or_immutable"};
  252. SGM::Node broadcast_concat{
  253. opr::Concat::typeinfo(),
  254. {{in_node, broadcast_or_immutable}},
  255. [](OperatorNodeBase* opr) {
  256. auto concat_pad = opr->try_cast_final<opr::Concat>();
  257. return concat_pad->axis() == 1;
  258. },
  259. "broadcast_concat"};
  260. SGM::Node nchwx_reshape{opr::Reshape::typeinfo(),
  261. {{broadcast_concat, SGM::Node(nullptr)}},
  262. [](OperatorNodeBase* opr) {
  263. auto inp0 = opr->input()[0];
  264. return is_shape_nchw(inp0->shape());
  265. }};
  266. SGM::Node shuffle_root{
  267. opr::Dimshuffle::typeinfo(),
  268. {{nchwx_reshape}, {broadcast_concat}},
  269. [](OperatorNodeBase* opr) {
  270. auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>();
  271. auto& input_vec = shuffle_opr.input();
  272. bool nchw_nchw4_ok =
  273. is_shape_before_nchw4(input_vec[0]->shape()) &&
  274. is_nchw_nchw4_shuffle_vec(shuffle_opr.param());
  275. bool nchw_nhwc_ok =
  276. is_shape_before_nhwc(input_vec[0]->shape()) &&
  277. is_nchw_nhwc_shuffle(shuffle_opr.param());
  278. return nchw_nchw4_ok || nchw_nhwc_ok;
  279. }};
  280. return shuffle_root;
  281. };
  282. auto gen_u8_cvt2_q8 = [](OperatorNodeBase*& src_node,
  283. OperatorNodeBase*& neg_128_immu_node) {
  284. SGM::Node input_data_u8{nullptr, [&](OperatorNodeBase* opr) {
  285. auto src_dtype = opr->output()[0]->dtype();
  286. if (src_dtype.enumv() == DTypeEnum::Uint8) {
  287. src_node = opr;
  288. return true;
  289. } else {
  290. return false;
  291. }
  292. }};
  293. SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(),
  294. {{input_data_u8}},
  295. [](OperatorNodeBase* opr) {
  296. auto cvt_op =
  297. opr->try_cast_final<opr::TypeCvt>();
  298. bool is_fp32 = cvt_op->param().enumv() ==
  299. DTypeEnum::Float32;
  300. return is_fp32;
  301. }};
  302. SGM::Node sub_128{
  303. opr::Elemwise::typeinfo(),
  304. {{cvt_fp32, nullptr}, {nullptr, cvt_fp32}},
  305. [&](OperatorNodeBase* opr) {
  306. auto elem_op = opr->try_cast_final<opr::Elemwise>();
  307. bool is_add_op = elem_op->param().mode ==
  308. opr::Elemwise::Param::Mode::ADD;
  309. auto neg_128_op = elem_op->input()[1]->owner_opr();
  310. bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f,
  311. DTypeEnum::Float32);
  312. neg_128_op = elem_op->input()[0]->owner_opr();
  313. is_neg_128 = is_neg_128 ||
  314. is_immutable_equal(neg_128_op, -128.f,
  315. DTypeEnum::Float32);
  316. neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr;
  317. return is_add_op && is_neg_128;
  318. },
  319. "sub_128"};
  320. return sub_128;
  321. };
  322. auto replace_shuffle_opr = [&](OperatorNodeBase* opr,
  323. const VarNodeArray& new_inp,
  324. SubGraph::Rewriter& rewriter,
  325. ReaderType& reader) {
  326. SGM matcher;
  327. OperatorNodeBase* src_node = nullptr;
  328. OperatorNodeBase* neg_128_immu_node = nullptr;
  329. auto u8_q8_input = gen_u8_cvt2_q8(src_node, neg_128_immu_node);
  330. SGM::Node input_data_qu8{
  331. nullptr, [&](OperatorNodeBase* opr) {
  332. auto src_dtype = opr->output()[0]->dtype();
  333. if (src_dtype.enumv() == DTypeEnum::Quantized8Asymm) {
  334. src_node = opr;
  335. return true;
  336. } else {
  337. return false;
  338. }
  339. }};
  340. SGM::Node type_cvt{opr::TypeCvt::typeinfo(),
  341. {{input_data_qu8}, {u8_q8_input}},
  342. [](OperatorNodeBase* opr) {
  343. auto cvt_op =
  344. opr->try_cast_final<opr::TypeCvt>();
  345. if (cvt_op) {
  346. return cvt_op->param().enumv() ==
  347. DTypeEnum::QuantizedS8;
  348. } else {
  349. return false;
  350. }
  351. }};
  352. SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) {
  353. bool is_fp32_pad = is_immutable_all_equal<dtype::Float32>(opr, 0);
  354. bool is_i32_pad = is_immutable_all_equal<dtype::Int32>(opr, 0);
  355. bool is_q8_pad = is_immutable_all_equal<dtype::QuantizedS8>(
  356. opr, dt_qint8(0));
  357. return is_fp32_pad || is_i32_pad || is_q8_pad;
  358. };
  359. SGM::Node::CallBack const_reshape_cbk = [](OperatorNodeBase* opr) {
  360. return true;
  361. };
  362. auto&& shuffle_root = gen_pad_dimshuffle_graph(type_cvt, const_pad_cbk,
  363. const_reshape_cbk);
  364. bool match = matcher.match(shuffle_root, opr);
  365. bool check_ok = false;
  366. if (match) {
  367. check_ok =
  368. SubGraphChecker().check({src_node}, opr, src_node, reader);
  369. }
  370. if (match && check_ok) {
  371. opr::RelayoutFormat::Param param;
  372. param.mode = opr::RelayoutFormat::Param::Mode::NCHW_NCHW4;
  373. OperatorNodeConfig config(opr->output()[0]->dtype());
  374. auto out_node = opr::RelayoutFormat::make(
  375. rewriter.get_var(src_node->output()[0]), param.mode,
  376. config);
  377. const auto& outshp = opr->output(0)->shape();
  378. if (outshp.ndim == 4) {
  379. auto shpvar = opr::GetVarShape::make(out_node);
  380. auto cv = [&out_node](int v) {
  381. return out_node.make_scalar(v);
  382. };
  383. auto sub = [&shpvar, &cv](int idx) {
  384. return opr::IndexAt::make(shpvar, {{0, cv(idx)}});
  385. };
  386. auto nhwc_shp =
  387. opr::Concat::make({sub(0), sub(2), sub(3), sub(4)}, 0);
  388. out_node = opr::Reshape::make(out_node, nhwc_shp);
  389. }
  390. return out_node.node()->owner_opr();
  391. } else {
  392. return serialization::copy_opr_shallow(*opr, new_inp,
  393. opr->config());
  394. }
  395. };
  396. auto replace_astype_opr = [&](OperatorNodeBase* opr,
  397. const VarNodeArray& new_inp,
  398. SubGraph::Rewriter& rewriter,
  399. ReaderType& reader) {
  400. SGM matcher;
  401. OperatorNodeBase* src_node = nullptr;
  402. OperatorNodeBase* neg_128_immu_node = nullptr;
  403. OperatorNodeBase* pad0_immu_node = nullptr;
  404. OperatorNodeBase* const_reshape_last_dim_node = nullptr;
  405. auto sub_128 = gen_u8_cvt2_q8(src_node, neg_128_immu_node);
  406. SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) {
  407. pad0_immu_node = opr;
  408. bool is_fp32_pad = is_immutable_all_equal<dtype::Float32>(opr, 0);
  409. bool is_i32_pad = is_immutable_all_equal<dtype::Int32>(opr, 0);
  410. return is_fp32_pad || is_i32_pad;
  411. };
  412. SGM::Node::CallBack const_reshape_cbk = [&](OperatorNodeBase* opr) {
  413. const_reshape_last_dim_node = opr;
  414. return true;
  415. };
  416. auto&& shuffle_root = gen_pad_dimshuffle_graph(sub_128, const_pad_cbk,
  417. const_reshape_cbk);
  418. SGM::Node::CallBack cvt_q8_cbk = [](OperatorNodeBase* opr) {
  419. auto cvt_op = opr->try_cast_final<opr::TypeCvt>();
  420. if (cvt_op) {
  421. return cvt_op->param().enumv() == DTypeEnum::QuantizedS8;
  422. } else {
  423. return false;
  424. }
  425. };
  426. SGM::Node astype_root{
  427. opr::TypeCvt::typeinfo(), {{shuffle_root}}, cvt_q8_cbk};
  428. bool match = matcher.match(astype_root, opr);
  429. bool check_ok = false;
  430. if (match) {
  431. check_ok = SubGraphChecker().check(
  432. {src_node, neg_128_immu_node, pad0_immu_node,
  433. const_reshape_last_dim_node},
  434. opr, src_node, reader);
  435. }
  436. if (match && check_ok) {
  437. opr::RelayoutFormat::Param param;
  438. param.mode = opr::RelayoutFormat::Param::Mode::NCHW_NCHW4;
  439. OperatorNodeConfig config(opr->output()[0]->dtype());
  440. auto out_node = opr::RelayoutFormat::make(
  441. rewriter.get_var(src_node->output()[0]), param.mode,
  442. config);
  443. return out_node.node()->owner_opr();
  444. } else {
  445. return serialization::copy_opr_shallow(*opr, new_inp,
  446. opr->config());
  447. }
  448. };
  449. auto ret = std::make_unique<FuseNCHW4Int8Preprocess>();
  450. auto&& replace_func = ret->m_opr_replace_func;
  451. MGB_MARK_USED_VAR(replace_astype_opr);
  452. MGB_MARK_USED_VAR(replace_shuffle_opr);
  453. replace_func[opr::Dimshuffle::typeinfo()] = replace_shuffle_opr;
  454. replace_func[opr::TypeCvt::typeinfo()] = replace_astype_opr;
  455. return ret;
  456. }
  457. void FuseNCHW4Int8Preprocess::apply(OptState& state) const {
  458. state.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_DTYPE |
  459. VarReplaceCheckFlag::CHECK_SHAPE);
  460. auto rewriter = state.graph().make_rewriter();
  461. VarNodeArray new_inp_cache;
  462. ReaderType readers;
  463. state.graph().iter([&readers](OperatorNodeBase* opr) {
  464. for (auto&& i : opr->node_prop().dep_map()) {
  465. readers[i.first->owner_opr()].emplace_back(opr, i.second);
  466. }
  467. });
  468. auto on_opr = [this, &rewriter, &new_inp_cache,
  469. &readers](OperatorNodeBase* opr) {
  470. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  471. if (it != m_opr_replace_func.end()) {
  472. auto&& new_inp = new_inp_cache;
  473. new_inp.clear();
  474. new_inp.reserve(opr->input().size());
  475. for (auto i : opr->input()) {
  476. new_inp.push_back(rewriter.get_var(i));
  477. }
  478. auto new_opr = (it->second)(opr, new_inp, rewriter, readers);
  479. if (new_opr->try_cast_final<opr::RelayoutFormat>()) {
  480. auto &&origin_out = opr->output(),
  481. &&cur_out = new_opr->output();
  482. rewriter.replace_var(origin_out[0], cur_out[0], nullptr);
  483. } else {
  484. auto &&origin_out = opr->output(),
  485. &&cur_out = new_opr->output();
  486. mgb_assert(origin_out.size() == cur_out.size(),
  487. "bad opr replace: src=%s{%s} dst=%s{%s}, %zu != %zu",
  488. opr->cname(), opr->dyn_typeinfo()->name,
  489. new_opr->cname(), new_opr->dyn_typeinfo()->name,
  490. origin_out.size(), cur_out.size());
  491. for (size_t i = 0; i < origin_out.size(); i++) {
  492. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  493. }
  494. }
  495. } else {
  496. rewriter.auto_replace_outputs(opr);
  497. }
  498. };
  499. state.graph().iter(on_opr);
  500. rewriter.apply_inplace();
  501. }
  502. /* ==================== FuseWarpPerspectiveDimshufflePass ================= */
  503. const char* FuseWarpPerspectiveDimshufflePass::name() const {
  504. return mgb_cstr_log("Fuse warp perspective dimshuffle pass");
  505. }
  506. void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const {
  507. auto rewriter = opt.graph().make_rewriter();
  508. auto uniq_reader_check = UniqReaderCheck{opt.graph()};
  509. auto make_new_warp = [&rewriter](opr::WarpPerspective* warp,
  510. opr::WarpPerspective::Param new_param,
  511. megdnn::DType dst_dtype,
  512. SymbolVar& new_warp) {
  513. OperatorNodeConfig new_config = warp->config();
  514. new_config.output_dtype(dst_dtype);
  515. if (warp->input().size() == 3) {
  516. auto src = rewriter.get_var(warp->input(0)),
  517. mat = rewriter.get_var(warp->input(1)),
  518. out_shape = rewriter.get_var(warp->input(2));
  519. new_warp = opr::WarpPerspective::make(src, mat, out_shape,
  520. new_param, new_config);
  521. } else {
  522. mgb_assert(warp->input().size() == 4);
  523. auto src = rewriter.get_var(warp->input(0)),
  524. mat = rewriter.get_var(warp->input(1)),
  525. mat_idx = rewriter.get_var(warp->input(2)),
  526. out_shape = rewriter.get_var(warp->input(3));
  527. new_warp = opr::WarpPerspective::make(src, mat, mat_idx, out_shape,
  528. new_param, new_config);
  529. }
  530. };
  531. auto is_warp_nchw = [&uniq_reader_check](OperatorNodeBase* bottom_opr,
  532. OperatorNodeBase*& top_opr) {
  533. // check warp
  534. auto warp = try_cast_as_op<opr::WarpPerspective>(bottom_opr);
  535. if (warp == nullptr)
  536. return false;
  537. auto inp_dtype = warp->input(0)->dtype();
  538. bool is_u8_or_qu8 = inp_dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  539. inp_dtype.enumv() == DTypeEnum::Uint8;
  540. bool is_nchw = warp->param().format ==
  541. megdnn::param::WarpPerspective::Format::NCHW;
  542. if (!(is_u8_or_qu8 && is_nchw))
  543. return false;
  544. if (!uniq_reader_check(warp->input(0)))
  545. return false;
  546. top_opr = warp;
  547. return true;
  548. };
  549. auto is_warp_nhwc2nchw = [&uniq_reader_check](OperatorNodeBase* bottom_opr,
  550. OperatorNodeBase*& top_opr) {
  551. // check shuffle
  552. auto shuffle = try_cast_as_op<opr::Dimshuffle>(bottom_opr);
  553. if (shuffle == nullptr)
  554. return false;
  555. auto&& shuffle_param = shuffle->param();
  556. if (shuffle_param.pattern_len != 4)
  557. return false;
  558. bool is_nhwc2nchw = shuffle_param.pattern[0] == 0 &&
  559. shuffle_param.pattern[1] == 3 &&
  560. shuffle_param.pattern[2] == 1 &&
  561. shuffle_param.pattern[3] == 2;
  562. if (!is_nhwc2nchw)
  563. return false;
  564. if (!uniq_reader_check(shuffle->input(0)))
  565. return false;
  566. // check warp
  567. auto warp = try_cast_as_op<opr::WarpPerspective>(
  568. shuffle->input(0)->owner_opr());
  569. if (warp == nullptr)
  570. return false;
  571. auto inp_dtype = warp->input(0)->dtype();
  572. bool is_u8_or_qu8 = inp_dtype.enumv() == DTypeEnum::Quantized8Asymm ||
  573. inp_dtype.enumv() == DTypeEnum::Uint8;
  574. bool is_nhwc = warp->param().format ==
  575. megdnn::param::WarpPerspective::Format::NHWC;
  576. if (!(is_u8_or_qu8 && is_nhwc))
  577. return false;
  578. top_opr = warp;
  579. return true;
  580. };
  581. auto try_warp_nchw_typecvt = [&rewriter, &uniq_reader_check, &is_warp_nchw,
  582. &make_new_warp](OperatorNodeBase* opr) {
  583. // check typecvt
  584. auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
  585. if (typecvt == nullptr)
  586. return false;
  587. bool is_to_f32 =
  588. typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32;
  589. if (!is_to_f32)
  590. return false;
  591. if (!uniq_reader_check(typecvt->input(0)))
  592. return false;
  593. OperatorNodeBase* top_opr = nullptr;
  594. if (!is_warp_nchw(typecvt->input(0)->owner_opr(), top_opr))
  595. return false;
  596. auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr);
  597. SymbolVar new_warp;
  598. make_new_warp(warp, warp->param(), opr->output()[0]->dtype(), new_warp);
  599. rewriter.replace_var(opr->output(0), new_warp.node(),
  600. mgb_cstr_log("replace warp + typecvt"
  601. "fuse warp_dimshuffle(NCHW)"));
  602. return true;
  603. };
  604. auto try_warp_nhwc2nchw_typecvt = [&rewriter, &uniq_reader_check,
  605. &is_warp_nhwc2nchw,
  606. &make_new_warp](OperatorNodeBase* opr) {
  607. // check typecvt
  608. auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
  609. if (typecvt == nullptr)
  610. return false;
  611. bool is_to_f32 =
  612. typecvt->output(0)->dtype().enumv() == DTypeEnum::Float32;
  613. if (!is_to_f32)
  614. return false;
  615. if (!uniq_reader_check(typecvt->input(0)))
  616. return false;
  617. OperatorNodeBase* top_opr = nullptr;
  618. if (!is_warp_nhwc2nchw(typecvt->input(0)->owner_opr(), top_opr))
  619. return false;
  620. auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr);
  621. opr::WarpPerspective::Param new_param = warp->param();
  622. new_param.format = megdnn::param::WarpPerspective::Format::NHWC_NCHW;
  623. SymbolVar new_warp;
  624. make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp);
  625. rewriter.replace_var(
  626. opr->output(0), new_warp.node(),
  627. mgb_cstr_log("replace conv_bias + dimshuffle + "
  628. "typecvt to warp_dimshuffle(NHWC_NCHW)"));
  629. return true;
  630. };
  631. auto try_warp_nhwc2nchw4_typecvt = [&rewriter, &uniq_reader_check,
  632. &is_warp_nhwc2nchw,
  633. &make_new_warp](OperatorNodeBase* opr) {
  634. // check relayout
  635. auto relayout = try_cast_as_op<opr::RelayoutFormat>(opr);
  636. if (relayout == nullptr)
  637. return false;
  638. bool is_to_q8 =
  639. relayout->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
  640. bool is_to_nchw2nchw4 = relayout->param().mode ==
  641. opr::RelayoutFormat::Param::Mode::NCHW_NCHW4;
  642. if (!(is_to_q8 && is_to_nchw2nchw4))
  643. return false;
  644. if (!uniq_reader_check(relayout->input(0)))
  645. return false;
  646. OperatorNodeBase* top_opr = nullptr;
  647. if (!is_warp_nhwc2nchw(relayout->input(0)->owner_opr(), top_opr))
  648. return false;
  649. auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr);
  650. bool is_small_chn = warp->input(0)->shape()[3] < 4;
  651. if (!is_small_chn)
  652. return false;
  653. opr::WarpPerspective::Param new_param = warp->param();
  654. new_param.format =
  655. megdnn::param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL;
  656. SymbolVar new_warp;
  657. make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp);
  658. rewriter.replace_var(
  659. opr->output(0), new_warp.node(),
  660. mgb_cstr_log("replace warp + dimshuffle + relayout(NCHW_NCHW4)"
  661. "to warp_dimshuffle(NHWC_NCHW4_IC_SMALL)"));
  662. return true;
  663. };
  664. auto try_warp_nchw2nchw4_typecvt = [&rewriter, &uniq_reader_check,
  665. &is_warp_nchw,
  666. &make_new_warp](OperatorNodeBase* opr) {
  667. // check relayout
  668. auto relayout = try_cast_as_op<opr::RelayoutFormat>(opr);
  669. if (relayout == nullptr)
  670. return false;
  671. bool is_to_q8 =
  672. relayout->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
  673. bool is_to_nchw2nchw4 = relayout->param().mode ==
  674. opr::RelayoutFormat::Param::Mode::NCHW_NCHW4;
  675. if (!(is_to_q8 && is_to_nchw2nchw4))
  676. return false;
  677. if (!uniq_reader_check(relayout->input(0)))
  678. return false;
  679. OperatorNodeBase* top_opr = nullptr;
  680. if (!is_warp_nchw(relayout->input(0)->owner_opr(), top_opr))
  681. return false;
  682. auto warp = try_cast_as_op<opr::WarpPerspective>(top_opr);
  683. bool is_small_chn = warp->input(0)->shape()[1] < 4;
  684. if (!is_small_chn)
  685. return false;
  686. opr::WarpPerspective::Param new_param = warp->param();
  687. new_param.format =
  688. megdnn::param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL;
  689. SymbolVar new_warp;
  690. make_new_warp(warp, new_param, opr->output()[0]->dtype(), new_warp);
  691. rewriter.replace_var(
  692. opr->output(0), new_warp.node(),
  693. mgb_cstr_log("replace warp + relayout(NCHW_NCHW4)"
  694. "to warp_dimshuffle(NCHW_NCHW4_IC_SMALL)"));
  695. return true;
  696. };
  697. auto on_opr = [&try_warp_nchw_typecvt, &try_warp_nhwc2nchw_typecvt,
  698. &try_warp_nhwc2nchw4_typecvt, &try_warp_nchw2nchw4_typecvt,
  699. &rewriter](OperatorNodeBase* opr) {
  700. if (!try_warp_nhwc2nchw4_typecvt(opr) &&
  701. !try_warp_nchw2nchw4_typecvt(opr) &&
  702. !try_warp_nchw_typecvt(opr) && !try_warp_nhwc2nchw_typecvt(opr)) {
  703. rewriter.auto_replace_outputs(opr);
  704. }
  705. };
  706. opt.graph().iter(on_opr);
  707. rewriter.apply_inplace();
  708. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台