You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

blas.cpp 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736
  1. /**
  2. * \file src/opr/impl/blas.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/blas.h"
  12. #include "megbrain/common.h"
  13. #include "megbrain/comp_node_env.h"
  14. #include "megbrain/graph/grad_impl.h"
  15. #include "megbrain/opr/basic_arith_wrapper.h"
  16. #include "megbrain/opr/indexing.h"
  17. #include "megbrain/opr/tensor_gen.h"
  18. #include "megbrain/opr/tensor_manip.h"
  19. #include "megbrain/opr/search_policy/algo_chooser.h"
  20. #include "megbrain/opr/search_policy/profiler.h"
  21. #include "./internal/megdnn_opr_wrapper.inl"
  22. #include "./search_policy/workspace_need_limit_getter.inl"
  23. #include "megdnn/oprs/linalg.h"
  24. using namespace mgb;
  25. using namespace opr;
  26. namespace {
  27. int get_mask_from_matmul(const megdnn::param::MatrixMul& param) {
  28. return static_cast<int>(param.transposeA) +
  29. (static_cast<int>(param.transposeB) * 2);
  30. }
  31. }
  32. /* ================= MatrixMul ================= */
  33. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul);
  34. MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param,
  35. const ExecutionPolicy& policy,
  36. const OperatorNodeConfig& config)
  37. : Super{a->owner_graph(), config, "matrix_mul", {a, b}} {
  38. init_megdnn_opr(*this, param);
  39. m_policy = policy;
  40. add_input({a, b});
  41. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  42. }
  43. SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
  44. const ExecutionPolicy& policy,
  45. const OperatorNodeConfig& config) {
  46. return a.insert_single_output_opr<MatrixMul>(a.node(), b.node(), param,
  47. policy, config);
  48. }
  49. void MatrixMul::init_output_dtype() {
  50. DType output_dtype = config().output_dtype();
  51. megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(),
  52. output_dtype);
  53. output(0)->dtype(output_dtype);
  54. }
  55. MatrixMul::NodeProp* MatrixMul::do_make_node_prop() const {
  56. auto ret = Super::do_make_node_prop();
  57. ret->add_dep_type_existing_var(input(0),
  58. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  59. ret->add_dep_type_existing_var(input(1),
  60. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  61. return ret;
  62. }
  63. bool MatrixMul::check_layout(const TensorLayout& layout, int transpose) {
  64. mgb_assert(layout.ndim == 2, "input to MatrixMul must be 2-dim; got %s",
  65. layout.to_string().c_str());
  66. return layout.stride[0 ^ transpose] >=
  67. static_cast<ptrdiff_t>(layout.shape[1 ^ transpose]) &&
  68. layout.stride[1 ^ transpose] == 1;
  69. }
  70. void MatrixMul::add_input_layout_constraint() {
  71. auto check = [](const TensorLayout& ly) {
  72. return check_layout(ly, 0) || check_layout(ly, 1);
  73. };
  74. input(0)->add_layout_constraint(check);
  75. input(1)->add_layout_constraint(check);
  76. }
  77. size_t MatrixMul::get_workspace_size_bytes(
  78. const TensorShapeArray& input_shapes,
  79. const TensorShapeArray& output_shapes) const {
  80. // we may change transepose param in the impl, so get the max possible
  81. // workspace by trying all cases
  82. // current implementation in megdnn guarantees that workspaces in different
  83. // cases are on the same order of magnitude
  84. auto mo = megdnn_opr();
  85. auto&& tparam = mo->param();
  86. size_t a, b, c, d;
  87. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  88. TensorLayout i0(input_shapes[0], input(0)->dtype()),
  89. i1(input_shapes[1], input(1)->dtype()),
  90. out(output_shapes[0], output(0)->dtype());
  91. auto transpose = [](TensorLayout& dst, bool& param) {
  92. std::swap(dst.shape[0], dst.shape[1]);
  93. dst.stride[0] = dst[1];
  94. param ^= 1;
  95. };
  96. MGB_TRY {
  97. megdnn_opr()->execution_policy() = {};
  98. a = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
  99. megdnn_opr(), this);
  100. //! Here we just want to save the execution policy got from setup_algo,
  101. //! while change the delaration of get_workspace_in_bytes may cause
  102. //! many changes.
  103. const_cast<MatrixMul*>(this)
  104. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  105. megdnn_opr()->execution_policy();
  106. megdnn_opr()->execution_policy() = {};
  107. transpose(i0, tparam.transposeA);
  108. b = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
  109. megdnn_opr(), this);
  110. const_cast<MatrixMul*>(this)
  111. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  112. megdnn_opr()->execution_policy();
  113. megdnn_opr()->execution_policy() = {};
  114. transpose(i1, tparam.transposeB);
  115. c = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
  116. megdnn_opr(), this);
  117. const_cast<MatrixMul*>(this)
  118. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  119. megdnn_opr()->execution_policy();
  120. megdnn_opr()->execution_policy() = {};
  121. transpose(i0, tparam.transposeA);
  122. d = AlgoChooser<megdnn::MatrixMul>::setup_algo({i0, i1, out},
  123. megdnn_opr(), this);
  124. const_cast<MatrixMul*>(this)
  125. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  126. megdnn_opr()->execution_policy();
  127. megdnn_opr()->execution_policy() = {};
  128. }
  129. MGB_FINALLY({ tparam = this->param(); });
  130. return std::max(std::max(a, b), std::max(c, d));
  131. }
  132. void MatrixMul::scn_do_execute() {
  133. auto inp0 = input(0)->dev_tensor().as_megdnn(),
  134. inp1 = input(1)->dev_tensor().as_megdnn(),
  135. out = output(0)->dev_tensor().as_megdnn();
  136. if ((inp0.layout.is_empty() || inp1.layout.is_empty())) {
  137. if (!out.layout.is_empty()) {
  138. if (!m_fill_opr) {
  139. m_fill_opr = intl::get_megdnn_handle(comp_node())->
  140. create_operator<megdnn::Fill>();
  141. }
  142. m_fill_opr->param() = 0;
  143. m_fill_opr->exec(out, {});
  144. }
  145. return;
  146. }
  147. auto transpose = [](TensorLayout& layout, bool& trans) {
  148. if (!check_layout(layout, 0)) {
  149. mgb_assert(check_layout(layout, 1));
  150. std::swap(layout.shape[0], layout.shape[1]);
  151. std::swap(layout.stride[0], layout.stride[1]);
  152. trans ^= 1;
  153. }
  154. };
  155. auto&& tparam = megdnn_opr()->param();
  156. MGB_TRY {
  157. transpose(inp0.layout, tparam.transposeA);
  158. transpose(inp1.layout, tparam.transposeB);
  159. megdnn_opr()->execution_policy() =
  160. m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
  161. megdnn_opr()->exec(inp0, inp1, out,
  162. intl::get_megdnn_workspace_from_var(output(1)));
  163. }
  164. MGB_FINALLY({ tparam = this->param(); });
  165. }
  166. #if MGB_ENABLE_GRAD
  167. MGB_IMPL_OPR_GRAD(MatrixMul) {
  168. mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
  169. "only float data type supported for grad");
  170. SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]};
  171. if (wrt_idx == 0) {
  172. // A * B = C, A' = C' * Bt
  173. if (opr.param().transposeA) {
  174. grad = MatrixMul::make(i1, og, {opr.param().transposeB, true});
  175. } else {
  176. grad = MatrixMul::make(og, i1, {false, !opr.param().transposeB});
  177. }
  178. } else {
  179. mgb_assert(wrt_idx == 1);
  180. // A * B = C, B' = At * C'
  181. if (opr.param().transposeB) {
  182. grad = MatrixMul::make(og, i0, {true, opr.param().transposeA});
  183. } else {
  184. grad = MatrixMul::make(i0, og, {!opr.param().transposeA, false});
  185. }
  186. }
  187. return grad.node();
  188. }
  189. #endif
  190. /* ================= BatchedMatrixMul ================= */
  191. MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul);
  192. BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param,
  193. const ExecutionPolicy& policy,
  194. const OperatorNodeConfig& config)
  195. : Super{a->owner_graph(), config, "batched_matrix_mul", {a, b}} {
  196. init_megdnn_opr(*this, param);
  197. m_policy = policy;
  198. add_input({a, b});
  199. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  200. }
  201. SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
  202. const ExecutionPolicy& policy,
  203. const OperatorNodeConfig& config) {
  204. return a.insert_single_output_opr<BatchedMatrixMul>(a.node(), b.node(),
  205. param, policy, config);
  206. }
  207. void BatchedMatrixMul::add_input_layout_constraint() {
  208. auto check = [](const TensorLayout& ly) {
  209. mgb_assert(ly.ndim == 3,
  210. "input to BatchedMatrixMul must be 3-dim; got %s",
  211. ly.to_string().c_str());
  212. bool good_layout =
  213. ((ly.stride[0] >=
  214. static_cast<ptrdiff_t>(ly.shape[1] * ly.stride[1])) &&
  215. (ly.stride[0] >=
  216. static_cast<ptrdiff_t>(ly.shape[2] * ly.stride[2])));
  217. bool ret = good_layout &&
  218. (check_layout(ly, true) || check_layout(ly, false));
  219. return ret;
  220. };
  221. input(0)->add_layout_constraint(check);
  222. input(1)->add_layout_constraint(check);
  223. }
  224. void BatchedMatrixMul::init_output_dtype() {
  225. DType output_dtype = config().output_dtype();
  226. megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(),
  227. output_dtype);
  228. output(0)->dtype(output_dtype);
  229. }
  230. BatchedMatrixMul::NodeProp* BatchedMatrixMul::do_make_node_prop() const {
  231. auto ret = Super::do_make_node_prop();
  232. ret->add_dep_type_existing_var(input(0),
  233. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  234. ret->add_dep_type_existing_var(input(1),
  235. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  236. return ret;
  237. }
  238. bool BatchedMatrixMul::check_layout(const TensorLayout& layout,
  239. bool transpose) {
  240. int lhs = (transpose) ? 2 : 1, rhs = (transpose) ? 1 : 2;
  241. return (layout.stride[lhs] >= static_cast<ptrdiff_t>(layout.shape[rhs])) &&
  242. (layout.stride[rhs] == 1);
  243. }
  244. size_t BatchedMatrixMul::get_workspace_size_bytes(
  245. const TensorShapeArray& input_shapes,
  246. const TensorShapeArray& output_shapes) const {
  247. // we may change transepose param in the impl, so get the max possible
  248. // workspace by trying all cases
  249. // current implementation in megdnn guarantees that workspaces in different
  250. // cases are on the same order of magnitude
  251. auto mo = megdnn_opr();
  252. auto&& tparam = mo->param();
  253. size_t a, b, c, d;
  254. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  255. TensorLayout i0(input_shapes[0], input(0)->dtype()),
  256. i1(input_shapes[1], input(1)->dtype()),
  257. out(output_shapes[0], output(0)->dtype());
  258. auto transpose = [](TensorLayout& dst, bool& param) {
  259. std::swap(dst.shape[1], dst.shape[2]);
  260. dst.stride[1] = dst[2];
  261. param ^= 1;
  262. };
  263. MGB_TRY {
  264. megdnn_opr()->execution_policy() = {};
  265. a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
  266. {i0, i1, out}, megdnn_opr(), this);
  267. const_cast<BatchedMatrixMul*>(this)
  268. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  269. megdnn_opr()->execution_policy();
  270. megdnn_opr()->execution_policy() = {};
  271. transpose(i0, tparam.transposeA);
  272. b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
  273. {i0, i1, out}, megdnn_opr(), this);
  274. const_cast<BatchedMatrixMul*>(this)
  275. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  276. megdnn_opr()->execution_policy();
  277. megdnn_opr()->execution_policy() = {};
  278. transpose(i1, tparam.transposeB);
  279. c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
  280. {i0, i1, out}, megdnn_opr(), this);
  281. const_cast<BatchedMatrixMul*>(this)
  282. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  283. megdnn_opr()->execution_policy();
  284. megdnn_opr()->execution_policy() = {};
  285. transpose(i0, tparam.transposeA);
  286. d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
  287. {i0, i1, out}, megdnn_opr(), this);
  288. const_cast<BatchedMatrixMul*>(this)
  289. ->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
  290. megdnn_opr()->execution_policy();
  291. megdnn_opr()->execution_policy() = {};
  292. }
  293. MGB_FINALLY({ tparam = this->param(); });
  294. return std::max(std::max(a, b), std::max(c, d));
  295. }
  296. void BatchedMatrixMul::scn_do_execute() {
  297. auto inp0 = input(0)->dev_tensor().as_megdnn(),
  298. inp1 = input(1)->dev_tensor().as_megdnn(),
  299. out = output(0)->dev_tensor().as_megdnn();
  300. if ((inp0.layout.is_empty() || inp1.layout.is_empty())) {
  301. if (!out.layout.is_empty()) {
  302. if (!m_fill_opr) {
  303. m_fill_opr = intl::get_megdnn_handle(comp_node())->
  304. create_operator<megdnn::Fill>();
  305. }
  306. m_fill_opr->param() = 0;
  307. m_fill_opr->exec(out, {});
  308. }
  309. return;
  310. }
  311. auto transpose = [](TensorLayout& layout, bool& trans) {
  312. if (!check_layout(layout, false)) {
  313. mgb_assert(check_layout(layout, true));
  314. std::swap(layout.shape[1], layout.shape[2]);
  315. std::swap(layout.stride[1], layout.stride[2]);
  316. mgb_assert(layout.stride[2] == 1);
  317. trans ^= 1;
  318. }
  319. };
  320. auto&& tparam = megdnn_opr()->param();
  321. MGB_TRY {
  322. transpose(inp0.layout, tparam.transposeA);
  323. transpose(inp1.layout, tparam.transposeB);
  324. megdnn_opr()->execution_policy() =
  325. m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
  326. megdnn_opr()->exec(inp0, inp1, out,
  327. intl::get_megdnn_workspace_from_var(output(1)));
  328. }
  329. MGB_FINALLY({ tparam = this->param(); });
  330. }
  331. #if MGB_ENABLE_GRAD
  332. MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
  333. mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
  334. "only float data type supported for grad");
  335. mgb_assert(out_grad.size() == 2 && !out_grad[1]);
  336. SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]};
  337. if (wrt_idx == 0) {
  338. // A * B = C, A' = C' * Bt
  339. if (opr.param().transposeA) {
  340. grad = BatchedMatrixMul::make(
  341. i1, og, {opr.param().transposeB, true});
  342. } else {
  343. grad = BatchedMatrixMul::make(
  344. og, i1, {false, !opr.param().transposeB});
  345. }
  346. } else {
  347. mgb_assert(wrt_idx == 1);
  348. // A * B = C, B' = At * C'
  349. if (opr.param().transposeB) {
  350. grad = BatchedMatrixMul::make(
  351. og, i0, {true, opr.param().transposeA});
  352. } else {
  353. grad = BatchedMatrixMul::make(
  354. i0, og, {!opr.param().transposeA, false});
  355. }
  356. }
  357. return grad.node();
  358. }
  359. #endif
  360. /* ================= Dot ================= */
  361. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Dot);
  362. Dot::Dot(VarNode *opr0, VarNode *opr1, const OperatorNodeConfig &config):
  363. Super{opr0->owner_graph(), config, "dot", {opr0, opr1}}
  364. {
  365. init_megdnn_opr(*this, {});
  366. add_input({opr0, opr1}, AddInputSortType::CUR_ADDED);
  367. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  368. static_assert(std::is_empty<Param>::value, "Dot param should be empty");
  369. mgb_assert(opr0->dtype().category() != DTypeCategory::QUANTIZED &&
  370. opr1->dtype().category() != DTypeCategory::QUANTIZED,
  371. "Dot does not support quantized input.");
  372. }
  373. void Dot::init_output_static_infer_desc() {
  374. using namespace cg::static_infer;
  375. auto &&mgr = owner_graph()->static_infer_manager();
  376. auto infer_shp = [](TensorShape &dest, const InpVal &){
  377. dest = {1};
  378. return true;
  379. };
  380. auto infer_workspace = [this](TensorShape &dest, const InpVal &iv) {
  381. auto dtype = input(0)->dtype();
  382. TensorLayout ily(
  383. {std::max(
  384. iv.val[0].shape().total_nr_elems(),
  385. iv.val[1].shape().total_nr_elems())},
  386. dtype);
  387. dest.ndim = 1;
  388. dest.shape[0] = megdnn_opr()->get_workspace_in_bytes(
  389. ily, ily, {{1}, dtype});
  390. return true;
  391. };
  392. mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, infer_shp});
  393. mgr.register_shape_infer(output(1),
  394. {SourceType::DEP,
  395. {{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}},
  396. infer_workspace});
  397. }
  398. void Dot::scn_do_execute() {
  399. auto i0 = input(0)->dev_tensor().as_megdnn(),
  400. i1 = input(1)->dev_tensor().as_megdnn();
  401. mgb_throw_if(i0.layout.ndim != 1 || i1.layout.ndim != 1, GraphError,
  402. "Invalid input shapes for Dot: %s",
  403. cg::dump_var_info(input()).c_str());
  404. if (i0.layout.shape[0] != i1.layout.shape[0]) {
  405. bool s0 = i0.layout.shape[0] == 1, s1 = i1.layout.shape[0] == 1;
  406. mgb_throw_if(!s0 && !s1, GraphError,
  407. "Invalid input shapes for Dot: %s",
  408. cg::dump_var_info(input()).c_str());
  409. if (s0) {
  410. i0.layout.shape[0] = i1.layout.shape[0];
  411. i0.layout.stride[0] = 0;
  412. }
  413. else {
  414. i1.layout.shape[0] = i0.layout.shape[0];
  415. i1.layout.stride[0] = 0;
  416. }
  417. }
  418. if ((i0.layout.is_empty() || i1.layout.is_empty())) {
  419. if (!m_fill_opr) {
  420. m_fill_opr = intl::get_megdnn_handle(comp_node())->
  421. create_operator<megdnn::Fill>();
  422. }
  423. m_fill_opr->param() = 0;
  424. m_fill_opr->exec(output(0)->dev_tensor().as_megdnn(), {});
  425. return;
  426. }
  427. megdnn_opr()->exec(i0, i1, output(0)->dev_tensor().as_megdnn(),
  428. intl::get_megdnn_workspace_from_var(output(1)));
  429. }
  430. Dot::NodeProp* Dot::do_make_node_prop() const {
  431. auto ret = Super::do_make_node_prop();
  432. ret->add_dep_type_existing_var(input(0),
  433. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  434. ret->add_dep_type_existing_var(input(1),
  435. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  436. return ret;
  437. }
  438. void Dot::add_input_layout_constraint() {
  439. auto check = [](const TensorLayout &ly) {
  440. mgb_throw_if(ly.ndim != 1, GraphError,
  441. "Dot input must be 1-dim; got %s", ly.to_string().c_str());
  442. return ly.stride[0] >= 0;
  443. };
  444. input(0)->add_layout_constraint(check);
  445. input(1)->add_layout_constraint(check);
  446. }
  447. #if MGB_ENABLE_GRAD
  448. MGB_IMPL_OPR_GRAD(Dot) {
  449. auto other_input = opr.input(wrt_idx == 0 ? 1 : 0);
  450. auto ishp0 = opr::GetVarShape::make(opr.input(0)),
  451. ishp1 = opr::GetVarShape::make(opr.input(1));
  452. auto max_ishp = opr::GetVarShape::make({opr.input(0), opr.input(1)});
  453. return reduce_sum(
  454. Broadcast::make(mul(out_grad[0], other_input), max_ishp),
  455. wrt_idx ? ishp1 : ishp0).node();
  456. }
  457. #endif
  458. SymbolVar Dot::make(SymbolVar opr0, SymbolVar opr1,
  459. const OperatorNodeConfig &config) {
  460. return opr0.insert_single_output_opr<Dot>(opr0.node(), opr1.node(), config);
  461. }
  462. void Dot::record_execute_deps(ExecDependencyArray &deps) {
  463. record_megdnn_opr(deps);
  464. }
  465. /* ================= MatrixInverse ================= */
  466. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse);
  467. MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv")
  468. #if MGB_ENABLE_GRAD
  469. MGB_IMPL_OPR_GRAD(MatrixInverse) {
  470. SymbolVar a = opr.output(0);
  471. // TODO: use unified MatrixMul interface when we have it
  472. auto n = opr::Subtensor::make(a.symshape(),
  473. {opr::Subtensor::AxisIndexer::make_index(0, a.make_scalar(-1))}),
  474. tshp = opr::Concat::make({a.make_scalar(0), n, n}, 0),
  475. // our hard disk is limited so derivation of the gradient is omitted:)
  476. a_bnn = opr::Dimshuffle::make(opr::Reshape::make(a, tshp, 0),
  477. {0, 2, 1}),
  478. dy = opr::Reshape::make(out_grad.at(0), tshp, 0),
  479. da = - BatchedMatrixMul::make(BatchedMatrixMul::make(a_bnn, dy),
  480. a_bnn);
  481. return da.reshape(a.symshape()).node();
  482. }
  483. #endif
  484. /* ================= SVD ================= */
  485. MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD);
  486. SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) :
  487. Super(OperatorNodeBaseCtorParam{src->owner_graph(),
  488. config, "svd", {src}}) {
  489. mgb_throw_if(src->dtype() != megdnn::dtype::Float32(), MegDNNError,
  490. "Singular Value Decomposition on non-float32 tensors is not "
  491. "supoorted.");
  492. init_megdnn_opr(*this, param);
  493. add_input({src});
  494. if (!param.compute_uv) {
  495. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  496. .add_flag(VarNode::Flag::VOLATILE_CONTENT);
  497. output(2)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
  498. .add_flag(VarNode::Flag::VOLATILE_CONTENT);
  499. }
  500. }
  501. #if MGB_ENABLE_GRAD
  502. namespace {
  503. /*!
  504. * \brief a wrapper similar to SymbolVar but can safely contain nullptr as zero
  505. *
  506. * Note: here we introduce a new class of SymbolVar representation, which allows
  507. * nullptr to represent zero values, and overload other C++ operators
  508. * accordingly. Therefore we can avoid testing nullptr values everywhere in SVD
  509. * grad.
  510. *
  511. * This is a general approach. It can be moved to some header file if we
  512. * encounter another operator that also has complex gradient computation.
  513. */
  514. class SafeSymbolVar {
  515. VarNode* m_node;
  516. public:
  517. explicit SafeSymbolVar(VarNode* node) : m_node{node} {}
  518. SafeSymbolVar(SymbolVar x) : m_node{x.node()} {}
  519. SafeSymbolVar() : m_node{nullptr} {}
  520. VarNode* node() const { return m_node; }
  521. SymbolVar s() const { return m_node; }
  522. #define FWD(name) \
  523. template <typename... Args> \
  524. SafeSymbolVar name(Args&&... args) { \
  525. if (!m_node) \
  526. return {}; \
  527. return SymbolVar{m_node}.name(std::forward<Args>(args)...); \
  528. }
  529. FWD(reshape)
  530. FWD(broadcast)
  531. #undef FWD
  532. };
  533. SymbolVar unsafe(SymbolVar x) {
  534. return x;
  535. }
  536. SymbolVar unsafe(SafeSymbolVar x) {
  537. return x.s();
  538. }
  539. template <typename T>
  540. T reshape_anybatch(T x, SymbolVar tshp) {
  541. if (!x.node())
  542. return x;
  543. return opr::Reshape::make(unsafe(x), tshp, 0);
  544. }
  545. template <typename T>
  546. T trans(T x) {
  547. if (!x.node())
  548. return x;
  549. return opr::Dimshuffle::make(unsafe(x), {0, 2, 1});
  550. }
  551. template <typename T>
  552. T matmul(T a, T b, const opr::BatchedMatrixMul::Param& param = {}) {
  553. if (!a.node() || !b.node())
  554. return {};
  555. return opr::BatchedMatrixMul::make(unsafe(a), unsafe(b), param);
  556. }
  557. SafeSymbolVar matmuls(SafeSymbolVar x, SafeSymbolVar y,
  558. const opr::BatchedMatrixMul::Param& param = {}) {
  559. return matmul(x, y, param);
  560. }
  561. SafeSymbolVar operator-(SafeSymbolVar x) {
  562. if (x.node())
  563. return -x.s();
  564. return {};
  565. }
  566. #define OP(x, a_, b_) \
  567. SafeSymbolVar operator x(SafeSymbolVar a, SafeSymbolVar b) { \
  568. if (!a.node()) \
  569. return a_; \
  570. if (!b.node()) \
  571. return b_; \
  572. return a.s() x b.s(); \
  573. }
  574. OP(+, b, a)
  575. OP(-, -b, a)
  576. OP(*, {}, {})
  577. #undef OP
  578. } // anonymous namespace
  579. #endif
  580. #if MGB_ENABLE_GRAD
  581. MGB_IMPL_OPR_GRAD(SVD) {
  582. /**
  583. * The formula is copied from
  584. * https://j-towns.github.io/papers/svd-derivative.pdf
  585. * It is hard to compare m, n here, so I do not refer this paper :
  586. * http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
  587. */
  588. mgb_throw_if(!opr.param().compute_uv, MegBrainError,
  589. "Singular value decomposition gradient computation depends "
  590. "on U and V, please set compute_uv = True");
  591. SymbolVar a{opr.input(0)}, u_raw{opr.output(0)}, s_raw{opr.output(1)},
  592. vt_raw{opr.output(2)};
  593. SafeSymbolVar grad_u_raw{out_grad[0]}, grad_s_raw{out_grad[1]},
  594. grad_vt_raw{out_grad[2]};
  595. auto param10 = BatchedMatrixMul::Param{true, false},
  596. param00 = BatchedMatrixMul::Param{false, false},
  597. param01 = BatchedMatrixMul::Param{false, true};
  598. auto n = opr::Subtensor::make(a.symshape(),
  599. {opr::Subtensor::AxisIndexer::make_index(
  600. 0, a.make_scalar(-1))}),
  601. m = opr::Subtensor::make(a.symshape(),
  602. {opr::Subtensor::AxisIndexer::make_index(
  603. 0, a.make_scalar(-2))}),
  604. r = opr::Subtensor::make(s_raw.symshape(),
  605. {opr::Subtensor::AxisIndexer::make_index(
  606. 0, s_raw.make_scalar(-1))});
  607. SymbolVar sshp = opr::Concat::make({a.make_scalar(0), r}, 0),
  608. ushp = opr::Concat::make({a.make_scalar(0), m, r}, 0),
  609. vtshp = opr::Concat::make({a.make_scalar(0), r, n}, 0),
  610. u = reshape_anybatch(u_raw, ushp),
  611. vt = reshape_anybatch(vt_raw, vtshp), v = trans(vt);
  612. SafeSymbolVar grad_u = reshape_anybatch(grad_u_raw, ushp),
  613. grad_vt = reshape_anybatch(grad_vt_raw, vtshp),
  614. grad_v = trans(grad_vt);
  615. auto batches = opr::Subtensor::make(
  616. u.symshape(),
  617. {opr::Subtensor::AxisIndexer::make_index(0, u.make_scalar(-3))});
  618. auto brr = opr::Concat::make({batches, r, r}, 0);
  619. auto I_r = opr::Eye::make(r, {0, DTypeEnum::Float32})
  620. .reshape(opr::Concat::make({a.make_scalar(1), r, r}, 0))
  621. .broadcast(brr),
  622. filter_matrix = 1 - I_r;
  623. auto sf = reshape_anybatch(s_raw, sshp)
  624. .reshape(opr::Concat::make({batches, r, a.make_scalar(1)},
  625. 0))
  626. .broadcast(brr);
  627. auto grad_sf = reshape_anybatch(grad_s_raw, sshp)
  628. .reshape(opr::Concat::make(
  629. {batches, r, a.make_scalar(1)}, 0))
  630. .broadcast(brr);
  631. auto s = I_r * sf;
  632. auto grad_s = I_r * grad_sf;
  633. auto s_inv = 1 / (s + filter_matrix) - filter_matrix;
  634. auto s_rhs = sf * sf, s_mid = trans(s_rhs) - s_rhs,
  635. s_avoid_nan = s_mid + I_r, f = filter_matrix / s_avoid_nan;
  636. auto I_m = opr::Eye::make(m, {0, DTypeEnum::Float32})
  637. .reshape(opr::Concat::make({a.make_scalar(1), m, m}, 0))
  638. .broadcast(opr::Concat::make({batches, m, m}, 0)),
  639. I_n = opr::Eye::make(n, {0, DTypeEnum::Float32})
  640. .reshape(opr::Concat::make({a.make_scalar(1), n, n}, 0))
  641. .broadcast(opr::Concat::make({batches, n, n}, 0));
  642. auto ut_du = matmuls(u, grad_u, param10),
  643. vt_dv = matmuls(v, grad_v, param10);
  644. auto ret =
  645. matmuls(matmuls(matmuls(u, f * (ut_du - trans(ut_du))), s,
  646. param00) +
  647. matmuls(matmuls(I_m - matmul(u, u, param01),
  648. grad_u),
  649. s_inv),
  650. v, param01) +
  651. matmuls(matmuls(u, I_r * grad_s), v, param01) +
  652. matmuls(u, matmuls(matmuls(s, f * (vt_dv - trans(vt_dv)), param00),
  653. v, param01) +
  654. matmuls(matmuls(s_inv, grad_v, param01),
  655. I_n - matmul(v, v, param01)));
  656. return ret.reshape(a.symshape()).node();
  657. }
  658. #endif
  659. SymbolVarArray SVD::make(const SymbolVar& src, const Param& param,
  660. const OperatorNodeConfig& config) {
  661. auto&& out = src.node()
  662. ->owner_graph()
  663. ->insert_opr(std::make_unique<SVD>(src.node(), param,
  664. config))
  665. ->output();
  666. SymbolVarArray ret(out.size());
  667. for (size_t i = 0; i < ret.size(); i++) {
  668. ret[i] = out[i];
  669. }
  670. return ret;
  671. }
  672. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台