You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.h 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. /**
  2. * \file src/core/include/megbrain/graph/helper.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megbrain/graph/cg.h"
  13. #include <vector>
  14. namespace mgb {
  15. namespace cg {
  16. class OperatorNodeBase;
  17. class VarNode;
  18. /*!
  19. * \brief get the involved comp nodes of an operator; the operator must have
  20. * been compiled
  21. */
  22. CompNode::UnorderedSet get_opr_comp_node_set(OperatorNodeBase *opr);
  23. /*!
  24. * \brief whether var shape could be statically inferred
  25. */
  26. static inline bool is_static_var_shape(VarNode *var) {
  27. using IT = static_infer::InferType;
  28. auto it = var->owner_graph()->static_infer_manager().
  29. get_infer_type(var);
  30. return it.shape & (IT::CONST | IT::RT_STATIC);
  31. }
  32. /*!
  33. * \brief whether var shape is constant
  34. */
  35. static inline bool is_const_var_shape(VarNode *var) {
  36. using IT = static_infer::InferType;
  37. auto it = var->owner_graph()->static_infer_manager().
  38. get_infer_type(var);
  39. return it.shape & IT::CONST;
  40. }
  41. /*!
  42. * \brief whether var value could be statically inferred
  43. */
  44. static inline bool is_static_var_value(VarNode *var) {
  45. using IT = static_infer::InferType;
  46. auto it = var->owner_graph()->static_infer_manager().
  47. get_infer_type(var);
  48. return it.value & (IT::CONST | IT::RT_STATIC);
  49. }
  50. /*!
  51. * \brief whether var value is constant
  52. */
  53. static inline bool is_const_var_value(VarNode* var) {
  54. using IT = static_infer::InferType;
  55. auto&& mgr = var->owner_graph()->static_infer_manager();
  56. auto infer_type = mgr.get_infer_type(var);
  57. if (!(infer_type.value & IT::CONST))
  58. return false;
  59. mgb_assert(infer_type.shape & IT::CONST,
  60. "var(%s) has const value infer but non-const shape infer",
  61. var->cname());
  62. return true;
  63. }
  64. /*!
  65. * \brief whether var storage would be statically allocated by system
  66. */
  67. static inline bool is_static_var_storage(VarNode *var) {
  68. using F = VarNode::Flag;
  69. if (var->contain_flag(F::PERSISTENT_DEVICE_VALUE))
  70. return true;
  71. if (var->contain_flag(
  72. F::RT_FORCE_DYNAMIC_MEM_ALLOC | F::NO_SYS_MEM_ALLOC |
  73. F::NO_SYS_STATIC_MEM_ALLOC))
  74. return false;
  75. return is_static_var_shape(var);
  76. }
  77. /*!
  78. * \brief whether device computing is needed for given input var and dep type of
  79. * an operator
  80. *
  81. * See the code for precise definition
  82. */
  83. static inline bool need_device_computing_on_var(
  84. VarNode *var, OperatorNodeBase::NodeProp::DepType dt) {
  85. using DT = OperatorNodeBase::NodeProp::DepType;
  86. return !var->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) &&
  87. ((dt & (DT::DEV_VALUE | DT::DEV_COMP_ORDER)) ||
  88. ((dt & DT::HOST_VALUE) && !is_static_var_value(var)) ||
  89. ((dt & DT::SHAPE) && is_static_var_shape(var)));
  90. }
  91. /*!
  92. * \brief whether all input vars of an operator has static storage
  93. */
  94. bool is_all_input_static_storage(OperatorNodeBase* opr);
  95. /*!
  96. * \brief transform a SymbolVarArray to a VarNodeArray
  97. */
  98. VarNodeArray to_var_node_array(const SymbolVarArray& symbol_var_array);
  99. /*!
  100. * \brief transform a VarNodeArray to a SymbolVarArray
  101. */
  102. SymbolVarArray to_symbol_var_array(const VarNodeArray& var_node_array);
  103. /*!
  104. * \brief return a string to describe the list of variables
  105. */
  106. std::string dump_var_info(const VarNodeArrayView &vars);
  107. /*!
  108. * \brief compute grad of target w.r.t. wrt (i.e. d(target)/d(wrt))
  109. * \param warn_mid_wrt whether to give warning on wrt not being end-point var
  110. * \param return_zero_for_nodep if *target* does not depend on *wrt*, return a
  111. * zero-valued var rather than a null var
  112. * \return the var representing grad, or nullptr if target does not depend on
  113. * wrt
  114. */
  115. SymbolVar grad(SymbolVar target, SymbolVar wrt,
  116. bool warn_mid_wrt = true, bool return_zero_for_nodep = true);
  117. /*!
  118. * \brief equivalant to calling grad(grad, wrt) one by one if symbolic;
  119. * since cache in grad manager would be cleared each time, this method is more
  120. * efficient if eager.
  121. */
  122. SymbolVarArray grad(SymbolVar target, SymbolVarArray wrts,
  123. bool warn_mid_wrt = true, bool return_zero_for_nodep = true);
  124. /*!
  125. * \brief get current grad target, which must be called inside
  126. * OperatorNodeBase::grad() implementations
  127. */
  128. SymbolVar current_grad_target(ComputingGraph &graph);
  129. struct SpecialOprStat {
  130. bool has_virtual_grad = false;
  131. bool has_shape_hint = false;
  132. };
  133. /*!
  134. * \brief replace variables in a graph
  135. * \param dest target vars to describe the graph
  136. * \param varmap map that describes how to replace an old var with a new var
  137. * \return a list of vars correpsonding to \p dest whose dependencies have been
  138. * replaced according to \p varmap
  139. */
  140. SymbolVarArray replace_vars(const SymbolVarArray &dest,
  141. const ThinHashMap<SymbolVar, SymbolVar>& varmap);
  142. /*!
  143. * \brief replace operator in a graph
  144. * \param dest target vars to describe the graph
  145. * \param oprmap map that describes how to replace an old operator with a new
  146. * operator
  147. * \return a list of vars correpsonding to \p dest whose dependencies have been
  148. * replaced according to \p oprmap
  149. */
  150. SymbolVarArray replace_oprs(
  151. const SymbolVarArray& dest,
  152. const ThinHashMap<OperatorNodeBase*, OperatorNodeBase*>& oprmap);
  153. /*!
  154. * \brief replace computing graph which owns all variables to another graph
  155. * \param dest target vars to describe the graph
  156. * \param new_graph target computing graph
  157. * \return a list of vars correpsonding to \p dest whose owner_graph have been
  158. * replaced with \p new_graph
  159. */
  160. SymbolVarArray replace_vars_comp_graph(
  161. const SymbolVarArray &dest, ComputingGraph* new_graph);
  162. SymbolVarArray find_h2d(const SymbolVarArray& dest);
  163. /*!
  164. * \brief go through OperatorNodeBase::NodeProp::Attribute::src_opr until it
  165. * becomes nullptr
  166. *
  167. * This function also performs path compression
  168. */
  169. OperatorNodeBase* get_opr_root_source_opr(OperatorNodeBase *opr);
  170. //! describes how two mem plans intersect
  171. enum class MemPlanIntersectionType {
  172. DISJOINT, //!< no intersection
  173. IDENTICAL, //!< completely same
  174. OVERLAP //!< intersects but not identical
  175. };
  176. MemPlanIntersectionType get_mem_plan_intersection_type(VarNode* a, VarNode *b);
  177. /*!
  178. * \brief request output var to writable forward input var if no mem plan of
  179. * other input vars intersects with this input var
  180. */
  181. void request_fwd_in2out_writable_if_no_mem_ovelap(
  182. OperatorNodeBase *opr, size_t inp, size_t out);
  183. /*!
  184. * \brief update shapes of output vars; set to empty if not statically
  185. * inferable
  186. *
  187. * This method must always be called if a new operator is inserted (currently
  188. * used in ComputingGraph::insert_opr and copy_opr_shallow)
  189. *
  190. * Note: implemented in cg_impl.cpp, since it is used during graph init
  191. */
  192. void update_output_var_shapes(OperatorNodeBase *opr);
  193. /*!
  194. * \brief add an output to be used as the workspace for an operator
  195. *
  196. * The workspace var would have dtype Byte.
  197. *
  198. * This helper is usually called from an opr constructor and used for adding the
  199. * last output.
  200. */
  201. void add_workspace_output(OperatorNodeBase *opr);
  202. /*!
  203. * \brief copy a raw tensor shape into a host tensor
  204. */
  205. void copy_shape_to_tensor_value(DeviceTensorND &dest, const TensorShape &shp);
  206. /*!
  207. * \brief copy value of a host tensor into a raw tensor shape
  208. */
  209. void copy_tensor_value_to_shape(TensorShape &dest, const DeviceTensorND &val);
  210. /*!
  211. * \brief get a symbolvar whose value is tensor shape, used for other
  212. * operators
  213. *
  214. * \param opr_name operator that invokes this function; used in error
  215. * function if *config* is invalid
  216. */
  217. SymbolVar var_from_tensor_shape(
  218. ComputingGraph &graph, const OperatorNodeConfig &config,
  219. const char *opr_name,
  220. const TensorShape &shape);
  221. /*!
  222. * \brief get a symbolvar whose value is tensor shape
  223. *
  224. * \param inp used to determine the computing graph, which can be any symbolvar
  225. * belonging to the same computing graph.
  226. */
  227. static inline SymbolVar var_from_tensor_shape(
  228. SymbolVar inp, const TensorShape &shape) {
  229. return var_from_tensor_shape(*inp.node()->owner_graph(),
  230. OperatorNodeConfig().follow_comp_node(inp),
  231. nullptr, shape);
  232. }
  233. /*!
  234. * \brief iterate over all dependency oprs in topological order
  235. * \param cb callback to be invoked when a new operator is discovered
  236. */
  237. class DepOprIter {
  238. public:
  239. using Callback = thin_function<void(OperatorNodeBase*)>;
  240. using ExtraDep = ThinHashMap<OperatorNodeBase*, SmallVector<VarNode*>>;
  241. explicit DepOprIter(Callback cb,
  242. std::shared_ptr<ExtraDep> extra_dep = nullptr)
  243. : m_cb{std::move(cb)}, m_extra_dep(std::move(extra_dep)) {}
  244. //! add an operator whose deps should be discovered
  245. void add(OperatorNodeBase *dest);
  246. void add(SymbolVar var) { add(var.node()->owner_opr()); }
  247. //! graph of all the oprs
  248. ComputingGraph* owner_graph() const {
  249. return m_owner_graph;
  250. }
  251. //! check if an opr has been visited
  252. bool visited(OperatorNodeBase *opr) const {
  253. return m_visited.count(opr);
  254. }
  255. //! set an opr to have been visited
  256. DepOprIter& set_visited(OperatorNodeBase* opr) {
  257. m_visited.insert(opr);
  258. return *this;
  259. }
  260. private:
  261. //! a single stack frame to avoid recursion
  262. struct Frame {
  263. OperatorNodeBase *opr;
  264. VarNode * const *inputs;
  265. VarNode * const *extra_deps;
  266. size_t inp_idx, nr_input, nr_extra_dep;
  267. };
  268. ComputingGraph *m_owner_graph = nullptr;
  269. std::vector<Frame> m_stack;
  270. ThinHashSet<OperatorNodeBase*> m_visited;
  271. Callback m_cb;
  272. const std::shared_ptr<ExtraDep> m_extra_dep;
  273. inline void push_stack(OperatorNodeBase *opr);
  274. };
  275. /*!
  276. * \brief a user data associated with ComputingGraph::Options::user_data
  277. *
  278. * When a graph A is copied as a new graph B, the module that initiates the copy
  279. * may associate an instance of InterGraphVarTransformer with user data of B, so
  280. * when B is exetended (e.g. by constructing a grad graph), others can know how
  281. * to transform a var in A into its equivalent var in B.
  282. */
  283. class InterGraphVarTransformer final: public UserDataContainer::UserData {
  284. MGB_TYPEINFO_OBJ_DECL;
  285. InterGraphVarTransformer() = default;
  286. public:
  287. /*!
  288. * var transforming function to be defined by copier; the input var has
  289. * been checked to be in src graph.
  290. */
  291. using TransFunc = thin_function<VarNode*(VarNode*)>;
  292. /*!
  293. * \brief register a transfomer to *dest* graph that takes var in *src*
  294. * and outputs a corresponding var in *dest*
  295. *
  296. * This function should be called only once on a graph
  297. */
  298. static void register_to(ComputingGraph *dest,
  299. const ComputingGraph *src, const TransFunc &trans);
  300. /*!
  301. * \brief get the transformer associated with a graph
  302. * \return previously registered transformer on given graph or nullptr
  303. * if none registered
  304. */
  305. static const InterGraphVarTransformer* get(const ComputingGraph &graph);
  306. /*!
  307. * \brief transform a var into this graph
  308. */
  309. VarNode *trans(VarNode *src) const;
  310. private:
  311. ComputingGraph *m_graph_dest;
  312. const ComputingGraph *m_graph_src;
  313. TransFunc m_trans_func;
  314. };
  315. /*!
  316. * \brief find extra dependency of vars (ComputingGraph::Options::extra_vardeps)
  317. * and merge into a var list
  318. */
  319. class ExtraDependencyMerger {
  320. SpecialOprStat* const m_sopr_stat;
  321. VarNodeArray m_new_deps;
  322. DepOprIter m_opr_iter;
  323. SymbolVarArray m_result;
  324. ComputingGraph* m_owner_graph = nullptr;
  325. void on_opr(OperatorNodeBase* opr);
  326. public:
  327. explicit ExtraDependencyMerger(SpecialOprStat* sopr_stat = nullptr);
  328. ~ExtraDependencyMerger();
  329. /*!
  330. * \brief add a new set of vars
  331. * \return current var list after adding this vars. It keeps growing.
  332. *
  333. * Note: \p vars given here would always be added to the result list, even
  334. * if they duplicate existing vars.
  335. *
  336. * \return vars with extra dependency; the returned list can be modified
  337. */
  338. SymbolVarArray& add(const SymbolVarArray& vars);
  339. };
  340. //! shortcut for calling ExtraDependencyMerger
  341. SymbolVarArray get_dest_vars_with_extra_deps(
  342. const SymbolVarArray& dest_vars, SpecialOprStat* sopr_stat = nullptr);
  343. } // cg
  344. } //mgb
  345. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台