You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_manip.cpp 53 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581
  1. /**
  2. * \file src/opr/impl/tensor_manip.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/tensor_manip.h"
  12. #include "megbrain/opr/basic_arith.h"
  13. #include "megbrain/opr/param_defs.h"
  14. #include "megbrain/opr/utility.h"
  15. #include "megbrain/opr/io.h"
  16. #include "megbrain/graph/event.h"
  17. #include "megbrain/comp_node_env.h"
  18. #include "megbrain/utils/arith_helper.h"
  19. #include "megbrain/graph/grad_impl.h"
  20. #include "megbrain/graph/exc_extra_info.h"
  21. #include "./internal/megdnn_opr_wrapper.inl"
  22. using namespace mgb;
  23. using namespace opr;
  24. using namespace intl;
  25. /* f{{{ ======================= local utils ======================= */
  26. namespace {
  27. using OptionalAxis = megdnn::param::OptionalAxisV1;
  28. //! check whether shp is GetVarShape(a)
  29. bool check_is_shape_of(SymbolVar shp, SymbolVar a) {
  30. #if MGB_BUILD_SLIM_SERVING
  31. return false;
  32. #else
  33. auto op = shp.node()->owner_opr();
  34. if (op->same_type<GetVarShape>() && op->input().size() == 1 &&
  35. op->input()[0] == a.node() &&
  36. op->cast_final<GetVarShape>().param().axis ==
  37. OptionalAxis::INVALID_AXIS) {
  38. return true;
  39. }
  40. using namespace cg::static_infer;
  41. auto &&mgr = a.node()->owner_graph()->static_infer_manager();
  42. if ((mgr.get_infer_type(shp.node()).value & InferType::CONST) &&
  43. (mgr.get_infer_type(a.node()).shape & InferType::CONST)) {
  44. auto &&a_shp = mgr.infer_shape(a.node());
  45. auto &&shp_val = mgr.infer_value(shp.node());
  46. TensorShape shp_shp;
  47. cg::copy_tensor_value_to_shape(shp_shp, shp_val);
  48. return a_shp.eq_shape(shp_shp);
  49. }
  50. return false;
  51. #endif
  52. }
  53. #if !MGB_BUILD_SLIM_SERVING
  54. // return x such that shape_of(var) == x
  55. GetVarShape* get_shape_shortcut(VarNode *var) {
  56. auto opr = var->owner_opr();
  57. auto otype = opr->dyn_typeinfo();
  58. if (!(otype == Reshape::typeinfo() &&
  59. opr->cast_final<Reshape>().param().axis ==
  60. OptionalAxis::INVALID_AXIS) &&
  61. otype != Broadcast::typeinfo()) {
  62. return nullptr;
  63. }
  64. auto i1 = opr->input(1)->owner_opr();
  65. if (i1->same_type<GetVarShape>())
  66. return &i1->cast_final<GetVarShape>();
  67. return nullptr;
  68. }
  69. #endif
  70. } // anonymous namespace
  71. // f}}}
  72. /* f{{{ ======================= GetVarShape ======================= */
  73. MGB_DYN_TYPE_OBJ_FINAL_IMPL(GetVarShape);
  74. GetVarShape::GetVarShape(const VarNodeArrayView &inp, Param axis,
  75. const OperatorNodeConfig &config):
  76. Super(inp.at(0)->owner_graph(), config, "shape_of", inp),
  77. m_axis{axis}
  78. {
  79. m_src_shapes.resize(inp.size());
  80. for (auto i: inp)
  81. add_input({i});
  82. add_input({}, AddInputSortType::ALL);
  83. add_output(None)->dtype(dtype::Int32());
  84. add_equivalence_component<PODHash<Param>>(&m_axis);
  85. mgb_assert(abs(m_axis.axis) <= m_axis.MAX_NDIM);
  86. }
  87. void GetVarShape::update_cached_shape() {
  88. TensorShape ishp;
  89. if (m_src_shapes.size() == 1) {
  90. ishp = m_src_shapes[0];
  91. } else {
  92. megdnn::Elemwise::deduce_shape(m_src_shapes, ishp);
  93. }
  94. mgb_assert(ishp.ndim);
  95. // check whether m_cached_shape is valid and update it if not
  96. if (m_axis.axis != OptionalAxis::INVALID_AXIS) {
  97. int axis = m_axis.axis;
  98. if (axis < 0) {
  99. axis += ishp.ndim;
  100. }
  101. mgb_assert(axis >= 0 && axis < (int)ishp.ndim);
  102. if (m_cached_shape.ndim == 1 &&
  103. m_cached_shape.shape[0] == ishp.shape[axis])
  104. return;
  105. m_cached_shape = {ishp.shape[axis]};
  106. } else {
  107. if (m_cached_shape.eq_shape(ishp))
  108. return;
  109. m_cached_shape = ishp;
  110. }
  111. cg::copy_shape_to_tensor_value(m_cached_shape_cpu_v, m_cached_shape);
  112. m_cached_shape_dev_v_synced = false;
  113. }
  114. void GetVarShape::scn_do_execute() {
  115. for (size_t i = 0; i < m_src_shapes.size(); ++ i) {
  116. m_src_shapes[i] = input()[i]->shape();
  117. }
  118. update_cached_shape();
  119. if (!m_cached_shape_dev_v_synced) {
  120. m_cached_shape_dev_v.copy_from(m_cached_shape_cpu_v);
  121. m_cached_shape_dev_v_synced = true;
  122. }
  123. output(0)->dev_tensor().copy_from_fixlayout(m_cached_shape_dev_v);
  124. }
  125. void GetVarShape::update_for_static_infer(const cg::static_infer::InpVal &inp) {
  126. for (size_t i = 0; i < m_src_shapes.size(); ++ i) {
  127. m_src_shapes[i] = inp.val.at(i).shape();
  128. }
  129. update_cached_shape();
  130. }
  131. void GetVarShape::init_output_static_infer_desc() {
  132. using namespace cg::static_infer;
  133. auto infer_shape = [this](TensorShape &dest, const InpVal &inp) {
  134. update_for_static_infer(inp);
  135. dest = m_cached_shape_cpu_v.shape();
  136. return true;
  137. };
  138. auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
  139. update_for_static_infer(inp);
  140. dest = m_cached_shape_cpu_v;
  141. return true;
  142. };
  143. DepVal deps;
  144. for (auto i: input()) {
  145. deps.push_back({i, DepType::SHAPE});
  146. }
  147. auto &&mgr = owner_graph()->static_infer_manager();
  148. mgr.register_shape_infer(output(0),
  149. {SourceType::DEP, deps, infer_shape});
  150. mgr.register_value_infer(output(0),
  151. {SourceType::DEP, deps, infer_value});
  152. }
  153. #ifdef MGB_ENABLE_GRAD
  154. MGB_IMPL_OPR_GRAD(GetVarShape) {
  155. MGB_MARK_USED_VAR(wrt_idx);
  156. MGB_MARK_USED_VAR(out_grad);
  157. return nullptr;
  158. }
  159. #endif
  160. SymbolVar GetVarShape::make(const VarNodeArrayView& inp, Param param,
  161. const OperatorNodeConfig& config) {
  162. mgb_assert(!inp.empty());
  163. #if !MGB_BUILD_SLIM_SERVING
  164. // try to apply shortcut and omit scalar shapes to optimize
  165. VarNodeArray inp_vp;
  166. inp_vp.reserve(inp.size());
  167. auto&& mgr = inp[0]->owner_graph()->static_infer_manager();
  168. for (auto var : inp) {
  169. auto&& it = mgr.get_infer_type(var);
  170. if (it.shape & cg::static_infer::InferType::CONST) {
  171. if (mgr.infer_shape(var).is_scalar()) {
  172. // scalar does not affect broadcast result
  173. continue;
  174. }
  175. }
  176. if (auto opr = get_shape_shortcut(var)) {
  177. // current var replaced by a shortcut
  178. auto&& op_inp = opr->input();
  179. inp_vp.insert(inp_vp.end(), op_inp.begin(), op_inp.end());
  180. continue;
  181. }
  182. inp_vp.push_back(var);
  183. }
  184. if (inp_vp.empty()) {
  185. // all inputs are scalar
  186. mgb_assert(param.axis == OptionalAxis::INVALID_AXIS || param.axis == 0);
  187. return SymbolVar{inp[0]}.make_scalar(1);
  188. }
  189. #else
  190. auto&& inp_vp = inp;
  191. #endif
  192. return SymbolVar{inp[0]}.insert_single_output_opr<GetVarShape>(
  193. inp_vp, param, config);
  194. }
  195. cg::OperatorNodeBase::NodeProp* GetVarShape::do_make_node_prop() const {
  196. auto prop = Super::do_make_node_prop();
  197. using DT = NodeProp::DepType;
  198. SmallVector<DT> dt(input().size(), DT::SHAPE);
  199. prop->reset_dep_type(input(), dt);
  200. return prop;
  201. }
  202. class GetVarShape::ShapeDevValueExecDep final : public ExecDependency {
  203. DeviceTensorStorage m_val;
  204. public:
  205. explicit ShapeDevValueExecDep(DeviceTensorStorage val)
  206. : m_val(std::move(val)) {}
  207. };
  208. void GetVarShape::record_execute_deps(ExecDependencyArray& deps) {
  209. deps.emplace_back(std::make_unique<ShapeDevValueExecDep>(
  210. m_cached_shape_dev_v.storage()));
  211. }
  212. // f}}}
  213. /* f{{{ ======================= ReshapeBrdcastHelper ======================= */
  214. void ReshapeBrdcastHelper::reshapebrdcast_init(VarNode *inp, VarNode *tshp) {
  215. add_input({inp, tshp});
  216. add_output(None)->dtype(inp->dtype());
  217. if (reshapebrdcast_output_shape_need_input_shape())
  218. outshape_by_symvar_enable(1, 1);
  219. else
  220. outshape_by_symvar_enable(0, 1);
  221. }
  222. void ReshapeBrdcastHelper::mem_plan_fwd_in2out_readonly() {
  223. auto &&tshape = output(0)->shape();
  224. auto inp_layout = input(0)->layout();
  225. auto dst_layout = reshapebrdcast_get_dest_layout(inp_layout, tshape);
  226. if (!dst_layout.valid()) {
  227. // retry after making input contiguous
  228. mgb_assert(dyn_typeinfo() == Reshape::typeinfo());
  229. inp_layout.init_contiguous_stride(input(0)->shape());
  230. dst_layout = reshapebrdcast_get_dest_layout(inp_layout, tshape);
  231. mgb_assert(dst_layout.valid());
  232. m_rofwd_subspec = SubTensorSpec::make_from_layout(dst_layout.val());
  233. m_incompatible_inp_layout = true;
  234. return;
  235. }
  236. m_rofwd_subspec = SubTensorSpec::make_from_layout(dst_layout.val());
  237. m_incompatible_inp_layout = false;
  238. rofwd_init_mem_plan();
  239. }
  240. void ReshapeBrdcastHelper::outshape_by_symvar_do_get_output_shape(
  241. TensorShape &dest,
  242. const ShapeInferInfo &shpinfo) {
  243. if (reshapebrdcast_output_shape_need_input_shape()) {
  244. TensorShape oshp_given;
  245. cg::copy_tensor_value_to_shape(oshp_given,
  246. *shpinfo.shpval_inp_val.at(0));
  247. TensorLayout src;
  248. src.init_contiguous_stride(shpinfo.shape_inp_shp.at(0));
  249. dest = reshapebrdcast_get_dest_layout(src, oshp_given).val();
  250. } else {
  251. cg::copy_tensor_value_to_shape(dest, *shpinfo.shpval_inp_val.at(0));
  252. }
  253. }
  254. void ReshapeBrdcastHelper::scn_do_execute() {
  255. if (m_incompatible_inp_layout) {
  256. // only happens in reshape
  257. auto &&iv = input(0)->dev_tensor();
  258. auto ishp = iv.shape();
  259. auto &&ov = output(0)->dev_tensor();
  260. mgb_assert(ishp.total_nr_elems() == ov.shape().total_nr_elems());
  261. ov.sub(SubTensorSpec::make_from_layout({ishp, iv.dtype()})).
  262. copy_from_fixlayout(iv);
  263. } else
  264. rofwd_execute();
  265. }
  266. void ReshapeBrdcastHelper::add_input_layout_constraint() {
  267. if (!cg::is_static_var_value(input(1)))
  268. return;
  269. auto check_layout = [this](const TensorLayout &layout) {
  270. MGB_TRY {
  271. TensorShape oshp;
  272. outshape_by_symvar_do_get_output_shape(
  273. oshp, outshape_by_symvar_get_shape_infer_info());
  274. return reshapebrdcast_get_dest_layout(layout, oshp).valid();
  275. } MGB_CATCH(MegBrainError &exc, {
  276. if (!exc.extra_info())
  277. cg::OperatorNodeExcExtraInfo::record(this, exc);
  278. throw;
  279. })
  280. };
  281. input(0)->add_layout_constraint(check_layout);
  282. }
  283. void ReshapeBrdcastHelper::init_output_static_infer_desc() {
  284. Super::init_output_static_infer_desc();
  285. using namespace cg::static_infer;
  286. auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
  287. TensorShape oshp;
  288. cg::copy_tensor_value_to_shape(oshp, inp.val.at(1).value());
  289. auto &&iv = inp.val[0].value();
  290. auto sub_layout = reshapebrdcast_get_dest_layout(iv.layout(), oshp);
  291. if (sub_layout.valid()) {
  292. dest = const_cast<DeviceTensorND&>(iv).sub(
  293. SubTensorSpec::make_from_layout(sub_layout.val()));
  294. } else {
  295. // use contig dest
  296. dest = {};
  297. dest.copy_from(iv);
  298. sub_layout = reshapebrdcast_get_dest_layout(dest.layout(), oshp);
  299. mgb_assert(sub_layout.valid());
  300. dest = dest.sub(SubTensorSpec::make_from_layout(sub_layout.val()));
  301. }
  302. return true;
  303. };
  304. owner_graph()->static_infer_manager().register_value_infer(
  305. output(0), {SourceType::DEP,
  306. {{input(0), DepType::VALUE}, {input(1), DepType::VALUE}},
  307. infer_value});
  308. }
  309. // f}}}
  310. /* f{{{ ======================= Reshape ======================= */
  311. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Reshape);
  312. Reshape::Reshape(VarNode *inp, VarNode *tshp, Param unspec_axis,
  313. const OperatorNodeConfig &config):
  314. Super{inp->owner_graph(), config, "reshape", {inp}},
  315. m_unspec_axis{unspec_axis}
  316. {
  317. reshapebrdcast_init(inp, tshp);
  318. add_equivalence_component<PODHash<Param>>(&m_unspec_axis);
  319. }
  320. SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp,
  321. Param unspec_axis, const OperatorNodeConfig &config) {
  322. if (check_is_shape_of(tshp, inp))
  323. return inp;
  324. return inp.insert_single_output_opr<Reshape>(
  325. inp.node(), tshp.node(), unspec_axis, config);
  326. }
  327. #ifdef MGB_ENABLE_GRAD
  328. MGB_IMPL_OPR_GRAD(Reshape) {
  329. if (wrt_idx)
  330. return InvalidGrad::make(opr, wrt_idx);
  331. return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node();
  332. }
  333. #endif
  334. Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout(
  335. const TensorLayout &src, const TensorShape &tshape) const {
  336. if (m_unspec_axis.axis == OptionalAxis::INVALID_AXIS) {
  337. TensorLayout ret;
  338. if (src.try_reshape(ret, tshape))
  339. return ret;
  340. return None;
  341. }
  342. int original_unspec = m_unspec_axis.axis;
  343. if (original_unspec < 0) {
  344. original_unspec += tshape.ndim;
  345. }
  346. size_t unspec = original_unspec;
  347. mgb_assert(unspec < tshape.ndim);
  348. auto actual_tshape = tshape;
  349. size_t rem_nr_elem = 1;
  350. for (size_t i = 0; i < tshape.ndim; ++ i) {
  351. if (i != unspec)
  352. rem_nr_elem *= tshape.shape[i];
  353. }
  354. auto tot_nr_elem = src.total_nr_elems();
  355. actual_tshape.shape[unspec] = 0;
  356. mgb_throw_if(tot_nr_elem % rem_nr_elem, TensorReshapeError,
  357. "could not reshape: src=%s tshape=%s unspec_axis=%zd",
  358. static_cast<const TensorShape&>(src).to_string().c_str(),
  359. actual_tshape.to_string().c_str(),
  360. unspec);
  361. actual_tshape.shape[unspec] = tot_nr_elem / rem_nr_elem;
  362. TensorLayout ret;
  363. if (src.try_reshape(ret, actual_tshape))
  364. return ret;
  365. return None;
  366. }
  367. bool Reshape::reshapebrdcast_output_shape_need_input_shape() const {
  368. return m_unspec_axis.axis != OptionalAxis::INVALID_AXIS;
  369. }
  370. // f}}}
  371. /* f{{{ ======================= Broadcast ======================= */
  372. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast);
  373. Broadcast::Broadcast(VarNode *inp, VarNode *tshp,
  374. const OperatorNodeConfig &config):
  375. Super{inp->owner_graph(), config, "broadcast", {inp}}
  376. {
  377. reshapebrdcast_init(inp, tshp);
  378. }
  379. SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp,
  380. const OperatorNodeConfig &config) {
  381. if (check_is_shape_of(tshp, inp))
  382. return inp;
  383. return inp.insert_single_output_opr<Broadcast>(
  384. inp.node(), tshp.node(), config);
  385. }
  386. #ifdef MGB_ENABLE_GRAD
  387. MGB_IMPL_OPR_GRAD(Broadcast) {
  388. if (wrt_idx)
  389. return InvalidGrad::make(opr, wrt_idx);
  390. return Reduce::make(out_grad.at(0), Reduce::Mode::SUM,
  391. GetVarShape::make(opr.input(0))).node();
  392. }
  393. #endif
  394. Maybe<TensorLayout> Broadcast::reshapebrdcast_get_dest_layout(
  395. const TensorLayout &src, const TensorShape &tshape) const {
  396. return src.broadcast(tshape);
  397. }
  398. bool Broadcast::reshapebrdcast_output_shape_need_input_shape() const {
  399. return false;
  400. }
  401. // f}}}
  402. /* f{{{ ======================= AxisManipOprBase ======================= */
  403. void AxisManipOprBase::mem_plan_fwd_in2out_readonly() {
  404. m_rofwd_subspec = SubTensorSpec::make_from_layout(
  405. axis_manip_get_output_layout(input(0)->layout()));
  406. rofwd_init_mem_plan();
  407. }
  408. void AxisManipOprBase::scn_do_execute() {
  409. rofwd_execute();
  410. }
  411. void AxisManipOprBase::init_output_static_infer_desc() {
  412. using namespace cg::static_infer;
  413. auto &&mgr = owner_graph()->static_infer_manager();
  414. auto infer_shape = [this](TensorShape &dest, const InpVal &inp) {
  415. dest = axis_manip_get_output_layout({
  416. inp.val.at(0).shape(), input(0)->dtype()});
  417. return true;
  418. };
  419. auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
  420. auto &&iv = inp.val.at(0).value();
  421. auto oly = axis_manip_get_output_layout(iv.layout());
  422. dest = const_cast<DeviceTensorND&>(iv).sub(
  423. SubTensorSpec::make_from_layout(oly));
  424. return true;
  425. };
  426. mgr.register_shape_infer(output(0),
  427. {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
  428. mgr.register_value_infer(output(0),
  429. {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
  430. }
  431. // f}}}
  432. /* f{{{ ======================= Dimshuffle ======================= */
  433. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Dimshuffle);
  434. Dimshuffle::Dimshuffle(VarNode *inp, const std::vector<int> &pattern,
  435. size_t ndim, const OperatorNodeConfig &config):
  436. Super{inp->owner_graph(), config, "dimshuffle", {inp}},
  437. m_pattern(pattern),
  438. m_inp_ndim(ndim)
  439. {
  440. mgb_throw_if(m_pattern.size() > TensorShape::MAX_NDIM,
  441. GraphError, "Dimshuffle pattern exceeds max length of %zd",
  442. TensorShape::MAX_NDIM);
  443. for (auto i: m_pattern) {
  444. mgb_throw_if(i < -1 || i >= int(ndim), GraphError,
  445. "bad Dimshuffle pattern");
  446. }
  447. add_input({inp});
  448. add_output(None);
  449. add_equivalence_component<PODHash<int>>(m_pattern.data(), m_pattern.size());
  450. }
  451. SymbolVar Dimshuffle::make(
  452. SymbolVar inp, const std::vector<int> &pattern,
  453. size_t ndim, const OperatorNodeConfig &config) {
  454. if (!ndim)
  455. ndim = *std::max_element(pattern.begin(), pattern.end()) + 1;
  456. return inp.insert_single_output_opr<Dimshuffle>(inp.node(),
  457. pattern, ndim, config);
  458. }
  459. TensorLayout Dimshuffle::axis_manip_get_output_layout(
  460. const TensorLayout &ily) const {
  461. mgb_assert(ily.ndim == m_inp_ndim,
  462. "input ndim mismatch for Dimshuffle: expect=%zd actual=%zd",
  463. m_inp_ndim, ily.ndim);
  464. TensorLayout oly{ily.dtype};
  465. oly.ndim = m_pattern.size();
  466. size_t idx = 0;
  467. bool input_used[TensorLayout::MAX_NDIM] = {0};
  468. for (auto i: m_pattern) {
  469. if (i < 0) {
  470. oly.shape[idx] = 1;
  471. oly.stride[idx] = 1;
  472. } else {
  473. input_used[i] = true;
  474. oly.shape[idx] = ily.shape[i];
  475. oly.stride[idx] = ily.stride[i];
  476. }
  477. ++ idx;
  478. }
  479. for (size_t i = 0; i < m_inp_ndim; ++ i) {
  480. mgb_assert(input_used[i] || ily.shape[i] == 1,
  481. "non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
  482. static_cast<const TensorShape&>(ily).to_string().c_str(),
  483. i);
  484. }
  485. return oly;
  486. }
  487. VarNode* Dimshuffle::grad(
  488. size_t /*wrt_idx*/, const VarNodeArray &out_grad) const {
  489. std::vector<int> back(m_inp_ndim, -1);
  490. for (size_t i = 0; i < m_pattern.size(); i ++) {
  491. // outdim[i] is indim[j]
  492. auto j = m_pattern[i];
  493. if (j >= 0) {
  494. mgb_assert(back[j] == -1,
  495. "taking grad for Dimshuffle with duplicated "
  496. "input axis unsupported");
  497. back[j] = i;
  498. }
  499. }
  500. return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node();
  501. }
  502. #ifdef MGB_ENABLE_GRAD
  503. MGB_IMPL_OPR_GRAD(Dimshuffle) {
  504. return opr.grad(wrt_idx, out_grad);
  505. }
  506. #endif
  507. // f}}}
  508. /* f{{{ ======================= AxisAddRemove ======================= */
  509. MGB_DYN_TYPE_OBJ_FINAL_IMPL(AxisAddRemove);
  510. AxisAddRemove::AxisAddRemove(
  511. VarNode *inp, const std::vector<AxisDesc> &desc,
  512. const OperatorNodeConfig &config):
  513. Super{inp->owner_graph(), config, "axis_add_rm", {inp}},
  514. m_desc(desc)
  515. {
  516. mgb_throw_if(desc.empty(), GraphError,
  517. "desc for AxisAddRemove could not be empty");
  518. add_input({inp});
  519. add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  520. add_equivalence_component<PODHash<AxisDesc>>(m_desc.data(), m_desc.size());
  521. }
  522. SymbolVar AxisAddRemove::make(SymbolVar inp,
  523. const std::vector<AxisDesc> &desc,
  524. const OperatorNodeConfig &config) {
  525. return inp.insert_single_output_opr<AxisAddRemove>(inp.node(), desc, config);
  526. }
  527. TensorLayout AxisAddRemove::axis_manip_get_output_layout(
  528. const TensorLayout &input_layout) const {
  529. auto layout = input_layout;
  530. for (auto &&i: m_desc) {
  531. using M = AxisDesc::Method;
  532. switch (i.method) {
  533. case M::REMOVE:
  534. {
  535. auto axis = i.axis.get(layout.ndim);
  536. if (layout.ndim == 1) {
  537. mgb_assert(layout.shape[0] == 1 && axis == 0,
  538. "can not remove axis %zu from tensor of shape=%s",
  539. axis,
  540. layout.megdnn::TensorShape::to_string().c_str());
  541. } else {
  542. mgb_assert(axis < layout.ndim &&
  543. layout.shape[axis] == 1,
  544. "can not remove axis %zu from tensor of shape=%s",
  545. axis,
  546. layout.megdnn::TensorShape::to_string().c_str());
  547. layout.remove_axis_inplace(axis);
  548. }
  549. break;
  550. }
  551. case M::ADD_1:
  552. layout.add_axis_cont_inplace(i.axis.get(layout.ndim + 1));
  553. break;
  554. }
  555. }
  556. return layout;
  557. }
  558. AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const {
  559. auto ret = Super::do_make_node_prop();
  560. ret->add_dep_type_existing_var(input(0),
  561. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  562. return ret;
  563. }
  564. #ifdef MGB_ENABLE_GRAD
  565. MGB_IMPL_OPR_GRAD(AxisAddRemove) {
  566. MGB_MARK_USED_VAR(wrt_idx);
  567. return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node();
  568. }
  569. #endif
  570. // f}}}
  571. /* f{{{ ======================= Subtensor ======================= */
  572. MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true);
  573. #ifdef MGB_ENABLE_GRAD
  574. MGB_IMPL_OPR_GRAD(Subtensor) {
  575. if (wrt_idx)
  576. return InvalidGrad::make(opr, wrt_idx);
  577. return IncrSubtensor::make(
  578. SymbolVar{opr.input(0)}.fill_retain_dtype(0),
  579. out_grad.at(0), opr.index_desc()).node();
  580. }
  581. #endif
  582. void Subtensor::init_output_static_infer_desc() {
  583. using namespace cg::static_infer;
  584. DepVal deps;
  585. // shape inference only needs slices
  586. deps.push_back({input(0), DepType::SHAPE});
  587. for (size_t i = 1; i < m_input2idxonly_axis_indexer.size(); ++ i) {
  588. if (!m_input2idxonly_axis_indexer[i])
  589. deps.push_back({input(i), DepType::VALUE});
  590. }
  591. auto infer_shape = [this](TensorShape &dest, const InpVal &inp) {
  592. auto &&ishp = inp.val[0].shape();
  593. auto subspec = fancy_indexing_make_sub_spec(
  594. {ishp, input(0)->dtype()}, inp, 1, true);
  595. dest = subspec.layout();
  596. return true;
  597. };
  598. owner_graph()->static_infer_manager().register_shape_infer(
  599. output(0), {SourceType::DEP, deps, infer_shape});
  600. deps.clear();
  601. for (auto i: input())
  602. deps.push_back({i, DepType::VALUE});
  603. deps[0].type = DepType::VALUE;
  604. auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
  605. auto &&iv = inp.val[0].value();
  606. auto subspec = fancy_indexing_make_sub_spec(iv.layout(), inp, 1);
  607. dest = const_cast<DeviceTensorND&>(iv).sub(subspec);
  608. return true;
  609. };
  610. owner_graph()->static_infer_manager().register_value_infer(
  611. output(0), {SourceType::DEP, deps, infer_value});
  612. }
  613. void Subtensor::scn_do_execute() {
  614. rofwd_execute();
  615. }
  616. void Subtensor::mem_plan_fwd_in2out_readonly() {
  617. m_rofwd_subspec = fancy_indexing_make_sub_spec(input(0)->layout());
  618. rofwd_init_mem_plan();
  619. }
  620. void Subtensor::init_rt_force_dynamic_mem_alloc_imply_chain() {
  621. auto inp = input(0), out = output(0);
  622. inp->add_rt_force_dynamic_mem_alloc_imply_chain(out);
  623. out->add_rt_force_dynamic_mem_alloc_imply_chain(inp);
  624. }
  625. // f}}}
  626. /* f{{{ ================== ModifySubtensorImplHelper ================== */
  627. void ModifySubtensorImplHelper::scn_do_execute() {
  628. auto mod = fancy_indexing_get_tensors_for_modify_in_scn_do_execute();
  629. modify(mod.first, mod.second);
  630. }
  631. void ModifySubtensorImplHelper::init_output_static_infer_desc() {
  632. using namespace cg::static_infer;
  633. auto &&mgr = owner_graph()->static_infer_manager();
  634. // try to register shape infer with subtensor shape check
  635. auto try_infer_shape_with_check = [&]() -> bool{
  636. if (!cg::is_static_var_shape(input(0)) ||
  637. !cg::is_static_var_shape(input(1)))
  638. return false;
  639. for (size_t i = 2; i < input().size(); ++ i) {
  640. if (!cg::is_static_var_value(input(i)))
  641. return false;
  642. }
  643. auto infer_shape = [this](TensorShape &dest, const InpVal &inp) {
  644. dest = inp.val.at(0).shape();
  645. // throw exception if shapes mismatch
  646. auto subspec = fancy_indexing_make_sub_spec(
  647. {dest, input(0)->dtype()}, inp, 2);
  648. auto &&subshp = inp.val.at(1).shape();
  649. mgb_throw_if(!subspec.layout().eq_shape(subshp), TensorReshapeError,
  650. "SetSubtensor shape mismatch: subspec=%s value_shape=%s",
  651. subspec.layout().TensorShape::to_string().c_str(),
  652. subshp.to_string().c_str());
  653. return true;
  654. };
  655. DepVal deps;
  656. for (auto i: input())
  657. deps.push_back({i, DepType::VALUE});
  658. deps[0].type = deps[1].type = DepType::SHAPE;
  659. mgr.register_shape_infer(output(0), {
  660. SourceType::DEP, deps, infer_shape});
  661. return true;
  662. };
  663. if (has_input_tensor_replacer()) {
  664. mgr.register_shape_infer(output(0), ShapeInferDesc::make_const({}));
  665. } else {
  666. if (!try_infer_shape_with_check()) {
  667. auto infer_shape = [](TensorShape &dest, const InpVal &inp) {
  668. dest = inp.val.at(0).shape();
  669. return true;
  670. };
  671. mgr.register_shape_infer(output(0), {
  672. SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
  673. }
  674. }
  675. auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
  676. dest.copy_from(inp.val.at(0).value());
  677. auto subspec = fancy_indexing_make_sub_spec(dest.layout(), inp, 2);
  678. auto dsub = dest.sub(subspec);
  679. modify(dsub, inp.val.at(1).value());
  680. return true;
  681. };
  682. DepVal value_deps;
  683. for (auto i: input())
  684. value_deps.push_back({i, DepType::VALUE});
  685. mgr.register_value_infer(output(0), {
  686. SourceType::DEP, value_deps, infer_value});
  687. }
  688. // f}}}
  689. /* f{{{ ======================= SetSubtensor ======================= */
  690. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetSubtensor, "set_subtensor", true);
  691. void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
  692. sub.copy_from_fixlayout(val);
  693. }
  694. #ifdef MGB_ENABLE_GRAD
  695. MGB_IMPL_OPR_GRAD(SetSubtensor) {
  696. if (wrt_idx >= 2)
  697. return InvalidGrad::make(opr, wrt_idx);
  698. if (wrt_idx == 0) {
  699. return SetSubtensor::make(out_grad.at(0),
  700. SymbolVar{opr.input(1)}.fill_retain_dtype(0),
  701. opr.index_desc()).node();
  702. }
  703. return Subtensor::make(out_grad.at(0), opr.index_desc()).node();
  704. }
  705. #endif
  706. // f}}}
  707. /* f{{{ ======================= IncrSubtensor ======================= */
  708. MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrSubtensor, "incr_subtensor", true);
  709. void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) {
  710. CompNode opr_comp_node;
  711. if (sub.comp_node().locator().device ==
  712. CompNode::Locator::DEVICE_CPU_DEFAULT) {
  713. // for static infer
  714. opr_comp_node = CompNode::default_cpu();
  715. } else {
  716. opr_comp_node = comp_node();
  717. }
  718. auto opr = intl::get_megdnn_global_opr<megdnn::AddUpdate>(opr_comp_node);
  719. opr->exec(sub.as_megdnn(), val.as_megdnn());
  720. }
  721. #ifdef MGB_ENABLE_GRAD
  722. MGB_IMPL_OPR_GRAD(IncrSubtensor) {
  723. if (wrt_idx >= 2)
  724. return InvalidGrad::make(opr, wrt_idx);
  725. if (wrt_idx == 0) {
  726. return out_grad.at(0);
  727. }
  728. return Subtensor::make(out_grad.at(0), opr.index_desc()).node();
  729. }
  730. #endif
  731. // f}}}
  732. /* f{{{ ======================= IndexAt ======================= */
  733. SymbolVar IndexAt::make(SymbolVar inp,
  734. const std::vector<std::pair<size_t, SymbolVar>> &index,
  735. const OperatorNodeConfig &config) {
  736. Subtensor::IndexDesc desc;
  737. for (auto &&i: index) {
  738. desc.emplace_back();
  739. desc.back().axis = i.first;
  740. desc.back().idx = i.second;
  741. }
  742. return Subtensor::make(inp, desc, config);
  743. }
  744. // f}}}
  745. /* f{{{ ======================= Split ======================= */
  746. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Split);
  747. Split::Options Split::Options::make_average(int axis, size_t nr_part) {
  748. auto cb = [nr_part](size_t size) {
  749. std::vector<size_t> part(nr_part, size / nr_part);
  750. for (size_t i = 0, it = size % nr_part; i < it; ++ i)
  751. ++ part[i];
  752. return part;
  753. };
  754. return make_callback(axis, nr_part, cb);
  755. }
  756. Split::Options Split::Options::make_partition(int axis,
  757. const SymbolVarArray &partition) {
  758. mgb_assert(!partition.empty());
  759. Options rst;
  760. rst.method = Method::SPECIFY;
  761. rst.axis = axis;
  762. rst.partition = partition;
  763. return rst;
  764. }
  765. Split::Options Split::Options::make_partition(SymbolVar inp, int axis,
  766. const std::vector<size_t> &partition) {
  767. SymbolVarArray sym_partition;
  768. for (auto i: partition)
  769. sym_partition.push_back(inp.make_scalar(static_cast<int>(i)));
  770. return make_partition(axis, sym_partition);
  771. }
  772. Split::Options Split::Options::make_callback(
  773. int axis, size_t nr_part, callback_t callback) {
  774. mgb_assert(nr_part);
  775. Options rst;
  776. rst.method = Method::CALLBACK;
  777. rst.axis = axis;
  778. rst.callback = callback;
  779. rst.nr_part = nr_part;
  780. return rst;
  781. }
  782. SymbolVarArray Split::make(SymbolVar inp, Options opt,
  783. const OperatorNodeConfig &config) {
  784. SymbolVarArray ret;
  785. auto &&output = inp.node()->owner_graph()->insert_opr(
  786. std::make_unique<Split>(inp.node(), opt, config))->output();
  787. for (auto i: output) {
  788. ret.emplace_back(i);
  789. }
  790. return ret;
  791. }
  792. Split::Split(VarNode *inp, const Options &opt, const OperatorNodeConfig &config):
  793. Super{inp->owner_graph(), config, "split", {inp}},
  794. m_opt(opt)
  795. {
  796. add_input({inp});
  797. add_equivalence_component<ScalarHash<size_t>>(m_opt.axis);
  798. if (m_opt.method == Options::Method::SPECIFY) {
  799. mgb_assert(!m_opt.partition.empty());
  800. for (auto &&i: m_opt.partition)
  801. add_input({i.node()});
  802. outshape_by_symvar_enable(0, 1);
  803. m_opt.nr_part = m_opt.partition.size();
  804. } else {
  805. // disable dedup
  806. add_equivalence_component<ScalarHash<void*>>(this);
  807. mgb_assert(m_opt.method == Options::Method::CALLBACK);
  808. mgb_assert(m_opt.nr_part);
  809. }
  810. for (size_t i = 0; i < m_opt.nr_part; ++ i)
  811. add_output(ssprintf("o%zd", i))->dtype(inp->dtype());
  812. m_output_spec.resize(m_opt.nr_part);
  813. }
  814. void Split::init_output_static_infer_desc() {
  815. using namespace cg::static_infer;
  816. using namespace std::placeholders;
  817. auto &&mgr = owner_graph()->static_infer_manager();
  818. DepVal shp_deps{{input(0), DepType::SHAPE}};
  819. if (m_opt.method == Options::Method::SPECIFY) {
  820. for (size_t i = 1; i < input().size(); ++ i)
  821. shp_deps.push_back({input(i), DepType::VALUE});
  822. }
  823. auto infer_value = [this](size_t oidx,
  824. DeviceTensorND &dest, const InpVal &inp) {
  825. auto &&cur_shp = m_output_spec[oidx].shape;
  826. mgb_assert(cur_shp.eq_shape(inp.val[1].shape()));
  827. auto axis = m_opt.axis;
  828. if (axis < 0)
  829. axis += m_output_spec[0].shape.ndim;
  830. size_t offset = 0;
  831. for (size_t i = 0; i < oidx; ++ i)
  832. offset += m_output_spec[i].shape[axis];
  833. auto &&iv = inp.val[0].value();
  834. auto subspec = Slice(offset, offset + cur_shp[axis]).apply(
  835. iv.layout(), axis);
  836. dest.copy_from(const_cast<DeviceTensorND&>(iv).sub(subspec));
  837. return true;
  838. };
  839. for (size_t i = 0; i < output().size(); ++ i) {
  840. auto ov = output(i);
  841. mgr.register_shape_infer(ov,
  842. {SourceType::DEP, shp_deps, std::bind(
  843. &Split::infer_shape, this, i, _1, _2)});
  844. mgr.register_value_infer(ov, {
  845. SourceType::DEP,
  846. {{input(0), DepType::VALUE}, {ov, DepType::SHAPE}},
  847. std::bind(infer_value, i, _1, _2)});
  848. }
  849. }
  850. bool Split::infer_shape(size_t out_idx, TensorShape &dest,
  851. const cg::static_infer::InpVal &inp) {
  852. if (inp.run_id != m_output_shape_version) {
  853. std::vector<size_t> partition;
  854. auto ishp = inp.val.at(0).shape();
  855. auto axis = m_opt.axis;
  856. if (axis < 0)
  857. axis += ishp.ndim;
  858. if (m_opt.method == Options::Method::SPECIFY) {
  859. for (size_t i = 0; i < m_opt.nr_part; ++ i) {
  860. auto &&val = inp.val.at(i + 1).value();
  861. mgb_assert(val.shape().is_scalar(),
  862. "shapes for Split must be scalars");
  863. size_t cvt;
  864. static_cast_dtype_safe(&cvt, val.dtype(), val.raw_ptr());
  865. partition.push_back(cvt);
  866. }
  867. } else {
  868. partition = m_opt.callback(ishp.shape[axis]);
  869. mgb_assert(partition.size() == m_opt.nr_part,
  870. "nr_part=%zu but split callback returned %zu parts",
  871. m_opt.nr_part, partition.size());
  872. }
  873. size_t size = 0;
  874. for (size_t i = 0; i < m_opt.nr_part; ++ i) {
  875. auto p = partition[i];
  876. mgb_assert(p,
  877. "got zero partition size at part %zu, tot_size=%zu",
  878. i, ishp.shape[axis]);
  879. size += p;
  880. auto &&cur = m_output_spec[i].shape;
  881. cur = ishp;
  882. cur.shape[axis] = p;
  883. }
  884. mgb_assert(size == ishp.shape[axis],
  885. "split size sums to %zd, but shape at the axis is %zd",
  886. size, ishp.shape[axis]);
  887. m_output_shape_version = inp.run_id;
  888. }
  889. dest = m_output_spec.at(out_idx).shape;
  890. return true;
  891. }
  892. void Split::init_output_comp_node() {
  893. auto &&conf_node = config().comp_node();
  894. auto &&cn_opt = owner_graph()->seq_comp_node_optimizer();
  895. // details of each comp_node specified
  896. if (conf_node.size() > 1) {
  897. mgb_assert(conf_node.size() == output().size(),
  898. "number of CompNodes specified in config should equal to number"
  899. " of output, but got %zd configured CompNodes while there are"
  900. " %zd output (node_name=%s node_type=%s)",
  901. conf_node.size(), output().size(),
  902. cname(), dyn_typeinfo()->name);
  903. auto cn0 = input(0)->comp_node();
  904. for (size_t i = 0; i < output().size(); i ++) {
  905. auto dvar = output(i);
  906. dvar->comp_node(conf_node[i]);
  907. if (conf_node[i].mem_node() != cn0.mem_node())
  908. cn_opt.register_stream_var(
  909. dvar, {CompNode::Stream::COPY,
  910. cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
  911. }
  912. return;
  913. }
  914. CompNode cn;
  915. if (conf_node.size() == 1) {
  916. cn = conf_node[0];
  917. } else {
  918. cn = input(0)->comp_node();
  919. }
  920. for (auto i: output())
  921. i->comp_node(cn);
  922. if (cn.mem_node() != input(0)->comp_node().mem_node()) {
  923. for (auto i: output())
  924. cn_opt.register_stream_var(
  925. i, {CompNode::Stream::COPY,
  926. cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
  927. }
  928. }
  929. cg::OperatorNodeBase::NodeProp* Split::do_make_node_prop() const {
  930. auto rst = OperatorNodeBase::do_make_node_prop();
  931. rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  932. outshape_by_symvar_reset_node_dep_type(rst);
  933. return rst;
  934. }
  935. void Split::do_execute(ExecEnv &env) {
  936. for (size_t idx = 0; idx < output().size(); ++ idx) {
  937. auto out = output(idx);
  938. if (!owner_graph()->var_receiver_in_current_comp_seq(out
  939. ).value_needed())
  940. continue;
  941. auto runner = [idx, this]() {
  942. auto &&in = input(0)->dev_tensor();
  943. auto &&out = output(idx)->dev_tensor();
  944. auto &&spec = m_output_spec.at(idx);
  945. owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
  946. this, out.comp_node());
  947. if (spec.mem_fwd_success) {
  948. mgb_assert(out.raw_ptr() ==
  949. in.raw_ptr() + spec.subspec.offset_byte());
  950. } else {
  951. out.comp_node().activate();
  952. out.copy_from_fixlayout(in.sub(spec.subspec));
  953. }
  954. owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
  955. this, out.comp_node());
  956. };
  957. env.dispatch_on_comp_node(out->comp_node(), runner);
  958. }
  959. }
  960. #ifdef MGB_ENABLE_GRAD
  961. MGB_IMPL_OPR_GRAD(Split) {
  962. if (wrt_idx)
  963. return InvalidGrad::make(opr, wrt_idx);
  964. mgb_assert(out_grad.size() == opr.output().size());
  965. SymbolVarArray grad;
  966. for (size_t i = 0; i < out_grad.size(); ++ i) {
  967. auto gval = out_grad[i];
  968. if (!gval) {
  969. gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node();
  970. }
  971. grad.emplace_back(gval);
  972. }
  973. return Concat::make(grad, opr.options().axis,
  974. OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node();
  975. }
  976. #endif
  977. void Split::mem_plan_fwd_in2out_readonly() {
  978. m_readonly_fwd_called = true;
  979. init_subspec(true);
  980. }
  981. void Split::init_subspec(bool memfwd) {
  982. auto in = input(0);
  983. size_t begin = 0, end = 0;
  984. for (size_t i = 0; i < output().size(); ++ i) {
  985. auto &&spec = m_output_spec[i];
  986. auto out = output(i);
  987. auto real_axis = m_opt.axis;
  988. if (real_axis < 0)
  989. real_axis += spec.shape.ndim;
  990. begin = end;
  991. mgb_assert(out->shape().eq_shape(spec.shape));
  992. end = begin + spec.shape.shape[real_axis];
  993. spec.subspec = Slice(begin, end).apply(in->layout(), real_axis);
  994. if (out->comp_node() == in->comp_node() && memfwd) {
  995. spec.mem_fwd_success = out->set_fwd_in2out_readonly(
  996. in, spec.subspec);
  997. } else {
  998. spec.mem_fwd_success = false;
  999. }
  1000. }
  1001. }
  1002. void Split::outshape_by_symvar_do_get_output_shape(
  1003. TensorShape &dest, const ShapeInferInfo &shpinfo) {
  1004. // shape infer handled in this class
  1005. MGB_MARK_USED_VAR(dest);
  1006. MGB_MARK_USED_VAR(shpinfo);
  1007. mgb_assert(0);
  1008. }
  1009. void Split::add_input_layout_constraint() {
  1010. m_readonly_fwd_called = false;
  1011. auto cn = input(0)->comp_node();
  1012. for (auto i: output())
  1013. if (i->comp_node() != cn) {
  1014. input(0)->add_layout_constraint_contiguous();
  1015. return;
  1016. }
  1017. }
  1018. void Split::on_mem_status_changed() {
  1019. if (!m_readonly_fwd_called) {
  1020. init_subspec(false);
  1021. }
  1022. }
  1023. cg::OperatorNodeBase::OprEventCallback
  1024. Split::get_opr_event_callback() {
  1025. return {std::bind(&Split::on_mem_status_changed, this)};
  1026. }
  1027. void Split::on_output_comp_node_stream_changed() {
  1028. }
  1029. void Split::init_rt_force_dynamic_mem_alloc_imply_chain() {
  1030. auto inp = input(0);
  1031. auto cn0 = inp->comp_node();
  1032. for (auto i: output()) {
  1033. if (i->comp_node() == cn0) {
  1034. i->add_rt_force_dynamic_mem_alloc_imply_chain(inp);
  1035. inp->add_rt_force_dynamic_mem_alloc_imply_chain(i);
  1036. }
  1037. }
  1038. }
  1039. // f}}}
  1040. /* f{{{ ======================= Concat ======================= */
  1041. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Concat);
  1042. Concat::Concat(const VarNodeArrayView &inp, int axis,
  1043. const OperatorNodeConfig &config):
  1044. Super{inp[0]->owner_graph(), config, "concat", inp},
  1045. m_axis(axis)
  1046. {
  1047. mgb_assert(!inp.empty());
  1048. for (auto &&i : inp) {
  1049. add_input({i});
  1050. }
  1051. add_equivalence_component<ScalarHash<size_t>>(m_axis);
  1052. add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  1053. }
  1054. void Concat::get_output_var_shape(
  1055. const TensorShapeArray &inp_shape,
  1056. TensorShapeArray &out_shape) const {
  1057. mgb_assert(inp_shape.size() == input().size());
  1058. mgb_assert(out_shape.size() == 1);
  1059. auto &&oshp = out_shape[0];
  1060. oshp = inp_shape[0];
  1061. mgb_throw_if(m_axis >= static_cast<int>(oshp.ndim) ||
  1062. m_axis < -static_cast<int>(oshp.ndim),
  1063. GraphError, "concat axis out of bound: input_ndim=%zu axis=%d",
  1064. oshp.ndim, m_axis);
  1065. auto real_axis = m_axis;
  1066. if (real_axis < 0)
  1067. real_axis += oshp.ndim;
  1068. for (size_t i = 1; i < inp_shape.size(); ++ i) {
  1069. auto &&tmp = inp_shape[i];
  1070. mgb_throw_if(oshp.ndim != tmp.ndim, GraphError,
  1071. "ndim mismatch: shape=%s inp[%zd]=%s",
  1072. oshp.to_string().c_str(), i, tmp.to_string().c_str());
  1073. for (int n = 0; n < static_cast<int>(tmp.ndim); ++ n) {
  1074. if (n == real_axis) {
  1075. oshp.shape[n] += tmp.shape[n];
  1076. } else {
  1077. mgb_throw_if(oshp.shape[n] != tmp.shape[n], GraphError,
  1078. "Concat input shapes mismatch: "
  1079. "accum_out_shape=%s cur_inp_shape=%s inp_idx=%zu"
  1080. " axis_concat=%d axis_mismatch=%d",
  1081. oshp.to_string().c_str(), tmp.to_string().c_str(), i,
  1082. real_axis, n);
  1083. }
  1084. }
  1085. }
  1086. }
  1087. SymbolVar Concat::make(const VarNodeArrayView& inp, int axis,
  1088. const OperatorNodeConfig& config) {
  1089. mgb_assert(!inp.empty());
  1090. if (inp.size() == 1)
  1091. return inp[0];
  1092. intl::BatchedDTypePromotion dtp{inp};
  1093. return SymbolVar{inp[0]}.insert_single_output_opr<Concat>(dtp.get_vars(),
  1094. axis, config);
  1095. }
  1096. #ifdef MGB_ENABLE_GRAD
  1097. MGB_IMPL_OPR_GRAD(Concat) {
  1098. auto axis = opr.axis();
  1099. mgb_assert(out_grad.size() == 1);
  1100. OperatorNodeConfig::CompNodeArray comp_node;
  1101. SymbolVarArray partition;
  1102. for (auto i : opr.input()) {
  1103. partition.push_back(GetVarShape::make(i, axis));
  1104. comp_node.push_back(i->comp_node());
  1105. }
  1106. auto ret = Split::make(out_grad[0],
  1107. Split::Options::make_partition(axis, partition),
  1108. OperatorNodeConfig().comp_node_arr(comp_node));
  1109. return cg::to_var_node_array(ret);
  1110. }
  1111. #endif
  1112. void Concat::scn_do_execute() {
  1113. auto&& out = output(0)->dev_tensor();
  1114. size_t end = 0;
  1115. for (auto&& input : this->input()) {
  1116. auto&& in = input->dev_tensor();
  1117. auto begin = end;
  1118. auto real_axis = m_axis;
  1119. if (real_axis < 0)
  1120. real_axis += in.shape().ndim;
  1121. end = begin + in.shape().shape[real_axis];
  1122. out.sub(Slice(begin, end).apply(out.layout(), real_axis)).
  1123. copy_from_fixlayout(in);
  1124. }
  1125. }
  1126. Concat::NodeProp* Concat::do_make_node_prop() const {
  1127. auto rst = Super::do_make_node_prop();
  1128. rst->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
  1129. for (auto i: input()) {
  1130. rst->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY);
  1131. }
  1132. return rst;
  1133. }
  1134. void Concat::init_output_static_infer_desc() {
  1135. Super::init_output_static_infer_desc();
  1136. using namespace cg::static_infer;
  1137. auto infer_value = [this](
  1138. DeviceTensorND &dest, const InpVal& inp) {
  1139. TensorShape oshp = inp.val[0].shape();
  1140. auto real_axis = m_axis;
  1141. if (real_axis < 0)
  1142. m_axis += oshp.ndim;
  1143. for (size_t i = 1; i < input().size(); ++ i)
  1144. oshp.shape[real_axis] += inp.val.at(i).shape().shape[real_axis];
  1145. dest.resize(oshp);
  1146. size_t end = 0;
  1147. for (size_t i = 0; i < input().size(); ++ i) {
  1148. auto begin = end;
  1149. end = begin + inp.val[i].shape().shape[real_axis];
  1150. dest.sub(Slice(begin, end).apply(dest.layout(), real_axis)).
  1151. copy_from_fixlayout(inp.val[i].value());
  1152. }
  1153. return true;
  1154. };
  1155. DepVal deps;
  1156. for (auto i: input())
  1157. deps.push_back({i, DepType::VALUE});
  1158. owner_graph()->static_infer_manager().register_value_infer(
  1159. output(0),
  1160. {SourceType::DEP, deps, infer_value});
  1161. }
  1162. void Concat::add_input_layout_constraint() {
  1163. auto cn = output(0)->comp_node();
  1164. for (auto i: input()) {
  1165. if (i->comp_node() != cn) {
  1166. i->add_layout_constraint_contiguous();
  1167. }
  1168. }
  1169. }
  1170. void Concat::init_output_comp_node() {
  1171. Super::init_output_comp_node();
  1172. auto dcn = output(0)->comp_node();
  1173. for (auto i: input()) {
  1174. if (i->comp_node().mem_node() != dcn.mem_node()) {
  1175. owner_graph()->seq_comp_node_optimizer().register_stream_var(
  1176. output(0),
  1177. {CompNode::Stream::COPY,
  1178. cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
  1179. return;
  1180. }
  1181. }
  1182. }
  1183. // f}}}
  1184. /* f{{{ ======================= ParamPackConcat ======================= */
  1185. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat);
  1186. ParamPackConcat::ParamPackConcat(VarNodeArray& inp, VarNode* table,
  1187. const std::vector<dt_int32> offsets_val,
  1188. const OperatorNodeConfig& config)
  1189. : Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp),
  1190. m_offsets(offsets_val) {
  1191. CompNode cn = inp[0]->comp_node();
  1192. add_input({inp[0]});
  1193. for (size_t i = 1; i < inp.size(); i++) {
  1194. add_input({inp[i]});
  1195. mgb_assert(cn == inp[i]->comp_node(),
  1196. "input var for param pack must in same comp node");
  1197. }
  1198. add_input({table});
  1199. add_output(None);
  1200. cg::add_workspace_output(this);
  1201. m_opr = intl::create_megdnn_opr<megdnn::ParamPackConcat>(cn);
  1202. }
  1203. void ParamPackConcat::add_input_layout_constraint(){
  1204. for (auto i: input()) {
  1205. i->add_layout_constraint_contiguous();
  1206. }
  1207. }
  1208. SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar>& inp,
  1209. const SymbolVar& offsets,
  1210. const std::vector<dt_int32> offsets_val,
  1211. const OperatorNodeConfig& config) {
  1212. VarNodeArray array(inp.size());
  1213. for (size_t i = 0; i < inp.size(); i++) {
  1214. array[i] = inp[i].node();
  1215. }
  1216. return inp.front().insert_single_output_opr<ParamPackConcat>(
  1217. array, offsets.node(), offsets_val, config);
  1218. }
  1219. void ParamPackConcat::scn_do_execute() {
  1220. mgb_assert(m_opr.comp_node() == comp_node());
  1221. auto&& inputs = input();
  1222. m_inp_ptr.resize(inputs.size() - 1);
  1223. auto ptr = m_inp_ptr.data();
  1224. for (size_t i = 0; i < inputs.size() - 1; i++) {
  1225. ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr;
  1226. }
  1227. auto offsets = inputs.back()->dev_tensor().as_megdnn();
  1228. megdnn::TensorND srcs(
  1229. ptr, megdnn::TensorLayout({inputs.size() - 1}, dtype::Int32()));
  1230. auto&& dst = output(0)->dev_tensor().as_megdnn();
  1231. m_opr->exec(srcs, offsets, dst, get_megdnn_workspace_from_var(output(1)));
  1232. }
  1233. void ParamPackConcat::init_output_dtype() {
  1234. output(0)->dtype(input(0)->dtype());
  1235. }
  1236. void ParamPackConcat::init_output_static_infer_desc(){
  1237. using namespace cg::static_infer;
  1238. auto &&mgr = owner_graph()->static_infer_manager();
  1239. auto infer_out = [this](TensorShape& dest, const InpVal& inp) {
  1240. dest = {static_cast<unsigned int>(m_offsets.back())};
  1241. return true;
  1242. };
  1243. DepVal shp_deps;
  1244. shp_deps.reserve(input().size());
  1245. for(auto&& inp : input()){
  1246. shp_deps.emplace_back(DepElement{inp, DepType::SHAPE});
  1247. }
  1248. auto infer_wk = [this](TensorShape &dest, const InpVal &inp) {
  1249. TensorShapeArray shapes;
  1250. auto vals = inp.val;
  1251. shapes.reserve(vals.size() - 1);
  1252. for(size_t i = 0; i < vals.size() - 1; i++){
  1253. shapes.push_back(vals[i].shape());
  1254. }
  1255. dest = {m_opr->get_workspace_in_bytes(shapes, vals.back().shape(),
  1256. dest)};
  1257. return true;
  1258. };
  1259. mgr.register_shape_infer(output(0), {SourceType::DEP, shp_deps, infer_out});
  1260. mgr.register_shape_infer(output(1), {SourceType::DEP, shp_deps, infer_wk});
  1261. }
  1262. void ParamPackConcat::on_output_comp_node_stream_changed(){
  1263. Super::on_output_comp_node_stream_changed();
  1264. m_opr = intl::create_megdnn_opr<megdnn::ParamPackConcat>(comp_node());
  1265. }
  1266. // f}}}
  1267. /* f{{{ ======================= ParamPackSplit ======================= */
  1268. MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit);
  1269. ParamPackSplit::ParamPackSplit(VarNode* src,
  1270. const std::vector<dt_int32> offsets,
  1271. TensorShapeArray& shapes,
  1272. const OperatorNodeConfig& config)
  1273. : Super{src->owner_graph(), config, "ParamPackSplit", {src}},
  1274. m_shapes(shapes), m_offsets(offsets) {
  1275. add_input({src});
  1276. for (size_t i = 0; i < shapes.size(); i++) {
  1277. mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!");
  1278. add_output(ssprintf("param_pack_o%zu", i))
  1279. ->dtype(src->dtype()).shape(shapes[i]);
  1280. }
  1281. }
  1282. void ParamPackSplit::add_input_layout_constraint(){
  1283. input(0)->add_layout_constraint_contiguous();
  1284. }
  1285. SymbolVarArray ParamPackSplit::make(const SymbolVar& src,
  1286. const std::vector<dt_int32> offsets,
  1287. TensorShapeArray shapes,
  1288. const OperatorNodeConfig& config) {
  1289. auto&& out = src.node()
  1290. ->owner_graph()
  1291. ->insert_opr(std::make_unique<ParamPackSplit>(
  1292. src.node(), offsets,
  1293. shapes, config))
  1294. ->output();
  1295. SymbolVarArray ret;
  1296. ret.resize(out.size());
  1297. for (size_t i = 0; i < ret.size(); ++i) {
  1298. ret[i] = out[i];
  1299. }
  1300. return ret;
  1301. }
  1302. void ParamPackSplit::init_output_dtype() {
  1303. // already initialized in constructor
  1304. }
  1305. void ParamPackSplit::mem_plan_fwd_in2out_readonly() {
  1306. mgb_assert(m_offsets.size() == output().size() * 2);
  1307. for (size_t i = 0; i < output().size(); i++) {
  1308. auto layout = output(i)->layout();
  1309. auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i * 2]);
  1310. mgb_assert(output(i)->set_fwd_in2out_readonly(input(0), spec));
  1311. }
  1312. }
  1313. bool ParamPackSplit::infer_shape(size_t index, TensorShape& dest,
  1314. const cg::static_infer::InpVal& inp) {
  1315. dest = m_shapes[index];
  1316. return true;
  1317. }
  1318. void ParamPackSplit::init_output_static_infer_desc() {
  1319. using namespace cg::static_infer;
  1320. using namespace std::placeholders;
  1321. auto&& mgr = owner_graph()->static_infer_manager();
  1322. DepVal shp_deps{{input(0), DepType::SHAPE}};
  1323. for (size_t i = 0; i < output().size(); i++) {
  1324. auto ov = output(i);
  1325. mgr.register_shape_infer(
  1326. ov, {SourceType::DEP, shp_deps,
  1327. std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)});
  1328. }
  1329. }
  1330. #ifdef MGB_ENABLE_GRAD
  1331. MGB_IMPL_OPR_GRAD(ParamPackSplit) {
  1332. mgb_assert(out_grad.size() == opr.output().size());
  1333. SmallVector<SymbolVar> grad;
  1334. for (size_t i = 0; i < out_grad.size(); ++i) {
  1335. auto gval = out_grad[i];
  1336. if (!gval) {
  1337. gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node();
  1338. }
  1339. grad.emplace_back(gval);
  1340. }
  1341. auto offsets_val = opr.get_offsets();
  1342. auto cn = opr.input(0)->comp_node();
  1343. if (opr.config().has_comp_node_set()) {
  1344. cn = opr.config().get_single_comp_node();
  1345. }
  1346. HostTensorND hv{cn, TensorShape{offsets_val.size()}, dtype::Int32{}};
  1347. memcpy(hv.raw_ptr(), offsets_val.data(), offsets_val.size() * sizeof(int));
  1348. auto offsets = opr::ImmutableTensor::make(*opr.input(0)->owner_graph(), hv);
  1349. return ParamPackConcat::make(
  1350. grad, offsets, offsets_val,
  1351. OperatorNodeConfig{}.follow_comp_node(opr.input(0)))
  1352. .node();
  1353. }
  1354. #endif
  1355. // f}}}
  1356. /* f{{{ ======================= RelayoutFormat ======================= */
  1357. MGB_DYN_TYPE_OBJ_FINAL_IMPL(RelayoutFormat);
  1358. MEGDNN_OPR_INIT1(RelayoutFormat, "relayout_format")
  1359. void RelayoutFormat::init_output_format() {
  1360. TensorFormat src_fmt = input(0)->format(), dst_fmt;
  1361. megdnn_opr()->deduce_format(src_fmt, dst_fmt);
  1362. mgb_assert(output().size() == 2);
  1363. output(0)->format(dst_fmt);
  1364. output(1)->format({}); // default format
  1365. }
  1366. // f}}}
  1367. //
  1368. /* f{{{ ===================== WinogradFilterPreprocess ===================== */
  1369. MGB_DYN_TYPE_OBJ_FINAL_IMPL(WinogradFilterPreprocess);
  1370. MEGDNN_OPR_INIT1(WinogradFilterPreprocess, "winograd_filter_preprocess")
  1371. void WinogradFilterPreprocess::init_output_dtype() {
  1372. TensorLayout dst;
  1373. TensorLayout src{input(0)->shape(), input(0)->dtype(), input(0)->format()};
  1374. megdnn_opr()->deduce_layout(src, dst);
  1375. output(0)->dtype(dst.dtype);
  1376. }
  1377. // f}}}
  1378. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台