You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_graph_tools.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. from typing import Dict, List
  10. import numpy
  11. from ..core import _imperative_rt
  12. from ..core._imperative_rt import OperatorNode, VarNode
  13. from ..core.tensor import megbrain_graph as G
  14. from ..core.tensor.raw_tensor import as_raw_tensor
  15. def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
  16. """return :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
  17. depands on. If ``var_type`` is None, return all types.
  18. """
  19. outputs = []
  20. memo = set()
  21. if isinstance(var, VarNode):
  22. var = [var]
  23. if isinstance(var_type, str):
  24. var_type = [var_type]
  25. q = list(var)
  26. while q:
  27. v = q.pop()
  28. if v in memo:
  29. continue
  30. memo.add(v)
  31. q.extend(get_owner_opr_inputs(v))
  32. if var_type is not None:
  33. if get_owner_opr_type(v) in var_type:
  34. outputs.append(v)
  35. else:
  36. outputs.append(v)
  37. return outputs
  38. def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
  39. """get the inputs of owner opr of a variable
  40. """
  41. assert isinstance(var, VarNode)
  42. return var.owner.inputs
  43. def get_owner_opr_type(var: VarNode) -> str:
  44. """get the type of owner opr of a variable
  45. """
  46. assert isinstance(var, VarNode)
  47. return var.owner.type
  48. def get_opr_type(opr: OperatorNode) -> str:
  49. """get the type of a opr
  50. """
  51. assert isinstance(opr, OperatorNode)
  52. return opr.type
  53. def graph_traversal(outputs: VarNode):
  54. """helper function to traverse the computing graph and return enough useful information
  55. :param outputs: model outputs
  56. :return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
  57. WHERE
  58. map_oprs is dict from opr_id to actual opr
  59. map_vars is dict from var_id to actual var
  60. var2oprs is dict from var to dest oprs along with index
  61. opr2receivers is dict from current opr to next opr
  62. indegree2opr is dict from in_degree to opr in computing graph
  63. opr2indegree is dict from opr in computing graph to in_degree
  64. (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
  65. """
  66. # meta information for comp graph
  67. map_oprs = collections.defaultdict(set)
  68. map_vars = collections.defaultdict(set)
  69. var2oprs = collections.defaultdict(list)
  70. opr2receivers = collections.defaultdict(list)
  71. queue = list(map(lambda x: x.owner, outputs))
  72. visited = set(map(lambda x: x.id, queue))
  73. # iterate through whole comp_graph, fill in meta information
  74. indegree2opr = collections.defaultdict(set)
  75. opr2indegree = {}
  76. idx = 0
  77. while idx < len(queue):
  78. cur_opr = queue[idx]
  79. map_oprs[cur_opr.id] = cur_opr
  80. idx += 1
  81. indegree = 0
  82. for var_idx, var in enumerate(cur_opr.inputs):
  83. map_vars[var.id] = var
  84. var2oprs[var.id].append((cur_opr.id, var_idx))
  85. pre_opr = var.owner
  86. if pre_opr.id not in visited:
  87. visited.add(pre_opr.id)
  88. queue.append(pre_opr)
  89. indegree += 1
  90. opr2receivers[pre_opr.id].append(cur_opr.id)
  91. indegree2opr[indegree].add(cur_opr.id)
  92. opr2indegree[cur_opr.id] = indegree
  93. return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
  94. def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]:
  95. """get oprs in some topological order for a dumped model
  96. :param outputs: model outputs
  97. :param prune_reshape: whether to prune the operators useless during inference
  98. :return: opr list with some correct execution order
  99. """
  100. def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
  101. # generate an execution order with topological sort algorithm
  102. oprs_seq = []
  103. nr_remain = len(map_oprs)
  104. while indegree2opr[0]:
  105. opr_id = indegree2opr[0].pop()
  106. opr = map_oprs[opr_id]
  107. nr_remain -= 1
  108. # skip const value generation operator
  109. if get_opr_type(opr) != "ImmutableTensor":
  110. oprs_seq.append(opr)
  111. for post_id in opr2receivers[opr_id]:
  112. indegree = opr2indegree[post_id]
  113. indegree2opr[indegree].remove(post_id)
  114. indegree -= 1
  115. indegree2opr[indegree].add(post_id)
  116. opr2indegree[post_id] = indegree
  117. assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
  118. nr_remain
  119. )
  120. return oprs_seq
  121. # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
  122. # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
  123. def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
  124. def iterative_pruning(cur_opr, post_opr, marked_opr_ids):
  125. useless = True
  126. for oup in cur_opr.outputs:
  127. if "workspace" not in oup.name:
  128. var_idx = post_opr.inputs.index(oup)
  129. var2oprs[oup.id].remove((post_opr.id, var_idx))
  130. useless = useless and (len(var2oprs[oup.id]) == 0)
  131. if useless:
  132. marked_opr_ids.append(cur_opr.id)
  133. for inp in cur_opr.inputs:
  134. iterative_pruning(inp.owner, cur_opr, marked_opr_ids)
  135. reshape_vars = get_dep_vars(outputs, "Reshape")
  136. reshape_oprs = [var.owner for var in reshape_vars]
  137. marked_opr_ids = []
  138. for reshape_opr in reshape_oprs:
  139. iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids)
  140. # filter out all marked oprs
  141. return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
  142. map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
  143. outputs
  144. )
  145. oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
  146. if prune_reshape is True:
  147. oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
  148. return oprs_seq
  149. def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
  150. """replace vars in the graph
  151. :param dst: target vars representing the graph
  152. :param varmap: the map that specifies how to replace the vars
  153. :return: new vars that correspond to ``dst`` with all the dependencies
  154. replaced
  155. """
  156. dst_vec = []
  157. repl_src_vec = []
  158. repl_dst_vec = []
  159. for i in dst:
  160. assert isinstance(i, VarNode)
  161. dst_vec.append(i)
  162. for i, j in getattr(varmap, "items", lambda: varmap)():
  163. assert isinstance(i, VarNode)
  164. assert isinstance(j, VarNode)
  165. repl_src_vec.append(i)
  166. repl_dst_vec.append(j)
  167. return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
  168. def replace_oprs(
  169. dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode]
  170. ) -> List[VarNode]:
  171. """Replace operators in the graph.
  172. :param dst: target vars representing the graph
  173. :param oprmap: the map that specifies how to replace the operators
  174. :return: new vars that correspond to ``dst`` with all the dependencies
  175. replaced
  176. """
  177. dst_vec = []
  178. repl_src_vec = []
  179. repl_dst_vec = []
  180. for i in dst:
  181. assert isinstance(i, VarNode)
  182. dst_vec.append(i)
  183. for i, j in getattr(oprmap, "items", lambda: oprmap)():
  184. assert isinstance(i, OperatorNode)
  185. assert isinstance(j, OperatorNode)
  186. repl_src_vec.append(i)
  187. repl_dst_vec.append(j)
  188. return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
  189. def set_priority_to_id(dest_vars):
  190. """For all oprs in the subgraph constructed by dest_vars
  191. set its priority to id if its original priority is zero
  192. :param dest_vars: target vars representing the graph
  193. """
  194. dest_vec = []
  195. for i in dest_vars:
  196. assert isinstance(i, VarNode)
  197. dest_vec.append(i)
  198. _imperative_rt.graph._set_priority_to_id(dest_vec)
  199. def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]:
  200. """Load a serialized computing graph and run inference with input data.
  201. :param file: Path or Handle of the input file.
  202. :param inp_data_list: List of input data.
  203. :return: List of inference results.
  204. """
  205. *_, out_list = G.load_graph(file)
  206. inputs = get_dep_vars(out_list, "Host2DeviceCopy")
  207. replace_dict = {}
  208. inp_node_list = []
  209. for i in inputs:
  210. inp_node = G.InputNode(
  211. device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
  212. )
  213. replace_dict[i] = inp_node.outputs[0]
  214. inp_node_list.append(inp_node)
  215. new_out = replace_vars(out_list, replace_dict)
  216. out_node_list = [G.OutputNode(i) for i in new_out]
  217. new_out_list = [i.outputs[0] for i in out_node_list]
  218. cg = new_out_list[0].graph
  219. func = cg.compile(new_out_list)
  220. for node, value in zip(inp_node_list, inp_data_list):
  221. node.set_value(as_raw_tensor(value)._dev_tensor())
  222. func.execute()
  223. out_data_list = [o.get_value().numpy() for o in out_node_list]
  224. return out_data_list

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台