You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_graph_tools.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. from collections import OrderedDict
  10. from typing import Dict, List, Optional
  11. import numpy
  12. from ..core import _imperative_rt
  13. from ..core._imperative_rt import OperatorNode, VarNode
  14. from ..core.tensor import megbrain_graph as G
  15. from ..core.tensor.megbrain_graph import set_priority_to_id
  16. from ..tensor import Tensor
  17. __all__ = [
  18. "get_dep_vars",
  19. "get_owner_opr_inputs",
  20. "get_owner_opr_type",
  21. "get_opr_type",
  22. "graph_traversal",
  23. "get_oprs_seq",
  24. "replace_vars",
  25. "replace_oprs",
  26. "set_priority_to_id",
  27. "GraphInference",
  28. ]
  29. def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
  30. """
  31. Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
  32. depands on. If ``var_type`` is None, returns all types.
  33. """
  34. outputs = []
  35. memo = set()
  36. if isinstance(var, VarNode):
  37. var = [var]
  38. if isinstance(var_type, str):
  39. var_type = [var_type]
  40. q = list(var)
  41. while q:
  42. v = q.pop(0)
  43. if v in memo:
  44. continue
  45. memo.add(v)
  46. q.extend(get_owner_opr_inputs(v))
  47. if var_type is not None:
  48. if get_owner_opr_type(v) in var_type:
  49. outputs.append(v)
  50. else:
  51. outputs.append(v)
  52. return outputs
  53. def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
  54. """
  55. Gets the inputs of owner opr of a variable.
  56. """
  57. assert isinstance(var, VarNode)
  58. return var.owner.inputs
  59. def get_owner_opr_type(var: VarNode) -> str:
  60. """
  61. Gets the type of owner opr of a variable.
  62. """
  63. assert isinstance(var, VarNode)
  64. return var.owner.type
  65. def get_opr_type(opr: OperatorNode) -> str:
  66. """
  67. Gets the type of an opr.
  68. """
  69. assert isinstance(opr, OperatorNode)
  70. return opr.type
  71. def graph_traversal(outputs: VarNode):
  72. """
  73. Helper function to traverse the computing graph and return enough useful information.
  74. :param outputs: model outputs.
  75. :return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
  76. WHERE
  77. map_oprs is dict from opr_id to actual opr
  78. map_vars is dict from var_id to actual var
  79. var2oprs is dict from var to dest oprs along with index
  80. opr2receivers is dict from current opr to next opr
  81. indegree2opr is dict from in_degree to opr in computing graph
  82. opr2indegree is dict from opr in computing graph to in_degree
  83. (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
  84. """
  85. # meta information for comp graph
  86. map_oprs = collections.defaultdict(set)
  87. map_vars = collections.defaultdict(set)
  88. var2oprs = collections.defaultdict(list)
  89. opr2receivers = collections.defaultdict(list)
  90. queue = list(map(lambda x: x.owner, outputs))
  91. visited = set(map(lambda x: x.id, queue))
  92. # iterate through whole comp_graph, fill in meta information
  93. indegree2opr = collections.defaultdict(set)
  94. opr2indegree = {}
  95. idx = 0
  96. while idx < len(queue):
  97. cur_opr = queue[idx]
  98. map_oprs[cur_opr.id] = cur_opr
  99. idx += 1
  100. indegree = 0
  101. for var_idx, var in enumerate(cur_opr.inputs):
  102. map_vars[var.id] = var
  103. var2oprs[var.id].append((cur_opr.id, var_idx))
  104. pre_opr = var.owner
  105. if pre_opr.id not in visited:
  106. visited.add(pre_opr.id)
  107. queue.append(pre_opr)
  108. indegree += 1
  109. opr2receivers[pre_opr.id].append(cur_opr.id)
  110. indegree2opr[indegree].add(cur_opr.id)
  111. opr2indegree[cur_opr.id] = indegree
  112. return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
  113. def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]:
  114. """
  115. Gets oprs in some topological order for a dumped model.
  116. :param outputs: model outputs.
  117. :param prune_reshape: whether to prune the useless operators during inference.
  118. :return: opr list with some correct execution order.
  119. """
  120. def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
  121. # generate an execution order with topological sort algorithm
  122. oprs_seq = []
  123. nr_remain = len(map_oprs)
  124. while indegree2opr[0]:
  125. opr_id = indegree2opr[0].pop()
  126. opr = map_oprs[opr_id]
  127. nr_remain -= 1
  128. # skip const value generation operator
  129. if get_opr_type(opr) != "ImmutableTensor":
  130. oprs_seq.append(opr)
  131. for post_id in opr2receivers[opr_id]:
  132. indegree = opr2indegree[post_id]
  133. indegree2opr[indegree].remove(post_id)
  134. indegree -= 1
  135. indegree2opr[indegree].add(post_id)
  136. opr2indegree[post_id] = indegree
  137. assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
  138. nr_remain
  139. )
  140. return oprs_seq
  141. # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
  142. # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
  143. def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
  144. def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited):
  145. useless = True
  146. for oup in cur_opr.outputs:
  147. if "workspace" not in oup.name:
  148. var_idx = post_opr.inputs.index(oup)
  149. var2oprs[oup.id].remove((post_opr.id, var_idx))
  150. useless = useless and (len(var2oprs[oup.id]) == 0)
  151. if useless:
  152. marked_opr_ids.append(cur_opr.id)
  153. for opr in set([var.owner for var in cur_opr.inputs]):
  154. if (opr.id, cur_opr.id) not in visited:
  155. visited.add((opr.id, cur_opr.id))
  156. iterative_pruning(opr, cur_opr, marked_opr_ids, visited)
  157. reshape_vars = get_dep_vars(outputs, "Reshape")
  158. reshape_oprs = [var.owner for var in reshape_vars]
  159. marked_opr_ids = []
  160. visited = set()
  161. for reshape_opr in reshape_oprs:
  162. iterative_pruning(
  163. reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited
  164. )
  165. # filter out all marked oprs
  166. return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
  167. map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
  168. outputs
  169. )
  170. oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
  171. if prune_reshape is True:
  172. oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
  173. return oprs_seq
  174. def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
  175. """
  176. Replaces vars in the graph.
  177. :param dst: target vars representing the graph.
  178. :param varmap: the map that specifies how to replace the vars.
  179. :return: new vars that correspond to ``dst`` with all the dependencies
  180. replaced.
  181. """
  182. dst_vec = []
  183. repl_src_vec = []
  184. repl_dst_vec = []
  185. for i in dst:
  186. assert isinstance(i, VarNode)
  187. dst_vec.append(i)
  188. for i, j in getattr(varmap, "items", lambda: varmap)():
  189. assert isinstance(i, VarNode)
  190. assert isinstance(j, VarNode)
  191. repl_src_vec.append(i)
  192. repl_dst_vec.append(j)
  193. return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
  194. def replace_oprs(
  195. dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode]
  196. ) -> List[VarNode]:
  197. """
  198. Replaces operators in the graph.
  199. :param dst: target vars representing the graph.
  200. :param oprmap: the map that specifies how to replace the operators.
  201. :return: new vars that correspond to ``dst`` with all the dependencies
  202. replaced.
  203. """
  204. dst_vec = []
  205. repl_src_vec = []
  206. repl_dst_vec = []
  207. for i in dst:
  208. assert isinstance(i, VarNode)
  209. dst_vec.append(i)
  210. for i, j in getattr(oprmap, "items", lambda: oprmap)():
  211. assert isinstance(i, OperatorNode)
  212. assert isinstance(j, OperatorNode)
  213. repl_src_vec.append(i)
  214. repl_dst_vec.append(j)
  215. return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
  216. class GraphInference:
  217. """
  218. Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.
  219. The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}.
  220. :param file: could be file object or filename.
  221. :param outputs: only compile the subgraph with outputs as its endpoints.
  222. """
  223. def __init__(self, file, outputs: Optional[List[str]] = None):
  224. *_, output_nodes = G.load_graph(file)
  225. if outputs is not None:
  226. output_name = outputs.copy()
  227. all_vars = get_dep_vars(output_nodes) + output_nodes
  228. new_outputs = {}
  229. for i in all_vars:
  230. if i.name in output_name:
  231. new_outputs[i.name] = i
  232. output_name.remove(i.name)
  233. assert (
  234. len(output_name) == 0
  235. ), "Can not find varnode {} in this model".format(output_name)
  236. output_nodes = [new_outputs[i] for i in outputs]
  237. inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
  238. self._inp_dict = OrderedDict()
  239. replace_dict = {}
  240. for idx, i in enumerate(inputs):
  241. inp_node = G.InputNode(
  242. device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph
  243. )
  244. self._inp_dict[i.name] = inp_node
  245. replace_dict[i] = inp_node.outputs[0]
  246. new_output_nodes = replace_vars(output_nodes, replace_dict)
  247. for old, new in zip(output_nodes, new_output_nodes):
  248. new.name = old.name
  249. self._out_dict = OrderedDict(
  250. [(i.name, G.OutputNode(i)) for i in new_output_nodes]
  251. )
  252. new_out_list = [i.outputs[0] for i in self._out_dict.values()]
  253. cg = new_out_list[0].graph
  254. self._func = cg.compile(new_out_list)
  255. def run(
  256. self,
  257. *inp_args: numpy.ndarray,
  258. inp_dict: Optional[Dict[str, numpy.ndarray]] = None
  259. ):
  260. assert len(inp_args) <= len(
  261. self._inp_dict
  262. ), "This model expects {} inputs".format(len(self._inp_dict))
  263. inputs = {}
  264. inp_keys = list(self._inp_dict.keys())
  265. for ind, data in enumerate(inp_args):
  266. inputs[inp_keys[ind]] = data
  267. if inp_dict is not None:
  268. inputs.update(inp_dict)
  269. assert (
  270. inputs.keys() == self._inp_dict.keys()
  271. ), "This model expects inputs {}, but gets inputs {}".format(
  272. list(self._inp_dict.keys()), list(inputs.keys())
  273. )
  274. for key in self._inp_dict:
  275. self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
  276. self._func.execute()
  277. result = OrderedDict()
  278. for key in self._out_dict:
  279. result[key] = self._out_dict[key].get_value().numpy()
  280. return result

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台