You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convolution.cpp 66 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641
  1. /**
  2. * \file src/opr/impl/dnn/convolution.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/opr/dnn/convolution.h"
  13. #include "megbrain/opr/io.h"
  14. #include "megbrain/opr/search_policy/algo_chooser.h"
  15. #include "megbrain/opr/search_policy/algo_chooser_helper.h"
  16. #include "megbrain/graph/grad_impl.h"
  17. #include "megbrain/system.h"
  18. #include "megbrain/utils/hash_ct.h"
  19. #include "megbrain/utils/timer.h"
  20. #include "megdnn/oprs/utils.h"
  21. #include "../internal/invoke.h"
  22. #include "../internal/megdnn_opr_wrapper.inl"
  23. #include "../search_policy/workspace_need_limit_getter.inl"
  24. #include <array>
  25. #include <chrono>
  26. #include <cstring>
  27. #include <thread>
  28. using namespace mgb;
  29. using namespace opr;
  30. using namespace cg::static_infer;
  31. using intl::WorkspaceLimitGetter;
  32. /* ==================== misc impl ==================== */
  33. template <class MgbOpr, class MegDNNOpr>
  34. void mixin::ConvolutionBackwardDataMixin::
  35. init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self) {
  36. using namespace cg::static_infer;
  37. auto&& mgr = self->owner_graph()->static_infer_manager();
  38. DepVal inp_deps;
  39. inp_deps.reserve(4);
  40. for (int i = 0; i < 2; ++i) {
  41. inp_deps.push_back({self->input(i), DepType::SHAPE});
  42. }
  43. // output shape
  44. if (self->input().size() == 3) {
  45. mgr.register_shape_infer(self->output(0),
  46. ShapeInferDesc::make_identity(self->input(2)));
  47. } else {
  48. auto infer_shp = [self](TensorShape& dest, const InpVal& inp) {
  49. TensorLayout ol{self->output(0)->dtype()};
  50. static_cast<MgbOpr*>(self)->megdnn_opr()->deduce_layout(
  51. {inp.val.at(0).shape(), self->input(0)->dtype()},
  52. {inp.val.at(1).shape(), self->input(1)->dtype()}, ol);
  53. dest = ol;
  54. return true;
  55. };
  56. mgr.register_shape_infer(self->output(0),
  57. {SourceType::DEP, inp_deps, infer_shp});
  58. }
  59. // workspace size
  60. auto infer_wk = [self](TensorShape& dest, const InpVal& inp) {
  61. auto&& iv = inp.val;
  62. dest.ndim = 1;
  63. dest.shape[0] = AlgoChooser<MegDNNOpr>::setup_algo(
  64. {TensorLayout{iv[0].shape(), self->input(0)->dtype(),
  65. self->input(0)->format()},
  66. {iv[1].shape(), self->input(1)->dtype(),
  67. self->input(1)->format()},
  68. {iv.at(2).shape(), self->output(0)->dtype(),
  69. self->output(0)->format()}},
  70. static_cast<MgbOpr*>(self)->megdnn_opr(),
  71. static_cast<MgbOpr*>(self));
  72. return true;
  73. };
  74. inp_deps.push_back({self->output(0), DepType::SHAPE});
  75. auto workspace_dep_var =
  76. intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph());
  77. if (workspace_dep_var) {
  78. inp_deps.push_back({workspace_dep_var, DepType::VALUE});
  79. }
  80. mgr.register_shape_infer(self->output(1),
  81. {SourceType::DEP, inp_deps, infer_wk});
  82. }
  83. #define IMPL_CONV(_cls) MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls)
  84. class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final
  85. : public cg::GraphExecutable::ExecDependency {
  86. std::unique_ptr<PreprocessedFilter> m_pf;
  87. SmallVector<DeviceTensorND> m_filter_storage;
  88. public:
  89. explicit PreprocessedFilterExecDep(
  90. std::unique_ptr<PreprocessedFilter> preprocessed_filter,
  91. SmallVector<DeviceTensorND> filter_storage)
  92. : m_pf(std::move(preprocessed_filter)),
  93. m_filter_storage(std::move(filter_storage)) {}
  94. };
  95. void mixin::WeightPreprocessExecutor::mixin_update_preprocessed_filter(
  96. cg::OperatorNodeBase& opr) {
  97. if (!mixin_allow_weight_preprocess(opr)) {
  98. return;
  99. }
  100. auto new_layout = deduce_preprocessed_filter_layout();
  101. size_t new_size = new_layout.size();
  102. //! No preprocess layout means no need weight preprocess
  103. if (new_layout.empty()) {
  104. return;
  105. }
  106. //! all layouts arm empty means no need weight preprocess
  107. bool layout_valid = false;
  108. for (auto&& layout : new_layout) {
  109. if (!layout.is_empty()) {
  110. layout_valid = true;
  111. }
  112. }
  113. if (!layout_valid) {
  114. return;
  115. }
  116. if (m_preprocessed_filter) {
  117. for (size_t i = 0; i < new_size; i++) {
  118. mgb_assert(new_layout[i].eq_layout(
  119. m_preprocessed_filter->tensors[i].layout),
  120. "weight preprocess layout changed, please keep input "
  121. "shape unchanged when weight preprocess is enabled");
  122. }
  123. return;
  124. }
  125. m_preprocessed_filter.reset(new PreprocessedFilter{});
  126. m_preprocessed_filter->tensors.resize(new_size);
  127. m_filter_storage.resize(new_size);
  128. m_preprocessed_filter->algorithm_id = nullptr;
  129. for (size_t i = 0; i < new_size; i++) {
  130. m_filter_storage[i] = {opr.output(0)->comp_node(), new_layout[i],
  131. new_layout[i].dtype, new_layout[i].format};
  132. m_preprocessed_filter->tensors[i] = m_filter_storage[i].as_megdnn();
  133. }
  134. scn_do_execute_preprocess();
  135. }
  136. void mixin::WeightPreprocessExecutor::record_preprocessed_weight(
  137. cg::GraphExecutable::ExecDependencyArray& deps) {
  138. deps.emplace_back(new PreprocessedFilterExecDep{
  139. std::move(m_preprocessed_filter), std::move(m_filter_storage)});
  140. }
  141. bool mixin::WeightPreprocessExecutor::mixin_allow_weight_preprocess(
  142. const cg::OperatorNodeBase& opr) const {
  143. if (!opr.owner_graph()->options().graph_opt.weight_preprocess) {
  144. return false;
  145. }
  146. if (!opr.input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE))
  147. return false;
  148. if (cg::is_const_var_value(opr.input(1)))
  149. return true;
  150. auto* input_opr = opr.input(1)->owner_opr();
  151. if (input_opr->same_type<opr::MultipleDeviceTensorHolder>() ||
  152. input_opr->same_type<opr::MultipleDeviceTensorWithFormatHolder>())
  153. return true;
  154. auto* sdt = input_opr->try_cast_final<opr::SharedDeviceTensor>();
  155. if (sdt && sdt->const_value())
  156. return true;
  157. auto* sdtf = input_opr->try_cast_final<opr::SharedDeviceTensorWithFormat>();
  158. if (sdtf && sdtf->const_value())
  159. return true;
  160. return false;
  161. }
  162. /* ==================== ConvolutionForward ==================== */
  163. IMPL_CONV(ConvolutionForward);
  164. ConvolutionForward::ConvolutionForward(VarNode* src, VarNode* filter,
  165. const Param& param,
  166. const ExecutionPolicy& policy,
  167. const OperatorNodeConfig& config)
  168. : Super{src->owner_graph(), config, "conv", {src, filter}} {
  169. init_megdnn_opr(*this, param);
  170. m_policy = policy;
  171. add_input({src, filter});
  172. }
  173. SymbolVar ConvolutionForward::make(SymbolVar src, SymbolVar filter,
  174. const Param& param,
  175. const ExecutionPolicy& policy,
  176. const OperatorNodeConfig& config) {
  177. return src.insert_single_output_opr<ConvolutionForward>(
  178. src.node(), filter.node(), param, policy, config);
  179. }
  180. void ConvolutionForward::init_output_dtype() {
  181. DType output_dtype = config().output_dtype();
  182. megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(),
  183. output_dtype);
  184. output(0)->dtype(output_dtype);
  185. }
  186. #if MGB_ENABLE_GRAD
  187. MGB_IMPL_OPR_GRAD(ConvolutionForward) {
  188. mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
  189. "only float data type supported for grad");
  190. mgb_assert(wrt_idx == 0 || wrt_idx == 1);
  191. mgb_assert(out_grad.size() == 2);
  192. if (wrt_idx == 0) {
  193. // data
  194. SymbolVar grad = ConvolutionBackwardData::make(
  195. opr.input(1), out_grad[0], opr.input(0), opr.param(),
  196. opr.execution_policy());
  197. return grad.node();
  198. } else {
  199. // filter
  200. SymbolVar grad = ConvolutionBackwardFilter::make(
  201. opr.input(0), out_grad[0], opr.input(1), opr.param(),
  202. opr.execution_policy());
  203. return grad.node();
  204. }
  205. }
  206. #endif
  207. size_t ConvolutionForward::get_workspace_size_bytes(
  208. const TensorShapeArray& input_shapes,
  209. const TensorShapeArray& output_shapes) const {
  210. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  211. return AlgoChooser<megdnn::ConvolutionForward>::setup_algo(
  212. {TensorLayout{input_shapes[0], input(0)->dtype(),
  213. input(0)->format()},
  214. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  215. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  216. megdnn_opr(), this, allow_weight_preprocess());
  217. }
  218. void ConvolutionForward::init_output_format() {
  219. mgb_assert(output().size() == 2);
  220. output(0)->format(input(0)->format());
  221. }
  222. void ConvolutionForward::scn_do_execute() {
  223. update_preprocessed_filter();
  224. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  225. input(1)->dev_tensor().as_megdnn(),
  226. output(0)->dev_tensor().as_megdnn(),
  227. preprocessed_filter(),
  228. intl::get_megdnn_workspace_from_var(output().back()));
  229. }
  230. void ConvolutionForward::add_input_layout_constraint() {
  231. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  232. }
  233. void ConvolutionForward::init_output_static_infer_desc() {
  234. Super::set_nr_managed_outputs(this->output().size() - 1);
  235. Super::init_output_static_infer_desc();
  236. init_output_static_infer_desc_workspace(
  237. intl::AutoAddWorkspaceNeedLimitGetter<
  238. megdnn::ConvolutionForward>::val);
  239. }
  240. void ConvolutionForward::get_output_var_shape(
  241. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  242. TensorLayout input_layout{inp_shape[0], input(0)->dtype(),
  243. input(0)->format()};
  244. TensorLayout filter_layout{inp_shape[1], input(1)->dtype(),
  245. input(1)->format()};
  246. TensorLayout dst_layout{output(0)->dtype(), output(0)->format()};
  247. megdnn_opr()->deduce_layout(input_layout, filter_layout, dst_layout);
  248. out_shape[0] = dst_layout;
  249. }
  250. void ConvolutionForward::record_execute_deps(
  251. cg::GraphExecutable::ExecDependencyArray& deps) {
  252. record_megdnn_opr(deps);
  253. record_preprocessed_weight(deps);
  254. }
  255. SmallVector<TensorLayout>
  256. ConvolutionForward::deduce_preprocessed_filter_layout() {
  257. return megdnn_opr()->deduce_preprocessed_filter_layout(
  258. input(0)->layout(), input(1)->layout(), output(0)->layout());
  259. }
  260. void ConvolutionForward::scn_do_execute_preprocess() {
  261. megdnn_opr()->exec_preprocess(
  262. input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
  263. output(0)->layout(), preprocessed_filter(),
  264. intl::get_megdnn_workspace_from_var(output().back()));
  265. //! Flag the input(1) no use later, which can be freed when no other
  266. //! var depend on its dev_value, host_value and shape.
  267. auto receiver_info =
  268. input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1));
  269. if (receiver_info.dev_value == 1 && receiver_info.host_value == 0 &&
  270. receiver_info.shape == 0) {
  271. input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED);
  272. }
  273. }
  274. /* ==================== ConvolutionBackwardData ==================== */
  275. IMPL_CONV(ConvolutionBackwardData);
  276. ConvolutionBackwardData::ConvolutionBackwardData(
  277. VarNode* filter, VarNode* diff, VarNode* src_for_shp,
  278. const Param& param, const ExecutionPolicy& policy,
  279. const OperatorNodeConfig& config)
  280. : Super{filter->owner_graph(),
  281. config,
  282. "conv_bwd_data",
  283. {filter, diff}} {
  284. init_megdnn_opr(*this, param);
  285. m_policy = policy;
  286. add_input({filter, diff});
  287. if (src_for_shp) {
  288. add_input({src_for_shp});
  289. }
  290. }
  291. SymbolVar ConvolutionBackwardData::make(SymbolVar filter, SymbolVar diff,
  292. SymbolVar src, const Param& param,
  293. const ExecutionPolicy& policy,
  294. const OperatorNodeConfig& config) {
  295. return filter.insert_single_output_opr<ConvolutionBackwardData>(
  296. filter.node(), diff.node(), src.node(), param, policy, config);
  297. }
  298. SymbolVar ConvolutionBackwardData::make(SymbolVar filter, SymbolVar data,
  299. const Param& param,
  300. const ExecutionPolicy& policy,
  301. const OperatorNodeConfig& config) {
  302. return make(filter, data, {}, param, policy, config);
  303. }
  304. void ConvolutionBackwardData::add_input_layout_constraint() {
  305. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  306. }
  307. void ConvolutionBackwardData::init_output_static_infer_desc() {
  308. init_output_static_infer_desc_for_bwd_data<ConvolutionBackwardData,
  309. megdnn::ConvolutionBackwardData>(
  310. this);
  311. }
  312. void ConvolutionBackwardData::init_output_dtype() {
  313. DType output_dtype = config().output_dtype();
  314. megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(),
  315. output_dtype);
  316. output(0)->dtype(output_dtype);
  317. }
  318. void ConvolutionBackwardData::init_output_format() {
  319. mgb_assert(output().size() == 2);
  320. output(0)->format(input(1)->format());
  321. }
  322. cg::OperatorNodeBase::NodeProp* ConvolutionBackwardData::do_make_node_prop()
  323. const {
  324. auto prop = Super::Super::do_make_node_prop();
  325. if (input().size() == 3) {
  326. using D = NodeProp::DepType;
  327. prop->reset_dep_type(input(), {D::DEV_VALUE, D::DEV_VALUE, D::SHAPE});
  328. }
  329. return prop;
  330. }
  331. void ConvolutionBackwardData::scn_do_execute() {
  332. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  333. input(1)->dev_tensor().as_megdnn(),
  334. output(0)->dev_tensor().as_megdnn(),
  335. intl::get_megdnn_workspace_from_var(output(1)));
  336. }
  337. #if MGB_ENABLE_GRAD
  338. MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) {
  339. mgb_assert(!out_grad[1]);
  340. if (wrt_idx == 0) {
  341. return ConvolutionBackwardFilter::make(out_grad[0], opr.input(1),
  342. opr.input(0), opr.param(),
  343. opr.execution_policy())
  344. .node();
  345. }
  346. if (wrt_idx == 1) {
  347. return Convolution::make(out_grad[0], opr.input(0), opr.param(),
  348. opr.execution_policy())
  349. .node();
  350. }
  351. return nullptr;
  352. }
  353. #endif
  354. /* ==================== ConvolutionBackwardFilter ==================== */
  355. IMPL_CONV(ConvolutionBackwardFilter);
  356. ConvolutionBackwardFilter::ConvolutionBackwardFilter(
  357. VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
  358. const ExecutionPolicy& policy, const OperatorNodeConfig& config)
  359. : Super({src->owner_graph(),
  360. config,
  361. "conv_bwd_filter",
  362. {src, diff, filter}},
  363. 2, false) {
  364. init_megdnn_opr(*this, param);
  365. m_policy = policy;
  366. add_input({src, diff, filter});
  367. }
  368. SymbolVar ConvolutionBackwardFilter::make(SymbolVar src, SymbolVar diff,
  369. SymbolVar filter, const Param& param,
  370. const ExecutionPolicy& policy,
  371. const OperatorNodeConfig& config) {
  372. return src.insert_single_output_opr<ConvolutionBackwardFilter>(
  373. src.node(), diff.node(), filter.node(), param, policy, config);
  374. }
  375. size_t ConvolutionBackwardFilter::get_workspace_size_bytes(
  376. const TensorShapeArray& input_shapes,
  377. const TensorShapeArray& output_shapes) const {
  378. mgb_assert(input_shapes.size() == 3 && output_shapes.size() == 1);
  379. return AlgoChooser<megdnn::ConvolutionBackwardFilter>::setup_algo(
  380. {TensorLayout{input_shapes[0], input(0)->dtype(),
  381. input(0)->format()},
  382. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  383. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  384. megdnn_opr(), this);
  385. }
  386. #if MGB_ENABLE_GRAD
  387. MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) {
  388. mgb_assert(!out_grad[1]);
  389. if (wrt_idx == 0) {
  390. return ConvolutionBackwardData::make(out_grad[0], opr.input(1),
  391. opr.input(0), opr.param(),
  392. opr.execution_policy())
  393. .node();
  394. }
  395. if (wrt_idx == 1) {
  396. return Convolution::make(opr.input(0), out_grad[0], opr.param(),
  397. opr.execution_policy())
  398. .node();
  399. }
  400. return nullptr;
  401. }
  402. #endif
  403. /* ==================== Convolution3DForward ==================== */
  404. IMPL_CONV(Convolution3DForward);
  405. Convolution3DForward::Convolution3DForward(VarNode* src, VarNode* filter,
  406. const Param& param,
  407. const ExecutionPolicy& policy,
  408. const OperatorNodeConfig& config)
  409. : Super{src->owner_graph(), config, "conv3d", {src, filter}} {
  410. init_megdnn_opr(*this, param);
  411. m_policy = policy;
  412. add_input({src, filter});
  413. }
  414. SymbolVar Convolution3DForward::make(SymbolVar src, SymbolVar filter,
  415. const Param& param,
  416. const ExecutionPolicy& policy,
  417. const OperatorNodeConfig& config) {
  418. return src.insert_single_output_opr<Convolution3DForward>(
  419. src.node(), filter.node(), param, policy, config);
  420. }
  421. void Convolution3DForward::init_output_dtype() {
  422. switch (param().data_type) {
  423. case Param::DataType::FLOAT:
  424. output(0)->dtype(input(0)->dtype());
  425. break;
  426. #if !MEGDNN_DISABLE_FLOAT16
  427. case Param::DataType::FLOAT_IO16xC32:
  428. mgb_assert(input(0)->dtype() == dtype::Float16(),
  429. "invalid input dtype %s", input(0)->name().c_str());
  430. output(0)->dtype(input(0)->dtype());
  431. break;
  432. #endif
  433. default:
  434. mgb_throw(MegBrainError, "bad data_type enum");
  435. }
  436. }
  437. #if MGB_ENABLE_GRAD
  438. MGB_IMPL_OPR_GRAD(Convolution3DForward) {
  439. mgb_assert(opr.param().data_type ==
  440. Convolution3DForward::Param::DataType::FLOAT,
  441. "only float data type supported for grad");
  442. mgb_assert(wrt_idx == 0 || wrt_idx == 1);
  443. mgb_assert(out_grad.size() == 2);
  444. if (wrt_idx == 0) {
  445. // data
  446. SymbolVar grad = Convolution3DBackwardData::make(
  447. opr.input(1), out_grad[0], opr.input(0), opr.param(),
  448. opr.execution_policy());
  449. return grad.node();
  450. } else {
  451. // filter
  452. SymbolVar grad = Convolution3DBackwardFilter::make(
  453. opr.input(0), out_grad[0], opr.input(1), opr.param(),
  454. opr.execution_policy());
  455. return grad.node();
  456. }
  457. }
  458. #endif
  459. size_t Convolution3DForward::get_workspace_size_bytes(
  460. const TensorShapeArray& input_shapes,
  461. const TensorShapeArray& output_shapes) const {
  462. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  463. return AlgoChooser<megdnn::Convolution3DForward>::setup_algo(
  464. {TensorLayout{input_shapes[0], input(0)->dtype(),
  465. input(0)->format()},
  466. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  467. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  468. megdnn_opr(), this);
  469. }
  470. /* ==================== Convolution3DBackwardData ==================== */
  471. IMPL_CONV(Convolution3DBackwardData);
  472. Convolution3DBackwardData::Convolution3DBackwardData(
  473. VarNode* filter, VarNode* diff, VarNode* src_for_shp,
  474. const Param& param, const ExecutionPolicy& policy,
  475. const OperatorNodeConfig& config)
  476. : Super{filter->owner_graph(),
  477. config,
  478. "conv3d_bwd_data",
  479. {filter, diff}} {
  480. init_megdnn_opr(*this, param);
  481. m_policy = policy;
  482. add_input({filter, diff});
  483. if (src_for_shp) {
  484. add_input({src_for_shp});
  485. }
  486. }
  487. SymbolVar Convolution3DBackwardData::make(SymbolVar filter, SymbolVar diff,
  488. SymbolVar src, const Param& param,
  489. const ExecutionPolicy& policy,
  490. const OperatorNodeConfig& config) {
  491. return filter.insert_single_output_opr<Convolution3DBackwardData>(
  492. filter.node(), diff.node(), src.node(), param, policy, config);
  493. }
  494. SymbolVar Convolution3DBackwardData::make(SymbolVar filter, SymbolVar data,
  495. const Param& param,
  496. const ExecutionPolicy& policy,
  497. const OperatorNodeConfig& config) {
  498. return make(filter, data, {}, param, policy, config);
  499. }
  500. void Convolution3DBackwardData::add_input_layout_constraint() {
  501. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  502. }
  503. void Convolution3DBackwardData::init_output_static_infer_desc() {
  504. init_output_static_infer_desc_for_bwd_data<
  505. Convolution3DBackwardData, megdnn::Convolution3DBackwardData>(this);
  506. }
  507. cg::OperatorNodeBase::NodeProp* Convolution3DBackwardData::do_make_node_prop()
  508. const {
  509. auto prop = Super::Super::do_make_node_prop();
  510. if (input().size() == 3) {
  511. using D = NodeProp::DepType;
  512. prop->reset_dep_type(input(), {D::DEV_VALUE, D::DEV_VALUE, D::SHAPE});
  513. }
  514. return prop;
  515. }
  516. void Convolution3DBackwardData::scn_do_execute() {
  517. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  518. input(1)->dev_tensor().as_megdnn(),
  519. output(0)->dev_tensor().as_megdnn(),
  520. intl::get_megdnn_workspace_from_var(output(1)));
  521. }
  522. #if MGB_ENABLE_GRAD
  523. MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) {
  524. mgb_assert(!out_grad[1]);
  525. if (wrt_idx == 0) {
  526. return Convolution3DBackwardFilter::make(out_grad[0], opr.input(1),
  527. opr.input(0), opr.param(),
  528. opr.execution_policy())
  529. .node();
  530. }
  531. if (wrt_idx == 1) {
  532. return Convolution3D::make(out_grad[0], opr.input(0), opr.param(),
  533. opr.execution_policy())
  534. .node();
  535. }
  536. return nullptr;
  537. }
  538. #endif
  539. /* ==================== Convolution3DBackwardFilter ==================== */
  540. IMPL_CONV(Convolution3DBackwardFilter);
  541. Convolution3DBackwardFilter::Convolution3DBackwardFilter(
  542. VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
  543. const ExecutionPolicy& policy, const OperatorNodeConfig& config)
  544. : Super({src->owner_graph(),
  545. config,
  546. "conv3d_bwd_filter",
  547. {src, diff, filter}},
  548. 2, false) {
  549. init_megdnn_opr(*this, param);
  550. m_policy = policy;
  551. add_input({src, diff, filter});
  552. }
  553. SymbolVar Convolution3DBackwardFilter::make(SymbolVar src, SymbolVar diff,
  554. SymbolVar filter,
  555. const Param& param,
  556. const ExecutionPolicy& policy,
  557. const OperatorNodeConfig& config) {
  558. return src.insert_single_output_opr<Convolution3DBackwardFilter>(
  559. src.node(), diff.node(), filter.node(), param, policy, config);
  560. }
  561. size_t Convolution3DBackwardFilter::get_workspace_size_bytes(
  562. const TensorShapeArray& input_shapes,
  563. const TensorShapeArray& output_shapes) const {
  564. mgb_assert(input_shapes.size() == 3 && output_shapes.size() == 1);
  565. return AlgoChooser<megdnn::Convolution3DBackwardFilter>::setup_algo(
  566. {TensorLayout{input_shapes[0], input(0)->dtype(),
  567. input(0)->format()},
  568. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  569. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  570. megdnn_opr(), this);
  571. }
  572. /* ========================== MaskConvolution ========================== */
  573. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaskConvolution);
  574. MaskConvolution::MaskConvolution(VarNode* src, VarNode* filter, VarNode* mask,
  575. const Param& param,
  576. const OperatorNodeConfig& config)
  577. : Super(src->owner_graph(), config, "mask_conv_fwd",
  578. {src, filter, mask}) {
  579. init_megdnn_opr(*this, param);
  580. add_input({src, filter, mask});
  581. }
  582. SymbolVar MaskConvolution::make(SymbolVar src, SymbolVar filter, SymbolVar mask,
  583. const Param& param,
  584. const OperatorNodeConfig& config) {
  585. return src.insert_single_output_opr<MaskConvolution>(
  586. src.node(), filter.node(), mask.node(), param, config);
  587. }
  588. void MaskConvolution::init_output_dtype() {
  589. auto dtype = input(2)->dtype();
  590. mgb_assert(dtype == dtype::Int32() || dtype == dtype::Int16() ||
  591. dtype == dtype::Int8(),
  592. "dtype must be int8, int16 or int32, while get %s",
  593. dtype.name());
  594. output(0)->dtype(input(0)->dtype());
  595. }
  596. MGB_DYN_TYPE_OBJ_FINAL_IMPL(MaskPropagate);
  597. MaskPropagate::MaskPropagate(VarNode* src, const Param& param,
  598. const OperatorNodeConfig& config)
  599. : Super(src->owner_graph(), config, "mask_propagate", {src}) {
  600. init_megdnn_opr(*this, param);
  601. add_input({src});
  602. }
  603. void MaskPropagate::init_output_dtype() {
  604. auto dtype = input(0)->dtype();
  605. mgb_assert(dtype == dtype::Int32() || dtype == dtype::Int16() ||
  606. dtype == dtype::Int8());
  607. output(0)->dtype(dtype);
  608. }
  609. SymbolVar MaskPropagate::make(SymbolVar src, const Param& param,
  610. const OperatorNodeConfig& config) {
  611. return src.insert_single_output_opr<MaskPropagate>(src.node(), param,
  612. config);
  613. }
  614. /* ==================== ConvBiasForward ==================== */
  615. IMPL_CONV(ConvBiasForward);
  616. ConvBiasForward::ConvBiasForward(VarNode* src, VarNode* filter,
  617. const Param& param,
  618. const ExecutionPolicy& policy,
  619. const OperatorNodeConfig& config)
  620. : Super{src->owner_graph(), config, "conv_bias", {src, filter}} {
  621. init_megdnn_opr(*this, param);
  622. m_policy = policy;
  623. add_input({src, filter});
  624. }
  625. ConvBiasForward::ConvBiasForward(VarNode* src, VarNode* filter, VarNode* bias,
  626. const Param& param,
  627. const ExecutionPolicy& policy,
  628. const OperatorNodeConfig& config)
  629. : Super{src->owner_graph(), config, "conv_bias", {src, filter, bias}} {
  630. m_policy = policy;
  631. init_megdnn_opr(*this, param);
  632. add_input({src, filter, bias});
  633. }
  634. ConvBiasForward::ConvBiasForward(VarNode* src, VarNode* filter, VarNode* bias,
  635. VarNode* z, const Param& param,
  636. const ExecutionPolicy& policy,
  637. const OperatorNodeConfig& config)
  638. : Super{src->owner_graph(),
  639. config,
  640. "conv_bias",
  641. {src, filter, bias, z}} {
  642. m_policy = policy;
  643. init_megdnn_opr(*this, param);
  644. add_input({src, filter, bias, z});
  645. }
  646. void ConvBiasForward::add_input_layout_constraint() {
  647. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  648. }
  649. SymbolVar ConvBiasForward::make(SymbolVar src, SymbolVar filter,
  650. const Param& param,
  651. const ExecutionPolicy& policy,
  652. const OperatorNodeConfig& config) {
  653. return src.insert_single_output_opr<ConvBiasForward>(
  654. src.node(), filter.node(), param, policy, config);
  655. }
  656. SymbolVar ConvBiasForward::make(SymbolVar src, SymbolVar filter, SymbolVar bias,
  657. const Param& param,
  658. const ExecutionPolicy& policy,
  659. const OperatorNodeConfig& config) {
  660. return src.insert_single_output_opr<ConvBiasForward>(
  661. src.node(), filter.node(), bias.node(), param, policy, config);
  662. }
  663. SymbolVar ConvBiasForward::make(SymbolVar src, SymbolVar filter, SymbolVar bias,
  664. SymbolVar z, const Param& param,
  665. const ExecutionPolicy& policy,
  666. const OperatorNodeConfig& config) {
  667. return src.insert_single_output_opr<ConvBiasForward>(
  668. src.node(), filter.node(), bias.node(), z.node(), param, policy,
  669. config);
  670. }
  671. void ConvBiasForward::init_output_dtype() {
  672. DType output_dtype = config().output_dtype();
  673. DType i0, i1, i2, i3;
  674. mgb_assert(input().size() >= 2 && input().size() <= 4);
  675. i0 = input(0)->dtype();
  676. i1 = input(1)->dtype();
  677. if (input().size() >= 3)
  678. i2 = input(2)->dtype();
  679. if (input().size() == 4)
  680. i3 = input(3)->dtype();
  681. megdnn_opr()->deduce_dtype(i0, i1, i2, i3, output_dtype);
  682. output(0)->dtype(output_dtype);
  683. }
  684. size_t ConvBiasForward::get_workspace_size_bytes(
  685. const TensorShapeArray& input_shapes,
  686. const TensorShapeArray& output_shapes) const {
  687. auto mo = megdnn_opr();
  688. TensorLayout i0, i1, i2, i3;
  689. mgb_assert(input_shapes.size() >= 2 && input_shapes.size() <= 4);
  690. i0 = {input_shapes[0], input(0)->dtype(), input(0)->format()};
  691. i1 = {input_shapes[1], input(1)->dtype(), input(1)->format()};
  692. if (input_shapes.size() >= 3)
  693. i2 = {input_shapes[2], input(2)->dtype(), input(2)->format()};
  694. else {
  695. DType dtype;
  696. mo->deduce_dtype(input(0)->dtype(), input(1)->dtype(), DType{}, DType{},
  697. dtype);
  698. i2 = {{}, dtype};
  699. }
  700. if (input_shapes.size() == 4)
  701. i3 = {input_shapes[3], input(3)->dtype(), input(3)->format()};
  702. else
  703. i3 = {{}, output(0)->dtype(), output(0)->format()};
  704. return AlgoChooser<megdnn::ConvBias>::setup_algo(
  705. {i0,
  706. i1,
  707. i2,
  708. i3,
  709. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  710. mo, this, allow_weight_preprocess());
  711. }
  712. void ConvBiasForward::scn_do_execute() {
  713. update_preprocessed_filter();
  714. auto&& inp = input();
  715. auto mo = megdnn_opr();
  716. if (inp.size() == 2) {
  717. TensorLayout bias_layout;
  718. bias_layout.ndim = 0;
  719. if (output(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
  720. bias_layout.dtype = dtype::QuantizedS32(
  721. output(0)->dtype().param<dtype::QuantizedS8>().scale);
  722. } else {
  723. bias_layout.dtype = output(0)->dtype();
  724. }
  725. TensorLayout z_layout;
  726. z_layout.ndim = 0;
  727. z_layout.dtype = output(0)->dtype();
  728. megdnn::TensorND bias_tensor{nullptr, bias_layout};
  729. megdnn::TensorND z_tensor{nullptr, z_layout};
  730. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  731. inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor,
  732. output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
  733. intl::get_megdnn_workspace_from_var(output().back()));
  734. } else if (inp.size() == 3) {
  735. TensorLayout z_layout;
  736. z_layout.ndim = 0;
  737. z_layout.dtype = output(0)->dtype();
  738. megdnn::TensorND z_tensor{nullptr, z_layout};
  739. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  740. inp[1]->dev_tensor().as_megdnn(),
  741. inp[2]->dev_tensor().as_megdnn(), z_tensor,
  742. output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
  743. intl::get_megdnn_workspace_from_var(output().back()));
  744. } else {
  745. mgb_assert(inp.size() == 4);
  746. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  747. inp[1]->dev_tensor().as_megdnn(),
  748. inp[2]->dev_tensor().as_megdnn(),
  749. inp[3]->dev_tensor().as_megdnn(),
  750. output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
  751. intl::get_megdnn_workspace_from_var(output().back()));
  752. }
  753. }
  754. void ConvBiasForward::get_output_var_shape(const TensorShapeArray& inp_shape,
  755. TensorShapeArray& out_shape) const {
  756. auto mo = megdnn_opr();
  757. TensorLayout dst;
  758. mo->deduce_layout({inp_shape[0], input(0)->dtype(), input(0)->format()},
  759. {inp_shape[1], input(1)->dtype(), input(0)->format()}, {},
  760. {}, dst);
  761. out_shape[0] = dst;
  762. }
  763. void ConvBiasForward::init_output_static_infer_desc() {
  764. Super::set_nr_managed_outputs(this->output().size() - 1);
  765. Super::init_output_static_infer_desc();
  766. this->init_output_static_infer_desc_workspace(
  767. intl::AutoAddWorkspaceNeedLimitGetter<
  768. megdnn::ConvBiasForward>::val);
  769. }
  770. void ConvBiasForward::init_output_format() {
  771. mgb_assert(output().size() == 2);
  772. output(0)->format(input(0)->format());
  773. }
  774. void ConvBiasForward::check_winograd_param_valid(
  775. const megdnn::ConvBias::WinogradParam& param, const DType& dtype) {
  776. if (dtype.enumv() == DTypeEnum::Float32) {
  777. mgb_assert(param.channel_block_size == 1 ||
  778. param.channel_block_size == 4 ||
  779. param.channel_block_size == 8,
  780. "only support 1/4/8 for the channel_block_size of "
  781. "winograd param, got %u",
  782. param.channel_block_size);
  783. } else {
  784. mgb_assert((MEGDNN_FLOAT16_SELECT(dtype.enumv() == DTypeEnum::Float16,
  785. false) ||
  786. dtype.enumv() == DTypeEnum::QuantizedS8 ||
  787. dtype.enumv() == DTypeEnum::Quantized8Asymm) &&
  788. (param.channel_block_size == 1 ||
  789. param.channel_block_size == 4 ||
  790. param.channel_block_size == 8),
  791. "only support 1/4/8 for the channel_block_size of "
  792. "winograd param, got %u",
  793. param.channel_block_size);
  794. }
  795. }
  796. megdnn::param::MatrixMul::Format ConvBiasForward::get_matmul_format(
  797. const megdnn::ConvBias::WinogradParam& param) {
  798. switch (param.channel_block_size) {
  799. case 1:
  800. return megdnn::param::MatrixMul::Format::DEFAULT;
  801. break;
  802. case 4:
  803. return megdnn::param::MatrixMul::Format::MK4;
  804. break;
  805. case 8:
  806. return megdnn::param::MatrixMul::Format::MK8;
  807. break;
  808. default:
  809. mgb_throw(InternalError,
  810. "Only Support 1/4/8 for "
  811. "channel_block_size, got: %u",
  812. param.channel_block_size);
  813. }
  814. }
  815. SmallVector<TensorLayout> ConvBiasForward::deduce_preprocessed_filter_layout() {
  816. TensorLayout i2, i3;
  817. if (input().size() > 2) {
  818. i2 = input(2)->layout();
  819. }
  820. if (input().size() > 3) {
  821. i3 = input(3)->layout();
  822. }
  823. return megdnn_opr()->deduce_preprocessed_filter_layout(
  824. input(0)->layout(), input(1)->layout(), i2, i3,
  825. output(0)->layout());
  826. }
  827. void ConvBiasForward::scn_do_execute_preprocess() {
  828. TensorLayout bias_layout(output(0)->dtype()), z_layout(output(0)->dtype());
  829. if (input().size() > 2) {
  830. bias_layout = input(2)->layout();
  831. }
  832. if (input().size() > 3) {
  833. z_layout = input(3)->layout();
  834. }
  835. if (input().size() > 2) {
  836. megdnn_opr()->exec_preprocess(
  837. input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
  838. input(2)->dev_tensor().as_megdnn(), z_layout,
  839. output(0)->layout(), preprocessed_filter(),
  840. intl::get_megdnn_workspace_from_var(output().back()));
  841. } else {
  842. megdnn::TensorND bias_tensor{nullptr, bias_layout};
  843. megdnn_opr()->exec_preprocess(
  844. input(0)->layout(), input(1)->dev_tensor().as_megdnn(),
  845. bias_tensor, z_layout, output(0)->layout(),
  846. preprocessed_filter(),
  847. intl::get_megdnn_workspace_from_var(output().back()));
  848. }
  849. //! Flag the weight and bias no use later, which can be freed when no other
  850. //! var depend on its dev_value, host_value and shape.
  851. auto receiver_info_weight =
  852. input(1)->owner_graph()->var_receiver_in_current_comp_seq(input(1));
  853. if (receiver_info_weight.dev_value == 1 &&
  854. receiver_info_weight.host_value == 0 &&
  855. receiver_info_weight.shape == 0) {
  856. input(1)->add_flag(VarNode::Flag::MEMORY_NO_NEED);
  857. }
  858. //! if bias is preprocessd
  859. if (input().size() > 2) {
  860. auto preprocessed_layouts =
  861. megdnn_opr()->deduce_preprocessed_filter_layout(
  862. input(0)->layout(), input(1)->layout(), bias_layout,
  863. z_layout, output(0)->layout());
  864. if (preprocessed_layouts.size() > 1 &&
  865. !preprocessed_layouts[1].is_empty()) {
  866. auto receiver_info_bias =
  867. input(2)->owner_graph()->var_receiver_in_current_comp_seq(
  868. input(2));
  869. if (receiver_info_bias.dev_value == 1 &&
  870. receiver_info_bias.host_value == 0 &&
  871. receiver_info_bias.shape == 0) {
  872. input(2)->add_flag(VarNode::Flag::MEMORY_NO_NEED);
  873. }
  874. }
  875. }
  876. }
  877. /* ===================== LocalShareForward ==================== */
  878. IMPL_CONV(LocalShareForward);
  879. LocalShareForward::LocalShareForward(VarNode* src, VarNode* filter,
  880. const Param& param,
  881. const ExecutionPolicy& policy,
  882. const OperatorNodeConfig& config)
  883. : Super{src->owner_graph(), config, "local_share", {src, filter}} {
  884. init_megdnn_opr(*this, param);
  885. m_policy = policy;
  886. add_input({src, filter});
  887. }
  888. SymbolVar LocalShareForward::make(SymbolVar src, SymbolVar filter,
  889. const Param& param,
  890. const ExecutionPolicy& policy,
  891. const OperatorNodeConfig& config) {
  892. return src.insert_single_output_opr<LocalShareForward>(
  893. src.node(), filter.node(), param, policy, config);
  894. }
  895. void LocalShareForward::init_output_dtype() {
  896. DType output_dtype = config().output_dtype();
  897. mgb_assert(!output_dtype.valid() || output_dtype == dtype::Float32());
  898. output_dtype = dtype::Float32();
  899. output(0)->dtype(output_dtype);
  900. }
  901. void LocalShareForward::init_output_format() {
  902. mgb_assert(output().size() == 2);
  903. output(0)->format(input(0)->format());
  904. }
  905. size_t LocalShareForward::get_workspace_size_bytes(
  906. const TensorShapeArray& input_shapes,
  907. const TensorShapeArray& output_shapes) const {
  908. mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
  909. return AlgoChooser<megdnn::LocalShareForward>::setup_algo(
  910. {TensorLayout{input_shapes[0], input(0)->dtype(),
  911. input(0)->format()},
  912. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  913. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  914. megdnn_opr(), this);
  915. }
  916. #if MGB_ENABLE_GRAD
  917. MGB_IMPL_OPR_GRAD(LocalShareForward) {
  918. mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
  919. "only float data type supported for grad");
  920. mgb_assert(wrt_idx == 0 || wrt_idx == 1);
  921. mgb_assert(out_grad.size() == 2);
  922. if (wrt_idx == 0) {
  923. // data
  924. SymbolVar grad = LocalShareBackwardData::make(opr.input(1), out_grad[0],
  925. opr.input(0), opr.param(),
  926. opr.execution_policy());
  927. return grad.node();
  928. } else {
  929. // filter
  930. SymbolVar grad = LocalShareBackwardFilter::make(
  931. opr.input(0), out_grad[0], opr.input(1), opr.param(),
  932. opr.execution_policy());
  933. return grad.node();
  934. }
  935. }
  936. #endif
  937. /* ===================== LocalShareBackwardData ==================== */
  938. IMPL_CONV(LocalShareBackwardData);
  939. LocalShareBackwardData::LocalShareBackwardData(VarNode* filter, VarNode* diff,
  940. VarNode* src_for_shp,
  941. const Param& param,
  942. const ExecutionPolicy& policy,
  943. const OperatorNodeConfig& config)
  944. : Super{filter->owner_graph(),
  945. config,
  946. "local_share_bwd_data",
  947. {filter, diff}} {
  948. init_megdnn_opr(*this, param);
  949. m_policy = policy;
  950. add_input({filter, diff});
  951. if (src_for_shp) {
  952. add_input({src_for_shp});
  953. }
  954. }
  955. SymbolVar LocalShareBackwardData::make(SymbolVar filter, SymbolVar diff,
  956. SymbolVar src, const Param& param,
  957. const ExecutionPolicy& policy,
  958. const OperatorNodeConfig& config) {
  959. return filter.insert_single_output_opr<LocalShareBackwardData>(
  960. filter.node(), diff.node(), src.node(), param, policy, config);
  961. }
  962. void LocalShareBackwardData::init_output_static_infer_desc() {
  963. init_output_static_infer_desc_for_bwd_data<LocalShareBackwardData,
  964. megdnn::LocalShareBackwardData>(
  965. this);
  966. }
  967. void LocalShareBackwardData::init_output_dtype() {
  968. DType output_dtype = config().output_dtype();
  969. mgb_assert(!output_dtype.valid() || output_dtype == dtype::Float32());
  970. output_dtype = dtype::Float32();
  971. output(0)->dtype(output_dtype);
  972. }
  973. void LocalShareBackwardData::add_input_layout_constraint() {
  974. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  975. }
  976. cg::OperatorNodeBase::NodeProp* LocalShareBackwardData::do_make_node_prop()
  977. const {
  978. auto prop = Super::Super::do_make_node_prop();
  979. mgb_assert(input().size() == 3);
  980. using D = NodeProp::DepType;
  981. prop->reset_dep_type(input(), {D::DEV_VALUE, D::DEV_VALUE, D::SHAPE});
  982. return prop;
  983. }
  984. void LocalShareBackwardData::scn_do_execute() {
  985. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
  986. input(1)->dev_tensor().as_megdnn(),
  987. output(0)->dev_tensor().as_megdnn(),
  988. intl::get_megdnn_workspace_from_var(output(1)));
  989. }
  990. #if MGB_ENABLE_GRAD
  991. MGB_IMPL_OPR_GRAD(LocalShareBackwardData) {
  992. mgb_assert(!out_grad[1]);
  993. if (wrt_idx == 0) {
  994. return LocalShareBackwardFilter::make(out_grad[0], opr.input(1),
  995. opr.input(0), opr.param(),
  996. opr.execution_policy())
  997. .node();
  998. }
  999. if (wrt_idx == 1) {
  1000. return LocalShare::make(out_grad[0], opr.input(0), opr.param(),
  1001. opr.execution_policy())
  1002. .node();
  1003. }
  1004. return nullptr;
  1005. }
  1006. #endif
  1007. /* ==================== LocalShareBackwardFilter ==================== */
  1008. IMPL_CONV(LocalShareBackwardFilter);
  1009. LocalShareBackwardFilter::LocalShareBackwardFilter(
  1010. VarNode* src, VarNode* diff, VarNode* filter, const Param& param,
  1011. const ExecutionPolicy& policy, const OperatorNodeConfig& config)
  1012. : Super({src->owner_graph(),
  1013. config,
  1014. "local_share_bwd_filter",
  1015. {src, diff, filter}},
  1016. 2, false) {
  1017. init_megdnn_opr(*this, param);
  1018. m_policy = policy;
  1019. add_input({src, diff, filter});
  1020. }
  1021. SymbolVar LocalShareBackwardFilter::make(SymbolVar src, SymbolVar diff,
  1022. SymbolVar filter, const Param& param,
  1023. const ExecutionPolicy& policy,
  1024. const OperatorNodeConfig& config) {
  1025. return src.insert_single_output_opr<LocalShareBackwardFilter>(
  1026. src.node(), diff.node(), filter.node(), param, policy, config);
  1027. }
  1028. size_t LocalShareBackwardFilter::get_workspace_size_bytes(
  1029. const TensorShapeArray& input_shapes,
  1030. const TensorShapeArray& output_shapes) const {
  1031. mgb_assert(input_shapes.size() == 3 && output_shapes.size() == 1);
  1032. return AlgoChooser<megdnn::LocalShareBackwardFilter>::setup_algo(
  1033. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1034. input(0)->format()},
  1035. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  1036. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1037. megdnn_opr(), this);
  1038. }
  1039. #if MGB_ENABLE_GRAD
  1040. MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) {
  1041. mgb_assert(!out_grad[1]);
  1042. if (wrt_idx == 0) {
  1043. return LocalShareBackwardData::make(out_grad[0], opr.input(1),
  1044. opr.input(0), opr.param(),
  1045. opr.execution_policy())
  1046. .node();
  1047. }
  1048. if (wrt_idx == 1) {
  1049. return LocalShare::make(opr.input(0), out_grad[0], opr.param(),
  1050. opr.execution_policy())
  1051. .node();
  1052. }
  1053. return nullptr;
  1054. }
  1055. #endif
  1056. /* ===================== DeformableConvForward ==================== */
  1057. IMPL_CONV(DeformableConvForward);
  1058. DeformableConvForward::DeformableConvForward(VarNode* src, VarNode* filter,
  1059. VarNode* offset, VarNode* mask,
  1060. const Param& param,
  1061. const ExecutionPolicy& policy,
  1062. const OperatorNodeConfig& config)
  1063. : Super{src->owner_graph(),
  1064. config,
  1065. "deformable_conv",
  1066. {src, filter, offset, mask}} {
  1067. mgb_assert(src->dtype() == dtype::Float32() &&
  1068. filter->dtype() == dtype::Float32() &&
  1069. offset->dtype() == dtype::Float32() &&
  1070. mask->dtype() == dtype::Float32(),
  1071. "input should be float32, got %s, %s, %s, %s",
  1072. src->dtype().name(), filter->dtype().name(),
  1073. offset->dtype().name(), mask->dtype().name());
  1074. init_megdnn_opr(*this, param);
  1075. m_policy = policy;
  1076. add_input({src, filter, offset, mask});
  1077. }
  1078. SymbolVar DeformableConvForward::make(SymbolVar src, SymbolVar filter,
  1079. SymbolVar offset, SymbolVar mask,
  1080. const Param& param,
  1081. const ExecutionPolicy& policy,
  1082. const OperatorNodeConfig& config) {
  1083. return src.insert_single_output_opr<DeformableConvForward>(
  1084. src.node(), filter.node(), offset.node(), mask.node(), param,
  1085. policy, config);
  1086. }
  1087. void DeformableConvForward::init_output_dtype() {
  1088. DType output_dtype = config().output_dtype();
  1089. mgb_assert(!output_dtype.valid() || output_dtype == dtype::Float32());
  1090. output_dtype = dtype::Float32();
  1091. output(0)->dtype(output_dtype);
  1092. }
  1093. void DeformableConvForward::init_output_format() {
  1094. mgb_assert(output().size() == 2);
  1095. output(0)->format(input(0)->format());
  1096. }
  1097. size_t DeformableConvForward::get_workspace_size_bytes(
  1098. const TensorShapeArray& input_shapes,
  1099. const TensorShapeArray& output_shapes) const {
  1100. mgb_assert(input_shapes.size() == 4 && output_shapes.size() == 1);
  1101. return AlgoChooser<megdnn::DeformableConvForward>::setup_algo(
  1102. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1103. input(0)->format()},
  1104. {input_shapes[1], input(1)->dtype(), input(1)->format()},
  1105. {input_shapes[2], input(2)->dtype(), input(2)->format()},
  1106. {input_shapes[3], input(3)->dtype(), input(3)->format()},
  1107. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1108. megdnn_opr(), this);
  1109. }
  1110. #if MGB_ENABLE_GRAD
  1111. MGB_IMPL_OPR_GRAD(DeformableConvForward) {
  1112. mgb_assert(opr.input(0)->dtype() == dtype::Float32(),
  1113. "only float data type supported for grad");
  1114. mgb_assert(wrt_idx < 4);
  1115. mgb_assert(!out_grad[1]);
  1116. mgb_assert(out_grad.size() == 2);
  1117. // data, offset and mask
  1118. auto grad_arr = DeformableConvBackwardData::make_all(
  1119. opr.input(0), opr.input(1), opr.input(2), opr.input(3), out_grad[0],
  1120. opr.param(), opr.execution_policy(), opr.config());
  1121. // filter
  1122. auto filter_grad = DeformableConvBackwardFilter::make(
  1123. opr.input(0), opr.input(1), opr.input(2), opr.input(3), out_grad[0],
  1124. opr.param(), opr.execution_policy(), opr.config());
  1125. SymbolVarArray grads = {grad_arr[0], filter_grad, grad_arr[1], grad_arr[2]};
  1126. return grads[wrt_idx].node();
  1127. }
  1128. #endif
  1129. /* ==================== DeformableConvBackwardData ==================== */
  1130. IMPL_CONV(DeformableConvBackwardData);
  1131. DeformableConvBackwardData::DeformableConvBackwardData(
  1132. VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
  1133. VarNode* diff, const Param& param, const ExecutionPolicy& policy,
  1134. const OperatorNodeConfig& config)
  1135. : Super{filter->owner_graph(),
  1136. config,
  1137. "deformable_conv_backward_data",
  1138. {src, filter, offset, mask, diff}} {
  1139. mgb_assert(src->dtype() == dtype::Float32() and
  1140. filter->dtype() == dtype::Float32() and
  1141. offset->dtype() == dtype::Float32() and
  1142. mask->dtype() == dtype::Float32() and
  1143. diff->dtype() == dtype::Float32(),
  1144. "input should be float32, got %s, %s, %s, %s %s",
  1145. src->dtype().name(), filter->dtype().name(),
  1146. offset->dtype().name(), mask->dtype().name(),
  1147. diff->dtype().name());
  1148. init_megdnn_opr(*this, param);
  1149. m_policy = policy;
  1150. add_input({src, filter, offset, mask, diff});
  1151. }
  1152. SymbolVarArray DeformableConvBackwardData::make_all(
  1153. SymbolVar src, SymbolVar filter, SymbolVar offset, SymbolVar mask,
  1154. SymbolVar diff, const Param& param, const ExecutionPolicy& policy,
  1155. const OperatorNodeConfig& config) {
  1156. auto graph = src.node()->owner_graph();
  1157. auto back_node =
  1158. graph->insert_opr(std::make_unique<DeformableConvBackwardData>(
  1159. src.node(), filter.node(), offset.node(), mask.node(),
  1160. diff.node(), param, policy, config));
  1161. return {back_node->output(0), back_node->output(1), back_node->output(2)};
  1162. }
  1163. SymbolVar DeformableConvBackwardData::make(SymbolVar src, SymbolVar filter,
  1164. SymbolVar offset, SymbolVar mask,
  1165. SymbolVar diff, const Param& param,
  1166. const ExecutionPolicy& policy,
  1167. const OperatorNodeConfig& config) {
  1168. auto&& all =
  1169. make_all(src, filter, offset, mask, diff, param, policy, config);
  1170. return all[0];
  1171. }
  1172. void DeformableConvBackwardData::scn_do_execute() {
  1173. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), // src
  1174. input(1)->dev_tensor().as_megdnn(), // filter
  1175. input(2)->dev_tensor().as_megdnn(), // offset
  1176. input(3)->dev_tensor().as_megdnn(), // mask
  1177. input(4)->dev_tensor().as_megdnn(), // diff
  1178. output(0)->dev_tensor().as_megdnn(), // src_grad
  1179. output(1)->dev_tensor().as_megdnn(), // offset_grad
  1180. output(2)->dev_tensor().as_megdnn(), // mask_grad
  1181. intl::get_megdnn_workspace_from_var(output(3)));
  1182. }
  1183. void DeformableConvBackwardData::get_output_var_shape(
  1184. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  1185. TensorShape im_shp = inp_shape[0];
  1186. TensorShape offset_shp = inp_shape[2];
  1187. TensorShape mask_shp = inp_shape[3];
  1188. mgb_assert(im_shp.ndim == 4, "invalid src shape: %s",
  1189. im_shp.to_string().c_str());
  1190. mgb_assert(offset_shp.ndim == 4, "invalid offset shape: %s",
  1191. offset_shp.to_string().c_str());
  1192. mgb_assert(mask_shp.ndim == 4, "invalid mask shape: %s",
  1193. mask_shp.to_string().c_str());
  1194. mgb_assert(out_shape.size() == 3);
  1195. out_shape[0] = im_shp;
  1196. out_shape[1] = offset_shp;
  1197. out_shape[2] = mask_shp;
  1198. }
  1199. size_t DeformableConvBackwardData::get_workspace_size_bytes(
  1200. const TensorShapeArray& inp_shape,
  1201. const TensorShapeArray& out_shape) const {
  1202. size_t ws = AlgoChooser<megdnn::DeformableConvBackwardData>::setup_algo(
  1203. {TensorLayout{inp_shape[0], input(0)->dtype(), input(0)->format()},
  1204. {inp_shape[1], input(1)->dtype(), input(1)->format()},
  1205. {inp_shape[2], input(2)->dtype(), input(2)->format()},
  1206. {inp_shape[3], input(3)->dtype(), input(3)->format()},
  1207. {inp_shape[4], input(4)->dtype(), input(4)->format()},
  1208. {out_shape[0], output(0)->dtype(), output(0)->format()},
  1209. {out_shape[1], output(1)->dtype(), output(1)->format()},
  1210. {out_shape[2], output(2)->dtype(), output(2)->format()}},
  1211. megdnn_opr(), this);
  1212. return ws;
  1213. }
  1214. void DeformableConvBackwardData::init_output_dtype() {
  1215. DType output_dtype = config().output_dtype();
  1216. mgb_assert(!output_dtype.valid() || output_dtype == dtype::Float32());
  1217. output_dtype = dtype::Float32();
  1218. output(0)->dtype(output_dtype);
  1219. output(1)->dtype(output_dtype);
  1220. output(2)->dtype(output_dtype);
  1221. }
  1222. void DeformableConvBackwardData::init_output_format() {
  1223. mgb_assert(output().size() == 4);
  1224. output(0)->format(input(0)->format());
  1225. output(1)->format(input(2)->format());
  1226. output(2)->format(input(3)->format());
  1227. }
  1228. cg::OperatorNodeBase::NodeProp* DeformableConvBackwardData::do_make_node_prop()
  1229. const {
  1230. auto prop = Super::Super::do_make_node_prop();
  1231. using D = NodeProp::DepType;
  1232. mgb_assert(input().size() == 5);
  1233. prop->reset_dep_type(input(), {D::DEV_VALUE, D::DEV_VALUE, D::DEV_VALUE,
  1234. D::DEV_VALUE, D::DEV_VALUE});
  1235. return prop;
  1236. }
  1237. void DeformableConvBackwardData::init_output_static_infer_desc() {
  1238. Super::set_nr_managed_outputs(this->output().size() - 1);
  1239. Super::init_output_static_infer_desc();
  1240. this->init_output_static_infer_desc_workspace(
  1241. intl::AutoAddWorkspaceNeedLimitGetter<
  1242. megdnn::DeformableConvBackwardData>::val);
  1243. }
  1244. /* ==================== DeformableConvBackwardFilter ==================== */
  1245. IMPL_CONV(DeformableConvBackwardFilter);
  1246. DeformableConvBackwardFilter::DeformableConvBackwardFilter(
  1247. VarNode* src, VarNode* filter, VarNode* offset, VarNode* mask,
  1248. VarNode* diff, const Param& param, const ExecutionPolicy& policy,
  1249. const OperatorNodeConfig& config)
  1250. : Super({src->owner_graph(),
  1251. config,
  1252. "deformable_conv_backward_filter",
  1253. {src, filter, offset, mask, diff}},
  1254. 1, false) {
  1255. mgb_assert(src->dtype() == dtype::Float32() and
  1256. filter->dtype() == dtype::Float32() and
  1257. offset->dtype() == dtype::Float32() and
  1258. mask->dtype() == dtype::Float32() and
  1259. diff->dtype() == dtype::Float32(),
  1260. "input should be float32, got %s, %s, %s, %s %s",
  1261. src->dtype().name(), filter->dtype().name(),
  1262. offset->dtype().name(), mask->dtype().name(),
  1263. diff->dtype().name());
  1264. init_megdnn_opr(*this, param);
  1265. m_policy = policy;
  1266. add_input({src, filter, offset, mask, diff});
  1267. }
  1268. SymbolVar DeformableConvBackwardFilter::make(SymbolVar src, SymbolVar filter,
  1269. SymbolVar offset, SymbolVar mask,
  1270. SymbolVar diff, const Param& param,
  1271. const ExecutionPolicy& policy,
  1272. const OperatorNodeConfig& config) {
  1273. return src.insert_single_output_opr<DeformableConvBackwardFilter>(
  1274. src.node(), filter.node(), offset.node(), mask.node(), diff.node(),
  1275. param, policy, config);
  1276. }
  1277. void DeformableConvBackwardFilter::scn_do_execute() {
  1278. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), // src
  1279. input(2)->dev_tensor().as_megdnn(), // offset
  1280. input(3)->dev_tensor().as_megdnn(), // mask
  1281. input(4)->dev_tensor().as_megdnn(), // diff
  1282. output(0)->dev_tensor().as_megdnn(), // filter_diff
  1283. intl::get_megdnn_workspace_from_var(output(1)));
  1284. }
  1285. size_t DeformableConvBackwardFilter::get_workspace_size_bytes(
  1286. const TensorShapeArray& input_shapes,
  1287. const TensorShapeArray& output_shapes) const {
  1288. mgb_assert(input_shapes.size() == 5 && output_shapes.size() == 1);
  1289. return AlgoChooser<megdnn::DeformableConvBackwardFilter>::setup_algo(
  1290. {TensorLayout{input_shapes[0], input(0)->dtype(),
  1291. input(0)->format()},
  1292. {input_shapes[2], input(2)->dtype(), input(2)->format()},
  1293. {input_shapes[3], input(3)->dtype(), input(3)->format()},
  1294. {input_shapes[4], input(4)->dtype(), input(4)->format()},
  1295. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1296. megdnn_opr(), this);
  1297. }
  1298. /* ==================== BatchConvBiasForward ==================== */
  1299. IMPL_CONV(BatchConvBiasForward);
  1300. BatchConvBiasForward::BatchConvBiasForward(VarNode* src, VarNode* filter,
  1301. const Param& param,
  1302. const ExecutionPolicy& policy,
  1303. const OperatorNodeConfig& config)
  1304. : Super{src->owner_graph(), config, "batch_conv_bias", {src, filter}} {
  1305. init_megdnn_opr(*this, param);
  1306. m_policy = policy;
  1307. add_input({src, filter});
  1308. }
  1309. BatchConvBiasForward::BatchConvBiasForward(VarNode* src, VarNode* filter,
  1310. VarNode* bias, const Param& param,
  1311. const ExecutionPolicy& policy,
  1312. const OperatorNodeConfig& config)
  1313. : Super{src->owner_graph(),
  1314. config,
  1315. "batch_conv_bias",
  1316. {src, filter, bias}} {
  1317. m_policy = policy;
  1318. init_megdnn_opr(*this, param);
  1319. add_input({src, filter, bias});
  1320. }
  1321. BatchConvBiasForward::BatchConvBiasForward(VarNode* src, VarNode* filter,
  1322. VarNode* bias, VarNode* z,
  1323. const Param& param,
  1324. const ExecutionPolicy& policy,
  1325. const OperatorNodeConfig& config)
  1326. : Super{src->owner_graph(),
  1327. config,
  1328. "batch_conv_bias",
  1329. {src, filter, bias, z}} {
  1330. m_policy = policy;
  1331. init_megdnn_opr(*this, param);
  1332. add_input({src, filter, bias, z});
  1333. }
  1334. void BatchConvBiasForward::add_input_layout_constraint() {
  1335. mixin::megdnn_utils::add_input_layout_constraint_contig(*this);
  1336. }
  1337. SymbolVar BatchConvBiasForward::make(SymbolVar src, SymbolVar filter,
  1338. const Param& param,
  1339. const ExecutionPolicy& policy,
  1340. const OperatorNodeConfig& config) {
  1341. return src.insert_single_output_opr<BatchConvBiasForward>(
  1342. src.node(), filter.node(), param, policy, config);
  1343. }
  1344. SymbolVar BatchConvBiasForward::make(SymbolVar src, SymbolVar filter,
  1345. SymbolVar bias, const Param& param,
  1346. const ExecutionPolicy& policy,
  1347. const OperatorNodeConfig& config) {
  1348. return src.insert_single_output_opr<BatchConvBiasForward>(
  1349. src.node(), filter.node(), bias.node(), param, policy, config);
  1350. }
  1351. SymbolVar BatchConvBiasForward::make(SymbolVar src, SymbolVar filter,
  1352. SymbolVar bias, SymbolVar z,
  1353. const Param& param,
  1354. const ExecutionPolicy& policy,
  1355. const OperatorNodeConfig& config) {
  1356. return src.insert_single_output_opr<BatchConvBiasForward>(
  1357. src.node(), filter.node(), bias.node(), z.node(), param, policy,
  1358. config);
  1359. }
  1360. void BatchConvBiasForward::init_output_dtype() {
  1361. DType output_dtype = config().output_dtype();
  1362. DType i0, i1, i2, i3;
  1363. mgb_assert(input().size() >= 2 && input().size() <= 4);
  1364. i0 = input(0)->dtype();
  1365. i1 = input(1)->dtype();
  1366. if (input().size() >= 3)
  1367. i2 = input(2)->dtype();
  1368. if (input().size() == 4)
  1369. i3 = input(3)->dtype();
  1370. megdnn_opr()->deduce_dtype(i0, i1, i2, i3, output_dtype);
  1371. output(0)->dtype(output_dtype);
  1372. }
  1373. size_t BatchConvBiasForward::get_workspace_size_bytes(
  1374. const TensorShapeArray& input_shapes,
  1375. const TensorShapeArray& output_shapes) const {
  1376. auto mo = megdnn_opr();
  1377. TensorLayout i0, i1, i2, i3;
  1378. mgb_assert(input_shapes.size() >= 2 && input_shapes.size() <= 4);
  1379. i0 = {input_shapes[0], input(0)->dtype(), input(0)->format()};
  1380. i1 = {input_shapes[1], input(1)->dtype(), input(1)->format()};
  1381. if (input_shapes.size() >= 3)
  1382. i2 = {input_shapes[2], input(2)->dtype(), input(2)->format()};
  1383. else {
  1384. DType dtype;
  1385. mo->deduce_dtype(input(0)->dtype(), input(1)->dtype(), DType{}, DType{},
  1386. dtype);
  1387. i2 = {{}, dtype};
  1388. }
  1389. if (input_shapes.size() == 4)
  1390. i3 = {input_shapes[3], input(3)->dtype(), input(3)->format()};
  1391. else
  1392. i3 = {{}, output(0)->dtype(), output(0)->format()};
  1393. return AlgoChooser<megdnn::BatchConvBias>::setup_algo(
  1394. {i0,
  1395. i1,
  1396. i2,
  1397. i3,
  1398. {output_shapes[0], output(0)->dtype(), output(0)->format()}},
  1399. mo, this);
  1400. }
  1401. void BatchConvBiasForward::scn_do_execute() {
  1402. auto&& inp = input();
  1403. auto mo = megdnn_opr();
  1404. if (inp.size() == 2) {
  1405. TensorLayout bias_layout;
  1406. bias_layout.ndim = 0;
  1407. if (output(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
  1408. bias_layout.dtype = dtype::QuantizedS32(
  1409. output(0)->dtype().param<dtype::QuantizedS8>().scale);
  1410. } else {
  1411. bias_layout.dtype = output(0)->dtype();
  1412. }
  1413. TensorLayout z_layout;
  1414. z_layout.ndim = 0;
  1415. z_layout.dtype = output(0)->dtype();
  1416. megdnn::TensorND bias_tensor{nullptr, bias_layout};
  1417. megdnn::TensorND z_tensor{nullptr, z_layout};
  1418. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  1419. inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor,
  1420. output(0)->dev_tensor().as_megdnn(),
  1421. intl::get_megdnn_workspace_from_var(output().back()));
  1422. } else if (inp.size() == 3) {
  1423. TensorLayout z_layout;
  1424. z_layout.ndim = 0;
  1425. z_layout.dtype = output(0)->dtype();
  1426. megdnn::TensorND z_tensor{nullptr, z_layout};
  1427. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  1428. inp[1]->dev_tensor().as_megdnn(),
  1429. inp[2]->dev_tensor().as_megdnn(), z_tensor,
  1430. output(0)->dev_tensor().as_megdnn(),
  1431. intl::get_megdnn_workspace_from_var(output().back()));
  1432. } else {
  1433. mgb_assert(inp.size() == 4);
  1434. mo->exec(inp[0]->dev_tensor().as_megdnn(),
  1435. inp[1]->dev_tensor().as_megdnn(),
  1436. inp[2]->dev_tensor().as_megdnn(),
  1437. inp[3]->dev_tensor().as_megdnn(),
  1438. output(0)->dev_tensor().as_megdnn(),
  1439. intl::get_megdnn_workspace_from_var(output().back()));
  1440. }
  1441. }
  1442. void BatchConvBiasForward::get_output_var_shape(
  1443. const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
  1444. auto mo = megdnn_opr();
  1445. TensorLayout dst;
  1446. mo->deduce_layout({inp_shape[0], input(0)->dtype(), input(0)->format()},
  1447. {inp_shape[1], input(1)->dtype(), input(0)->format()}, {},
  1448. {}, dst);
  1449. out_shape[0] = dst;
  1450. }
  1451. void BatchConvBiasForward::init_output_static_infer_desc() {
  1452. Super::set_nr_managed_outputs(this->output().size() - 1);
  1453. Super::init_output_static_infer_desc();
  1454. this->init_output_static_infer_desc_workspace(
  1455. intl::AutoAddWorkspaceNeedLimitGetter<
  1456. megdnn::BatchConvBiasForward>::val);
  1457. }
  1458. void BatchConvBiasForward::init_output_format() {
  1459. mgb_assert(output().size() == 2);
  1460. output(0)->format(input(0)->format());
  1461. }
  1462. #undef IMPL_CONV
  1463. #undef MGB_FOREACH_FASTRUN_OPR
  1464. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台