You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_graph_tools.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. import heapq
  10. from collections import OrderedDict
  11. from typing import Dict, List, Tuple, Union
  12. import numpy as np
  13. from ..core import _imperative_rt
  14. from ..core._imperative_rt import GraphProfiler
  15. from ..core._imperative_rt import OperatorNode as _OpNode
  16. from ..core._imperative_rt import VarNode as _VarNode
  17. from ..core.tensor import megbrain_graph as G
  18. from ..core.tensor.megbrain_graph import set_priority_to_id
  19. from ..tensor import Tensor
  20. __all__ = [
  21. "get_dep_vars",
  22. "get_owner_opr_inputs",
  23. "get_owner_opr_type",
  24. "get_opr_type",
  25. "graph_traversal",
  26. "get_oprs_seq",
  27. "replace_vars",
  28. "replace_oprs",
  29. "set_priority_to_id",
  30. "GraphInference",
  31. ]
  32. def get_dep_vars(
  33. var: Union[_VarNode, List[_VarNode]], var_type: Union[str, List[str]] = None
  34. ) -> List[_VarNode]:
  35. """
  36. Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
  37. depands on. If ``var_type`` is None, returns all types.
  38. """
  39. outputs = []
  40. memo = set()
  41. if isinstance(var, _VarNode):
  42. var = [var]
  43. if isinstance(var_type, str):
  44. var_type = [var_type]
  45. q = list(var)
  46. while q:
  47. v = q.pop(0)
  48. if v in memo:
  49. continue
  50. memo.add(v)
  51. q.extend(get_owner_opr_inputs(v))
  52. if var_type is not None:
  53. if get_owner_opr_type(v) in var_type:
  54. outputs.append(v)
  55. else:
  56. outputs.append(v)
  57. return outputs
  58. def get_owner_opr_inputs(var: _VarNode) -> List[_VarNode]:
  59. """
  60. Gets the inputs of owner opr of a variable.
  61. """
  62. return var.owner.inputs
  63. def get_owner_opr_type(var: _VarNode) -> str:
  64. """
  65. Gets the type of owner opr of a variable.
  66. """
  67. return var.owner.type
  68. def get_opr_type(opr: _OpNode) -> str:
  69. """
  70. Gets the type of an opr.
  71. """
  72. assert isinstance(opr, _OpNode)
  73. return opr.type
  74. class _OprStableOrderHeapq:
  75. """heap implementation for operator comparison in stable order"""
  76. _list = None
  77. _extra_priority = None
  78. _used_id_name_pairs = None
  79. def __init__(self, extra_priority):
  80. assert isinstance(extra_priority, collections.Callable)
  81. self._list = []
  82. self._extra_priority = extra_priority
  83. self._used_id_name_pairs = {}
  84. def pop_min(self):
  85. return heapq.heappop(self._list)[-1]
  86. def add(self, opr):
  87. # named as add to mimic set() interface
  88. id_ = opr.id
  89. name = opr.name
  90. other = self._used_id_name_pairs.setdefault((id_, name), opr)
  91. if other is not opr:
  92. raise RuntimeError(
  93. "duplicated (id, name) pair: opr0={} opr1={}".format(other, opr)
  94. )
  95. item = self._extra_priority(opr) + (id_, name, opr)
  96. heapq.heappush(self._list, item)
  97. def __bool__(self):
  98. return bool(self._list)
  99. def graph_traversal(outputs: _VarNode):
  100. """
  101. Helper function to traverse the computing graph and return enough useful information.
  102. :param outputs: model outputs.
  103. :return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
  104. WHERE
  105. map_oprs is dict from opr_id to actual opr
  106. map_vars is dict from var_id to actual var
  107. var2oprs is dict from var to dest oprs along with index
  108. opr2receivers is dict from current opr to next opr
  109. indegree2opr is dict from in_degree to opr in computing graph
  110. opr2indegree is dict from opr in computing graph to in_degree
  111. (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
  112. """
  113. # meta information for comp graph
  114. map_oprs = collections.defaultdict(set)
  115. map_vars = collections.defaultdict(set)
  116. var2oprs = collections.defaultdict(list)
  117. opr2receivers = collections.defaultdict(list)
  118. queue = []
  119. [queue.append(o) for o in [x.owner for x in outputs] if o not in queue]
  120. visited = set(map(lambda x: x.id, queue))
  121. # iterate through whole comp_graph, fill in meta information
  122. indegree2opr = collections.defaultdict(set)
  123. indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,))
  124. opr2indegree = {}
  125. idx = 0
  126. while idx < len(queue):
  127. cur_opr = queue[idx]
  128. map_oprs[cur_opr.id] = cur_opr
  129. idx += 1
  130. indegree = 0
  131. for var_idx, var in enumerate(cur_opr.inputs):
  132. map_vars[var.id] = var
  133. var2oprs[var.id].append((cur_opr.id, var_idx))
  134. pre_opr = var.owner
  135. if pre_opr.id not in visited:
  136. visited.add(pre_opr.id)
  137. queue.append(pre_opr)
  138. indegree += 1
  139. opr2receivers[pre_opr.id].append(cur_opr.id)
  140. opr = cur_opr if indegree == 0 else cur_opr.id
  141. indegree2opr[indegree].add(opr)
  142. opr2indegree[cur_opr.id] = indegree
  143. return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
  144. def get_oprs_seq(
  145. outputs: List[_VarNode], prune_reshape=False, prune_immtensor=True
  146. ) -> List[_OpNode]:
  147. """
  148. Gets oprs in some topological order for a dumped model.
  149. :param outputs: model outputs.
  150. :param prune_reshape: whether to prune the useless operators used by Reshape opr during inference.
  151. :param prune_immtensor: whether to prune the ImmutableTensor opr.
  152. :return: opr list with some correct execution order.
  153. """
  154. def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
  155. # generate an execution order with topological sort algorithm
  156. oprs_seq = []
  157. nr_remain = len(map_oprs)
  158. while indegree2opr[0]:
  159. opr = indegree2opr[0].pop_min()
  160. opr_id = opr.id
  161. nr_remain -= 1
  162. if opr.type != "ImmutableTensor" or not prune_immtensor:
  163. oprs_seq.append(opr)
  164. for post_id in opr2receivers[opr_id]:
  165. indegree = opr2indegree[post_id]
  166. indegree2opr[indegree].remove(post_id)
  167. indegree -= 1
  168. if indegree == 0:
  169. indegree2opr[indegree].add(map_oprs[post_id])
  170. else:
  171. indegree2opr[indegree].add(post_id)
  172. opr2indegree[post_id] = indegree
  173. assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
  174. nr_remain
  175. )
  176. return oprs_seq
  177. # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
  178. # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
  179. def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
  180. def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited):
  181. useless = True
  182. for oup in cur_opr.outputs:
  183. if "workspace" not in oup.name:
  184. var_idx = post_opr.inputs.index(oup)
  185. var2oprs[oup.id].remove((post_opr.id, var_idx))
  186. useless = useless and (len(var2oprs[oup.id]) == 0)
  187. if useless:
  188. marked_opr_ids.append(cur_opr.id)
  189. for opr in set([var.owner for var in cur_opr.inputs]):
  190. if (opr.id, cur_opr.id) not in visited:
  191. visited.add((opr.id, cur_opr.id))
  192. iterative_pruning(opr, cur_opr, marked_opr_ids, visited)
  193. reshape_vars = get_dep_vars(outputs, "Reshape")
  194. reshape_oprs = [var.owner for var in reshape_vars]
  195. marked_opr_ids = []
  196. visited = set()
  197. for reshape_opr in reshape_oprs:
  198. iterative_pruning(
  199. reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited
  200. )
  201. # filter out all marked oprs
  202. return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
  203. # adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs.
  204. def reorder_oprs_seq(oprs):
  205. rst = []
  206. param_or_data_provider_oprs = []
  207. other_oprs = []
  208. for o in oprs:
  209. if o.type in ["ImmutableTensor", "Host2DeviceCopy"]:
  210. param_or_data_provider_oprs.append(o)
  211. else:
  212. other_oprs.append(o)
  213. for o in other_oprs:
  214. for inp in o.inputs:
  215. if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]:
  216. if inp.owner in param_or_data_provider_oprs:
  217. rst.append(inp.owner)
  218. param_or_data_provider_oprs.remove(inp.owner)
  219. rst.append(o)
  220. rst = rst + param_or_data_provider_oprs
  221. assert len(rst) == len(oprs)
  222. return rst
  223. map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
  224. outputs
  225. )
  226. oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
  227. oprs_seq = reorder_oprs_seq(oprs_seq)
  228. if prune_reshape is True:
  229. oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
  230. return oprs_seq
  231. def replace_vars(
  232. dst: List[_VarNode], varmap: Dict[_VarNode, _VarNode]
  233. ) -> List[_VarNode]:
  234. """
  235. Replaces vars in the graph.
  236. :param dst: target vars representing the graph.
  237. :param varmap: the map that specifies how to replace the vars.
  238. :return: new vars that correspond to ``dst`` with all the dependencies
  239. replaced.
  240. """
  241. dst_vec = []
  242. repl_src_vec = []
  243. repl_dst_vec = []
  244. for i in dst:
  245. assert isinstance(i, _VarNode)
  246. dst_vec.append(i)
  247. for i, j in getattr(varmap, "items", lambda: varmap)():
  248. assert isinstance(i, _VarNode)
  249. assert isinstance(j, _VarNode)
  250. repl_src_vec.append(i)
  251. repl_dst_vec.append(j)
  252. return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
  253. def replace_oprs(dst: List[_VarNode], oprmap: Dict[_OpNode, _OpNode]) -> List[_VarNode]:
  254. """
  255. Replaces operators in the graph.
  256. :param dst: target vars representing the graph.
  257. :param oprmap: the map that specifies how to replace the operators.
  258. :return: new vars that correspond to ``dst`` with all the dependencies
  259. replaced.
  260. """
  261. dst_vec = []
  262. repl_src_vec = []
  263. repl_dst_vec = []
  264. for i in dst:
  265. assert isinstance(i, _VarNode)
  266. dst_vec.append(i)
  267. for i, j in getattr(oprmap, "items", lambda: oprmap)():
  268. assert isinstance(i, _OpNode)
  269. assert isinstance(j, _OpNode)
  270. repl_src_vec.append(i)
  271. repl_dst_vec.append(j)
  272. return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
  273. def find_vars_by_name(dst: List[_VarNode], names: List[str]) -> List[_VarNode]:
  274. """
  275. Gets VarNode list by names in the graph.
  276. :param dst: target vars representing the graph.
  277. :param names: name list for target VarNode.
  278. :return: results found by names.
  279. """
  280. output_names = names.copy()
  281. all_vars = get_dep_vars(dst) + dst
  282. # use dict to keep outputs order the same as names.
  283. output_dict = {}
  284. for i in all_vars:
  285. if i.name in output_names:
  286. output_dict[i.name] = i
  287. output_names.remove(i.name)
  288. assert len(output_names) == 0, "Can not find varnode {} in this model".format(
  289. output_names
  290. )
  291. return [output_dict[i] for i in names]
  292. def convert_inputs(
  293. dst: List[_VarNode], inputs: List[_VarNode] = None
  294. ) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
  295. """
  296. Replaces ``Host2DeviceCopy`` with :class:`~.InputNode` in the graph
  297. to :meth:`~.InputNode.set_value` and run.
  298. :param dst: target vars representing the graph.
  299. :param inputs: indicates which inputs to be replaced. All
  300. inputs(``Host2DeiceCopy``) will be replaced if not specified.
  301. :return: new vars that correspond to ``dst`` with all inputs
  302. replaced, and new inputs dict.
  303. """
  304. if inputs is None:
  305. inputs = get_dep_vars(dst, "Host2DeviceCopy")
  306. input_dict = OrderedDict()
  307. replace_dict = {}
  308. for inp in inputs:
  309. inp_node = G.InputNode(
  310. device=inp.comp_node, dtype=inp.dtype, shape=inp.shape, graph=inp.graph,
  311. )
  312. inp_node.name = inp.name
  313. input_dict[inp.name] = inp_node
  314. replace_dict[inp] = inp_node.outputs[0]
  315. new_output_nodes = replace_vars(dst, replace_dict)
  316. for old, new in zip(dst, new_output_nodes):
  317. new.name = old.name
  318. return new_output_nodes, input_dict
  319. def convert_outputs(dst: List[_VarNode]) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
  320. """
  321. Wraps ``dst`` with :class:`~.OutputNode` in the graph to get outputs
  322. with :meth:`~.OutputNode.get_value`.
  323. :param dst: target vars representing the graph.
  324. :return: new vars that correspond to ``dst`` with all inputs
  325. replaced, and outputs dict.
  326. """
  327. output_dict = OrderedDict([(i.name, G.OutputNode(i)) for i in dst])
  328. new_output_nodes = [i.outputs[0] for i in output_dict.values()]
  329. return new_output_nodes, output_dict
  330. def embed_inputs(
  331. dst: List[_VarNode], data: List[np.ndarray], inputs: List[_VarNode] = None
  332. ) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
  333. """
  334. Embeds ``data`` to the graph's inputs of ``dst``.
  335. :param dst: target vars representing the graph.
  336. :param data: data to be embeded.
  337. :param inputs: indicates which inputs to be replaced. All
  338. inputs(``Host2DeiceCopy``) will be replaced if not specified.
  339. :return: new vars that correspond to ``dst`` with all inputs
  340. replaced, and new inputs dict.
  341. """
  342. if inputs is None:
  343. inputs = get_dep_vars(dst, "Host2DeviceCopy")
  344. assert len(data) == len(inputs)
  345. input_dict = OrderedDict()
  346. replace_dict = {}
  347. for inp, d in zip(inputs, data):
  348. new_inp = _imperative_rt.make_shared(inp.graph, Tensor(d)._dev_tensor())
  349. new_inp.name = inp.name
  350. input_dict[inp.name] = new_inp
  351. replace_dict[inp] = new_inp
  352. new_output_nodes = replace_vars(dst, replace_dict)
  353. for old, new in zip(dst, new_output_nodes):
  354. new.name = old.name
  355. return new_output_nodes, input_dict
  356. class GraphInference:
  357. """
  358. Loads a serialized computing graph as a GraphInference object which can be used
  359. to execute the computing graph.
  360. :param file: could be file object or filename.
  361. :param outputs: only compile the subgraph with outputs as its endpoints.
  362. """
  363. def __init__(
  364. self,
  365. file,
  366. outputs: List[str] = None,
  367. profiling: bool = False,
  368. optimize_for_inference: bool = False,
  369. **kwargs
  370. ):
  371. ret = G.load_graph(file)
  372. self._graph, output_nodes = ret.graph, ret.output_vars_list
  373. if outputs is not None:
  374. output_nodes = find_vars_by_name(output_nodes, outputs)
  375. self._origin_outputs = output_nodes
  376. # replace inputs with `InputNode`
  377. output_nodes, self._inp_dict = convert_inputs(output_nodes)
  378. # replace outputs with `OutputNode`
  379. output_nodes, self._oup_dict = convert_outputs(output_nodes)
  380. self._func = self._graph.compile(output_nodes)
  381. def run(
  382. self, *inp_args: np.ndarray, inp_dict: Dict[str, np.ndarray] = None
  383. ) -> Dict[str, np.ndarray]:
  384. """
  385. :param inp_args: list of input datas.
  386. :param inp_dict: dict of named input datas.
  387. :return: a dict {output_name: output_value}.
  388. """
  389. assert len(inp_args) <= len(
  390. self._inp_dict
  391. ), "This model expects {} inputs".format(len(self._inp_dict))
  392. inputs = {}
  393. inp_keys = list(self._inp_dict.keys())
  394. for ind, data in enumerate(inp_args):
  395. inputs[inp_keys[ind]] = data
  396. if inp_dict is not None:
  397. inputs.update(inp_dict)
  398. assert (
  399. inputs.keys() == self._inp_dict.keys()
  400. ), "This model expects inputs {}, but gets inputs {}".format(
  401. list(self._inp_dict.keys()), list(inputs.keys())
  402. )
  403. for key in self._inp_dict:
  404. self._inp_dict[key].set_value(
  405. Tensor(inputs[key], device=self._inp_dict[key].device)._dev_tensor()
  406. )
  407. self._func.execute()
  408. self._func.wait()
  409. result = OrderedDict()
  410. for key in self._oup_dict:
  411. result[key] = self._oup_dict[key].get_value().numpy()
  412. return result

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台