You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

megdnn_opr_wrapper.inl 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. /**
  2. * \file src/opr/impl/internal/megdnn_opr_wrapper.inl
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  13. namespace mgb {
  14. namespace opr {
  15. namespace intl {
  16. /*!
  17. * \brief template that can be specialized so inputs of an operator could be
  18. * modified in-place
  19. *
  20. * Invoked by MEGDNN_OPR_INIT* macros
  21. *
  22. * \tparam Opr an megbrain opr final class
  23. */
  24. template<class Opr>
  25. struct MegDNNOprInitInputsModifier {
  26. static inline void apply(const typename Opr::Param &param,
  27. std::initializer_list<SymbolVar*> inputs) {
  28. MGB_MARK_USED_VAR(param);
  29. MGB_MARK_USED_VAR(inputs);
  30. }
  31. };
  32. /*!
  33. * \brief template that can be specialized to be called in opr constructor
  34. *
  35. * Invoked by MEGDNN_OPR_INIT* macros
  36. */
  37. template<class Opr>
  38. struct MegDNNOprInitPostCtor {
  39. static inline void apply(cg::OperatorNodeBase &opr) {
  40. MGB_MARK_USED_VAR(opr);
  41. }
  42. };
  43. //! get megdnn Workspace object from a workspace var
  44. megdnn::Workspace get_megdnn_workspace_from_var(VarNode *var);
  45. /*!
  46. * \brief A UserData object associated with the computing graph to get
  47. * maximal usable workspace.
  48. *
  49. * It works by first limit workspace to 0 and alloc to get free memory, and
  50. * assume workspace can use all free memory.
  51. * It would produce a var node, which should be taken as a value dep for
  52. * workspace static infer functors so memory manager can re-allocate.
  53. */
  54. class WorkspaceLimitGetter {
  55. class Impl;
  56. static Impl* get_impl(ComputingGraph *graph);
  57. public:
  58. /*!
  59. * \brief get usable workspace size in bytes for a comp node
  60. *
  61. * Can only be called after is_prealloc_run() returns false
  62. *
  63. * \param old_limit workspace limit set by user, which would be an
  64. * upper bound for the return value
  65. */
  66. static size_t get_workspace_limit(
  67. ComputingGraph *graph, CompNode cn, size_t old_limit);
  68. //! return whether current is pre-allocation so workspace should
  69. //! return 0
  70. static bool is_prealloc_run(ComputingGraph *graph);
  71. /*!
  72. * \brief register WorkspaceLimitGetter in a graph
  73. * \return an var to be added as extra value dep for workspace
  74. * infer; it would be null if WorkspaceLimitGetter is disabled
  75. * at compile time
  76. */
  77. static VarNode* register_to_graph(ComputingGraph *graph);
  78. };
  79. /*!
  80. * a template that can be specialized to indicate whether
  81. * WorkspaceLimitGetter is needed for an operator class
  82. *
  83. * \tparam MegDNNOpr a megdnn opr class
  84. */
  85. template<class MegDNNOpr>
  86. struct AutoAddWorkspaceNeedLimitGetter {
  87. static constexpr bool val = false;
  88. };
  89. /*!
  90. * \brief implement megdnn::DynOutMallocPolicy using memory management
  91. * system in megbrain
  92. */
  93. class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy {
  94. cg::OperatorNodeBase *m_opr;
  95. CompNode m_cn;
  96. public:
  97. MegDNNDynOutMallocImpl(cg::OperatorNodeBase *opr, CompNode cn):
  98. m_opr{opr}, m_cn{cn}
  99. {}
  100. megdnn::TensorND alloc_output(
  101. size_t id, DType dtype, const TensorShape &shape,
  102. void *user_data) override;
  103. void* alloc_workspace(size_t sz, void *user_data) override;
  104. void free_workspace(void *ptr, void *user_data) override;
  105. };
  106. /* ======================= MegDNNOprMethInvoker ======================= */
  107. namespace {
  108. template<int nr_in, int nr_out>
  109. struct _MegDNNOprMethInvoker;
  110. template<class Opr>
  111. using MegDNNOprMethInvoker =
  112. _MegDNNOprMethInvoker<Opr::NR_INPUTS, Opr::NR_OUTPUTS>;
  113. #define _NR_INPUTS 1
  114. #define _NR_OUTPUTS 1
  115. #define _FOREACH_IO(_i, _o) _i(0), _o(0)
  116. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  117. #define _NR_INPUTS 1
  118. #define _NR_OUTPUTS 2
  119. #define _FOREACH_IO(_i, _o) _i(0), _o(0), _o(1)
  120. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  121. #define _NR_INPUTS 1
  122. #define _NR_OUTPUTS 3
  123. #define _FOREACH_IO(_i, _o) _i(0), _o(0), _o(1), _o(2)
  124. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  125. #define _NR_INPUTS 2
  126. #define _NR_OUTPUTS 1
  127. #define _FOREACH_IO(_i, _o) _i(0), _i(1), _o(0)
  128. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  129. #define _NR_INPUTS 2
  130. #define _NR_OUTPUTS 2
  131. #define _FOREACH_IO(_i, _o) _i(0), _i(1), _o(0), _o(1)
  132. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  133. #define _NR_INPUTS 3
  134. #define _NR_OUTPUTS 1
  135. #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0)
  136. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  137. #define _NR_INPUTS 3
  138. #define _NR_OUTPUTS 2
  139. #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1)
  140. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  141. #define _NR_INPUTS 3
  142. #define _NR_OUTPUTS 3
  143. #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _o(0), _o(1), _o(2)
  144. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  145. #define _NR_INPUTS 4
  146. #define _NR_OUTPUTS 1
  147. #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _o(0)
  148. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  149. #define _NR_INPUTS 5
  150. #define _NR_OUTPUTS 2
  151. #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1)
  152. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  153. #define _NR_INPUTS 5
  154. #define _NR_OUTPUTS 3
  155. #define _FOREACH_IO(_i, _o) _i(0), _i(1), _i(2), _i(3), _i(4), _o(0), _o(1), _o(2)
  156. #include "./megdnn_opr_wrapper_megdnn_opr_meth_invoker_impl.inl"
  157. } // anonymous namespace
  158. /* ======================= MegDNNOprWrapperFwd ======================= */
  159. template<class MegDNNOpr>
  160. void MegDNNOprWrapperFwd<MegDNNOpr>::init_output_static_infer_desc() {
  161. Super::set_nr_managed_outputs(this->output().size() - 1);
  162. Super::init_output_static_infer_desc();
  163. this->init_output_static_infer_desc_workspace(
  164. AutoAddWorkspaceNeedLimitGetter<MegDNNOpr>::val);
  165. }
  166. template<class MegDNNOpr>
  167. void MegDNNOprWrapperFwd<MegDNNOpr>::scn_do_execute() {
  168. MegDNNOprMethInvoker<MegDNNOpr>::exec(this->megdnn_opr(), this);
  169. }
  170. template<class MegDNNOpr>
  171. size_t MegDNNOprWrapperFwd<MegDNNOpr>::get_workspace_size_bytes(
  172. const TensorShapeArray &input_shapes,
  173. const TensorShapeArray &output_shapes) const {
  174. return this->mixin_get_workspace_size_bytes_by_megdnn(
  175. *this, input_shapes, output_shapes);
  176. }
  177. template<class MegDNNOpr>
  178. void MegDNNOprWrapperFwd<MegDNNOpr>::get_output_var_shape(
  179. const TensorShapeArray &inp_shape,
  180. TensorShapeArray &out_shape) const {
  181. MegDNNOprMethInvoker<MegDNNOpr>::deduce_layout(
  182. this->megdnn_opr(), this, inp_shape, out_shape);
  183. }
  184. /* ======================= MegDNNOprWrapperBwd ======================= */
  185. template<class MegDNNOpr>
  186. void MegDNNOprWrapperBwd<MegDNNOpr>::init_output_static_infer_desc() {
  187. this->mixin_init_output_static_infer_desc_bwd(*this);
  188. this->init_output_static_infer_desc_workspace(
  189. AutoAddWorkspaceNeedLimitGetter<MegDNNOpr>::val);
  190. }
  191. template<class MegDNNOpr>
  192. void MegDNNOprWrapperBwd<MegDNNOpr>::scn_do_execute() {
  193. MegDNNOprMethInvoker<MegDNNOpr>::exec(this->megdnn_opr(), this);
  194. }
  195. template<class MegDNNOpr>
  196. size_t MegDNNOprWrapperBwd<MegDNNOpr>::get_workspace_size_bytes(
  197. const TensorShapeArray &input_shapes,
  198. const TensorShapeArray &output_shapes) const {
  199. return this->mixin_get_workspace_size_bytes_by_megdnn(
  200. *this, input_shapes, output_shapes);
  201. }
  202. template<class MegDNNOpr>
  203. typename MegDNNOprWrapperBwd<MegDNNOpr>::Super::NodeProp*
  204. MegDNNOprWrapperBwd<MegDNNOpr>::do_make_node_prop() const {
  205. auto prop = Super::do_make_node_prop();
  206. this->mixin_update_node_prop(*this, prop);
  207. return prop;
  208. }
  209. } // nmamespace intl
  210. namespace mixin {
  211. /* ======================= MegDNNOprHolderImpl ======================= */
  212. template<class MegDNNOpr, bool add_workspace, class OprHolder>
  213. size_t MegDNNOprHolderImpl<MegDNNOpr, add_workspace, OprHolder>::
  214. mixin_get_workspace_size_bytes_by_megdnn(
  215. const OperatorNodeBase &opr,
  216. const TensorShapeArray &input_shapes,
  217. const TensorShapeArray &output_shapes) const {
  218. static_assert(add_workspace, "must add_workspace");
  219. return intl::MegDNNOprMethInvoker<MegDNNOpr>::get_workspace_in_bytes(
  220. this->megdnn_opr(), &opr, input_shapes, output_shapes);
  221. }
  222. }
  223. } // namespace opr
  224. } // namespace mgb
  225. //! generate opr constructor, with 1 arg
  226. #define MEGDNN_OPR_CTOR_INIT1(_name, _node_name, ...) \
  227. _name::_name(VarNode *i0, \
  228. const Param &param, const OperatorNodeConfig &config): \
  229. Super(OperatorNodeBaseCtorParam{ \
  230. i0->owner_graph(), config, _node_name, {i0}} ,##__VA_ARGS__) \
  231. { \
  232. init_megdnn_opr(*this, param); \
  233. add_input({i0}); \
  234. intl::MegDNNOprInitPostCtor<_name>::apply(*this); \
  235. }
  236. //! generate opr constructor and ::make, with 1 arg
  237. #define MEGDNN_OPR_INIT1(_name, _node_name, ...) \
  238. MEGDNN_OPR_CTOR_INIT1(_name, _node_name ,##__VA_ARGS__) \
  239. SymbolVar _name::make(SymbolVar i0, \
  240. const Param &param, const OperatorNodeConfig &config) { \
  241. intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0}); \
  242. return i0.insert_single_output_opr<_name>( \
  243. i0.node(), param, config); \
  244. }
  245. //! generate opr constructor, with 2 args
  246. #define MEGDNN_OPR_CTOR_INIT2(_name, _node_name, ...) \
  247. _name::_name(VarNode *i0, VarNode *i1, \
  248. const Param &param, const OperatorNodeConfig &config): \
  249. Super(OperatorNodeBaseCtorParam{ \
  250. i0->owner_graph(), config, _node_name, {i0}} ,##__VA_ARGS__) \
  251. { \
  252. init_megdnn_opr(*this, param); \
  253. add_input({i0, i1}); \
  254. intl::MegDNNOprInitPostCtor<_name>::apply(*this); \
  255. }
  256. //! generate opr constructor and ::make, with 2 args
  257. #define MEGDNN_OPR_INIT2(_name, _node_name, ...) \
  258. MEGDNN_OPR_CTOR_INIT2(_name, _node_name ,##__VA_ARGS__) \
  259. SymbolVar _name::make(SymbolVar i0, SymbolVar i1, \
  260. const Param &param, const OperatorNodeConfig &config) { \
  261. intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1}); \
  262. return i0.insert_single_output_opr<_name>( \
  263. i0.node(), i1.node(), param, config); \
  264. }
  265. //! generate opr constructor, with 3 args
  266. #define MEGDNN_OPR_CTOR_INIT3(_name, _node_name, ...) \
  267. _name::_name(VarNode *i0, VarNode *i1, VarNode *i2, \
  268. const Param &param, const OperatorNodeConfig &config): \
  269. Super(OperatorNodeBaseCtorParam{ \
  270. i0->owner_graph(), config, _node_name, {i0}} ,##__VA_ARGS__) \
  271. { \
  272. init_megdnn_opr(*this, param); \
  273. add_input({i0, i1, i2}); \
  274. intl::MegDNNOprInitPostCtor<_name>::apply(*this); \
  275. }
  276. //! generate opr constructor and ::make, with 3 args
  277. #define MEGDNN_OPR_INIT3(_name, _node_name, ...) \
  278. MEGDNN_OPR_CTOR_INIT3(_name, _node_name ,##__VA_ARGS__) \
  279. SymbolVar _name::make(SymbolVar i0, SymbolVar i1, SymbolVar i2, \
  280. const Param &param, const OperatorNodeConfig &config) { \
  281. intl::MegDNNOprInitInputsModifier<_name>::apply(param, {&i0, &i1, &i2}); \
  282. return i0.insert_single_output_opr<_name>( \
  283. i0.node(), i1.node(), i2.node(), param, config); \
  284. }
  285. //! generate opr constructor, with 4 args
  286. #define MEGDNN_OPR_CTOR_INIT4(_name, _node_name, ...) \
  287. _name::_name(VarNode *i0, VarNode *i1, VarNode *i2, VarNode *i3, \
  288. const Param &param, const OperatorNodeConfig &config): \
  289. Super(OperatorNodeBaseCtorParam{ \
  290. i0->owner_graph(), config, _node_name, {i0}} ,##__VA_ARGS__) \
  291. { \
  292. init_megdnn_opr(*this, param); \
  293. add_input({i0, i1, i2, i3}); \
  294. intl::MegDNNOprInitPostCtor<_name>::apply(*this); \
  295. }
  296. //! generate opr constructor and ::make, with 4 args
  297. #define MEGDNN_OPR_INIT4(_name, _node_name, ...) \
  298. MEGDNN_OPR_CTOR_INIT4(_name, _node_name ,##__VA_ARGS__) \
  299. SymbolVar _name::make(SymbolVar i0, SymbolVar i1, SymbolVar i2, SymbolVar i3, \
  300. const Param &param, const OperatorNodeConfig &config) { \
  301. intl::MegDNNOprInitInputsModifier<_name>::apply( \
  302. param, {&i0, &i1, &i2, &i3}); \
  303. return i0.insert_single_output_opr<_name>( \
  304. i0.node(), i1.node(), i2.node(), i3.node(), param, config); \
  305. }
  306. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台