You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_graph_tools.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. """tools for graph manipulation"""
  10. import collections
  11. from . import mgb as _mgb
  12. def get_dep_vars(var, var_type=None):
  13. """return :class:`.SymbolVar` of type ``var_type`` that input ``var``
  14. depands on. If ``var_type`` is None, return all types.
  15. :type var: an instance or iterable of :class:`.SymbolVar`
  16. :type var_type: ``str`` or an iterable of ``str``
  17. "rtype: list of :class:`.SymbolVar`
  18. """
  19. outputs = []
  20. memo = set()
  21. if isinstance(var, _mgb.SymbolVar):
  22. var = [var]
  23. if isinstance(var_type, str):
  24. var_type = [var_type]
  25. q = list(var)
  26. while q:
  27. v = q.pop()
  28. if v in memo:
  29. continue
  30. memo.add(v)
  31. q.extend(get_inputs(v))
  32. if var_type is not None:
  33. if get_type(v) in var_type:
  34. outputs.append(v)
  35. else:
  36. outputs.append(v)
  37. return outputs
  38. def get_inputs(var):
  39. """get the inputs of owner opr of a variable
  40. :type var: :class:`.SymbolVar`
  41. :rtype: list of :class:`.SymbolVar`
  42. """
  43. assert isinstance(var, _mgb.SymbolVar)
  44. return _mgb._get_owner_opr_inputs(var)
  45. def get_type(var):
  46. """get the type of owner opr of a variable
  47. :type var: :class:`.SymbolVar`
  48. :rtype: ``str``
  49. """
  50. assert isinstance(var, _mgb.SymbolVar)
  51. return _mgb._get_owner_opr_type(var)
  52. def get_opr_type(opr):
  53. """get the type of a opr
  54. :type var: :class:`.Operator`
  55. :rtype: ``str``
  56. """
  57. assert isinstance(opr, _mgb.Operator)
  58. return _mgb._get_opr_type(opr)
  59. def graph_traversal(outputs):
  60. """helper function to traverse the computing graph and reeturn enough useful information
  61. :param outputs: model outputs
  62. :type outputs: :class:`.Symbolvar`
  63. :return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
  64. WHERE
  65. map_oprs is dict from opr_id to actual opr
  66. map_vars is dict from var_id to actual var
  67. var2oprs is dict from var to dest oprs along with index
  68. opr2receivers is dict from current opr to next opr
  69. indegree2opr is dict from in_degree to opr in computing graph
  70. opr2indegree is dict from opr in computing graph to in_degree
  71. (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
  72. """
  73. # meta information for comp graph
  74. map_oprs = collections.defaultdict(set)
  75. map_vars = collections.defaultdict(set)
  76. var2oprs = collections.defaultdict(list)
  77. opr2receivers = collections.defaultdict(list)
  78. queue = list(map(lambda x: x.owner_opr, outputs))
  79. visited = set(map(lambda x: x.id, queue))
  80. # iterate through whole comp_graph, fill in meta information
  81. indegree2opr = collections.defaultdict(set)
  82. opr2indegree = {}
  83. idx = 0
  84. while idx < len(queue):
  85. cur_opr = queue[idx]
  86. map_oprs[cur_opr.id] = cur_opr
  87. idx += 1
  88. indegree = 0
  89. for var_idx, var in enumerate(cur_opr.inputs):
  90. map_vars[var.id] = var
  91. var2oprs[var.id].append((cur_opr.id, var_idx))
  92. pre_opr = var.owner_opr
  93. if pre_opr.id not in visited:
  94. visited.add(pre_opr.id)
  95. queue.append(pre_opr)
  96. indegree += 1
  97. opr2receivers[pre_opr.id].append(cur_opr.id)
  98. indegree2opr[indegree].add(cur_opr.id)
  99. opr2indegree[cur_opr.id] = indegree
  100. return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
  101. def get_oprs_seq(outputs, prune_reshape=False):
  102. """get oprs in some topological order for a dumped model
  103. :param outputs: model outputs
  104. :param prune_reshape: whether to prune the operators useless during inference
  105. :return: opr list with some correct execution order
  106. """
  107. def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
  108. # generate an execution order with topological sort algorithm
  109. oprs_seq = []
  110. nr_remain = len(map_oprs)
  111. while indegree2opr[0]:
  112. opr_id = indegree2opr[0].pop()
  113. opr = map_oprs[opr_id]
  114. nr_remain -= 1
  115. # skip const value generation operator
  116. if get_opr_type(opr) != "ImmutableTensor":
  117. oprs_seq.append(opr)
  118. for post_id in opr2receivers[opr_id]:
  119. indegree = opr2indegree[post_id]
  120. indegree2opr[indegree].remove(post_id)
  121. indegree -= 1
  122. indegree2opr[indegree].add(post_id)
  123. opr2indegree[post_id] = indegree
  124. assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
  125. nr_remain
  126. )
  127. return oprs_seq
  128. # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
  129. # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
  130. def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
  131. def iterative_pruning(cur_opr, post_opr, marked_opr_ids):
  132. useless = True
  133. for oup in cur_opr.outputs:
  134. if "workspace" not in oup.name:
  135. var_idx = post_opr.inputs.index(oup)
  136. var2oprs[oup.id].remove((post_opr.id, var_idx))
  137. useless = useless and (len(var2oprs[oup.id]) == 0)
  138. if useless:
  139. marked_opr_ids.append(cur_opr.id)
  140. for inp in cur_opr.inputs:
  141. iterative_pruning(inp.owner_opr, cur_opr, marked_opr_ids)
  142. reshape_vars = get_dep_vars(outputs, "Reshape")
  143. reshape_oprs = [var.owner_opr for var in reshape_vars]
  144. marked_opr_ids = []
  145. for reshape_opr in reshape_oprs:
  146. iterative_pruning(
  147. reshape_opr.inputs[1].owner_opr, reshape_opr, marked_opr_ids
  148. )
  149. # filter out all marked oprs
  150. return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
  151. map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
  152. outputs
  153. )
  154. oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
  155. if prune_reshape is True:
  156. oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
  157. return oprs_seq
  158. def replace_vars(dst, varmap):
  159. """replace vars in the graph
  160. :param dst: target vars representing the graph
  161. :type dst: list of :class:`.SymbolVar`
  162. :param varmap: the map that specifies how to replace the vars
  163. :type varmap: dict that maps from src var to dst var
  164. :return: new vars that correspond to ``dst`` with all the dependencies
  165. replaced
  166. :rtype: list of :class:`.SymbolVar`
  167. """
  168. dst_vec = _mgb._VectorSymbolVar()
  169. repl_src_vec = _mgb._VectorSymbolVar()
  170. repl_dst_vec = _mgb._VectorSymbolVar()
  171. for i in dst:
  172. assert isinstance(i, _mgb.SymbolVar)
  173. dst_vec.push_back(i)
  174. for i, j in getattr(varmap, "items", lambda: varmap)():
  175. assert isinstance(i, _mgb.SymbolVar)
  176. assert isinstance(j, _mgb.SymbolVar)
  177. repl_src_vec.push_back(i)
  178. repl_dst_vec.push_back(j)
  179. return _mgb._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
  180. def replace_oprs(dst, oprmap):
  181. """Replace operators in the graph. Roughly equivalent to
  182. :param dst: target vars representing the graph
  183. :type dst: list of :class:`.SymbolVar`
  184. :param oprmap: the map that specifies how to replace the operators
  185. :type oprmap: dict that maps from src operator to dst operator
  186. :return: new vars that correspond to ``dst`` with all the dependencies
  187. replaced
  188. :rtype: list of :class:`.SymbolVar`
  189. """
  190. dst_vec = _mgb._VectorSymbolVar()
  191. repl_src_vec = _mgb._VectorOperator()
  192. repl_dst_vec = _mgb._VectorOperator()
  193. for i in dst:
  194. assert isinstance(i, _mgb.SymbolVar)
  195. dst_vec.push_back(i)
  196. for i, j in getattr(oprmap, "items", lambda: oprmap)():
  197. assert isinstance(i, _mgb.Operator)
  198. assert isinstance(j, _mgb.Operator)
  199. repl_src_vec.push_back(i)
  200. repl_dst_vec.push_back(j)
  201. return _mgb._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
  202. def set_priority_to_id(dest_vars):
  203. """For all oprs in the subgraph constructed by dest_vars
  204. set its priority to id if its original priority is zero
  205. :param dest_vars: target vars representing the graph
  206. """
  207. dest_vec = _mgb._VectorSymbolVar()
  208. for i in dest_vars:
  209. assert isinstance(i, _mgb.SymbolVar)
  210. dest_vec.push_back(i)
  211. _mgb._set_priority_to_id(dest_vec)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台