You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comp_graph_tools.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. import collections
  2. import heapq
  3. from collections import OrderedDict
  4. from typing import Dict, List, Tuple, Union
  5. import numpy as np
  6. from ..core import _imperative_rt
  7. from ..core._imperative_rt import GraphProfiler
  8. from ..core._imperative_rt import OperatorNode as _OpNode
  9. from ..core._imperative_rt import VarNode as _VarNode
  10. from ..core.tensor import megbrain_graph as G
  11. from ..core.tensor.megbrain_graph import set_priority_to_id
  12. from ..tensor import Tensor
  13. __all__ = [
  14. "get_dep_vars",
  15. "get_owner_opr_inputs",
  16. "get_owner_opr_type",
  17. "get_opr_type",
  18. "graph_traversal",
  19. "get_oprs_seq",
  20. "replace_vars",
  21. "replace_oprs",
  22. "set_priority_to_id",
  23. "GraphInference",
  24. ]
  25. def get_dep_vars(
  26. var: Union[_VarNode, List[_VarNode]], var_type: Union[str, List[str]] = None
  27. ) -> List[_VarNode]:
  28. r"""Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
  29. depands on. If ``var_type`` is None, returns all types.
  30. """
  31. outputs = []
  32. memo = set()
  33. if isinstance(var, _VarNode):
  34. var = [var]
  35. if isinstance(var_type, str):
  36. var_type = [var_type]
  37. q = list(var)
  38. while q:
  39. v = q.pop(0)
  40. if v in memo:
  41. continue
  42. memo.add(v)
  43. q.extend(get_owner_opr_inputs(v))
  44. if var_type is not None:
  45. if get_owner_opr_type(v) in var_type:
  46. outputs.append(v)
  47. else:
  48. outputs.append(v)
  49. return outputs
  50. def get_owner_opr_inputs(var: _VarNode) -> List[_VarNode]:
  51. r"""Gets the inputs of owner opr of a variable. """
  52. return var.owner.inputs
  53. def get_owner_opr_type(var: _VarNode) -> str:
  54. r"""Gets the type of owner opr of a variable."""
  55. return var.owner.type
  56. def get_opr_type(opr: _OpNode) -> str:
  57. r"""Gets the type of an opr."""
  58. assert isinstance(opr, _OpNode)
  59. return opr.type
  60. class _OprStableOrderHeapq:
  61. r"""heap implementation for operator comparison in stable order"""
  62. _list = None
  63. _extra_priority = None
  64. _used_id_name_pairs = None
  65. def __init__(self, extra_priority):
  66. assert isinstance(extra_priority, collections.abc.Callable)
  67. self._list = []
  68. self._extra_priority = extra_priority
  69. self._used_id_name_pairs = {}
  70. def pop_min(self):
  71. return heapq.heappop(self._list)[-1]
  72. def add(self, opr):
  73. # named as add to mimic set() interface
  74. id_ = opr.id
  75. name = opr.name
  76. other = self._used_id_name_pairs.setdefault((id_, name), opr)
  77. if other is not opr:
  78. raise RuntimeError(
  79. "duplicated (id, name) pair: opr0={} opr1={}".format(other, opr)
  80. )
  81. item = self._extra_priority(opr) + (id_, name, opr)
  82. heapq.heappush(self._list, item)
  83. def __bool__(self):
  84. return bool(self._list)
  85. def graph_traversal(outputs: _VarNode):
  86. r"""Helper function to traverse the computing graph and return enough useful information.
  87. Args:
  88. outputs: model outputs.
  89. Returns:
  90. tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
  91. WHERE
  92. * map_oprs is dict from opr_id to actual opr
  93. * map_vars is dict from var_id to actual var
  94. * var2oprs is dict from var to dest oprs along with index
  95. * opr2receivers is dict from current opr to next opr
  96. * indegree2opr is dict from in_degree to opr in computing graph
  97. * opr2indegree is dict from opr in computing graph to in_degree
  98. (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
  99. """
  100. # meta information for comp graph
  101. map_oprs = collections.defaultdict(set)
  102. map_vars = collections.defaultdict(set)
  103. var2oprs = collections.defaultdict(list)
  104. opr2receivers = collections.defaultdict(list)
  105. queue = []
  106. [queue.append(o) for o in [x.owner for x in outputs] if o not in queue]
  107. visited = set(map(lambda x: x.id, queue))
  108. # iterate through whole comp_graph, fill in meta information
  109. indegree2opr = collections.defaultdict(set)
  110. indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,))
  111. opr2indegree = {}
  112. idx = 0
  113. while idx < len(queue):
  114. cur_opr = queue[idx]
  115. map_oprs[cur_opr.id] = cur_opr
  116. idx += 1
  117. indegree = 0
  118. for var_idx, var in enumerate(cur_opr.inputs):
  119. map_vars[var.id] = var
  120. var2oprs[var.id].append((cur_opr.id, var_idx))
  121. pre_opr = var.owner
  122. if pre_opr.id not in visited:
  123. visited.add(pre_opr.id)
  124. queue.append(pre_opr)
  125. indegree += 1
  126. opr2receivers[pre_opr.id].append(cur_opr.id)
  127. opr = cur_opr if indegree == 0 else cur_opr.id
  128. indegree2opr[indegree].add(opr)
  129. opr2indegree[cur_opr.id] = indegree
  130. return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
  131. def get_oprs_seq(
  132. outputs: List[_VarNode], prune_reshape=False, prune_immtensor=True
  133. ) -> List[_OpNode]:
  134. r"""Gets oprs in some topological order for a dumped model.
  135. Args:
  136. outputs: model outputs.
  137. prune_reshape: whether to prune the useless operators used by Reshape opr during inference.
  138. prune_immtensor: whether to prune the ImmutableTensor opr.
  139. Returns:
  140. opr list with some correct execution order.
  141. """
  142. def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
  143. # generate an execution order with topological sort algorithm
  144. oprs_seq = []
  145. nr_remain = len(map_oprs)
  146. while indegree2opr[0]:
  147. opr = indegree2opr[0].pop_min()
  148. opr_id = opr.id
  149. nr_remain -= 1
  150. if opr.type != "ImmutableTensor" or not prune_immtensor:
  151. oprs_seq.append(opr)
  152. for post_id in opr2receivers[opr_id]:
  153. indegree = opr2indegree[post_id]
  154. indegree2opr[indegree].remove(post_id)
  155. indegree -= 1
  156. if indegree == 0:
  157. indegree2opr[indegree].add(map_oprs[post_id])
  158. else:
  159. indegree2opr[indegree].add(post_id)
  160. opr2indegree[post_id] = indegree
  161. assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
  162. nr_remain
  163. )
  164. return oprs_seq
  165. # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
  166. # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
  167. def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
  168. def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited):
  169. useless = True
  170. for oup in cur_opr.outputs:
  171. if "workspace" not in oup.name:
  172. var_idx = post_opr.inputs.index(oup)
  173. var2oprs[oup.id].remove((post_opr.id, var_idx))
  174. useless = useless and (len(var2oprs[oup.id]) == 0)
  175. if useless:
  176. marked_opr_ids.append(cur_opr.id)
  177. for opr in set([var.owner for var in cur_opr.inputs]):
  178. if (opr.id, cur_opr.id) not in visited:
  179. visited.add((opr.id, cur_opr.id))
  180. iterative_pruning(opr, cur_opr, marked_opr_ids, visited)
  181. reshape_vars = get_dep_vars(outputs, "Reshape")
  182. reshape_oprs = [var.owner for var in reshape_vars]
  183. marked_opr_ids = []
  184. visited = set()
  185. for reshape_opr in reshape_oprs:
  186. iterative_pruning(
  187. reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited
  188. )
  189. # filter out all marked oprs
  190. return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
  191. # adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs.
  192. def reorder_oprs_seq(oprs):
  193. rst = []
  194. param_or_data_provider_oprs = []
  195. other_oprs = []
  196. for o in oprs:
  197. if o.type in ["ImmutableTensor", "Host2DeviceCopy"]:
  198. param_or_data_provider_oprs.append(o)
  199. else:
  200. other_oprs.append(o)
  201. for o in other_oprs:
  202. for inp in o.inputs:
  203. if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]:
  204. if inp.owner in param_or_data_provider_oprs:
  205. rst.append(inp.owner)
  206. param_or_data_provider_oprs.remove(inp.owner)
  207. rst.append(o)
  208. rst = rst + param_or_data_provider_oprs
  209. assert len(rst) == len(oprs)
  210. return rst
  211. map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
  212. outputs
  213. )
  214. oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
  215. oprs_seq = reorder_oprs_seq(oprs_seq)
  216. if prune_reshape is True:
  217. oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
  218. return oprs_seq
  219. def replace_vars(
  220. dst: List[_VarNode], varmap: Dict[_VarNode, _VarNode]
  221. ) -> List[_VarNode]:
  222. r"""Replaces vars in the graph.
  223. Args:
  224. dst: target vars representing the graph.
  225. varmap: the map that specifies how to replace the vars.
  226. Returns:
  227. new vars that correspond to ``dst`` with all the dependencies replaced.
  228. """
  229. dst_vec = []
  230. repl_src_vec = []
  231. repl_dst_vec = []
  232. for i in dst:
  233. assert isinstance(i, _VarNode)
  234. dst_vec.append(i)
  235. for i, j in getattr(varmap, "items", lambda: varmap)():
  236. assert isinstance(i, _VarNode)
  237. assert isinstance(j, _VarNode)
  238. repl_src_vec.append(i)
  239. repl_dst_vec.append(j)
  240. return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
  241. def replace_oprs(dst: List[_VarNode], oprmap: Dict[_OpNode, _OpNode]) -> List[_VarNode]:
  242. """Replaces operators in the graph.
  243. Args:
  244. dst: target vars representing the graph.
  245. oprmap: the map that specifies how to replace the operators.
  246. Returns:
  247. new vars that correspond to ``dst`` with all the dependencies replaced.
  248. """
  249. dst_vec = []
  250. repl_src_vec = []
  251. repl_dst_vec = []
  252. for i in dst:
  253. assert isinstance(i, _VarNode)
  254. dst_vec.append(i)
  255. for i, j in getattr(oprmap, "items", lambda: oprmap)():
  256. assert isinstance(i, _OpNode)
  257. assert isinstance(j, _OpNode)
  258. repl_src_vec.append(i)
  259. repl_dst_vec.append(j)
  260. return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
  261. def find_vars_by_name(dst: List[_VarNode], names: List[str]) -> List[_VarNode]:
  262. r"""Gets VarNode list by names in the graph.
  263. Args:
  264. dst: target vars representing the graph.
  265. names: name list for target VarNode.
  266. Returns:
  267. results found by names.
  268. """
  269. output_names = names.copy()
  270. all_vars = get_dep_vars(dst) + dst
  271. # use dict to keep outputs order the same as names.
  272. output_dict = {}
  273. for i in all_vars:
  274. if i.name in output_names:
  275. output_dict[i.name] = i
  276. output_names.remove(i.name)
  277. assert len(output_names) == 0, "Can not find varnode {} in this model".format(
  278. output_names
  279. )
  280. return [output_dict[i] for i in names]
  281. def convert_inputs(
  282. dst: List[_VarNode], inputs: List[_VarNode] = None
  283. ) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
  284. r"""Replaces ``Host2DeviceCopy`` with :class:`~.InputNode` in the graph
  285. to :meth:`~.InputNode.set_value` and run.
  286. Args:
  287. dst: target vars representing the graph.
  288. inputs: indicates which inputs to be replaced. All
  289. inputs(``Host2DeiceCopy``) will be replaced if not specified.
  290. Returns:
  291. new vars that correspond to ``dst`` with all inputs replaced, and new inputs dict.
  292. """
  293. if inputs is None:
  294. inputs = get_dep_vars(dst, "Host2DeviceCopy")
  295. input_dict = OrderedDict()
  296. replace_dict = {}
  297. for inp in inputs:
  298. inp_node = G.InputNode(
  299. device=inp.comp_node, dtype=inp.dtype, shape=inp.shape, graph=inp.graph,
  300. )
  301. inp_node.name = inp.name
  302. input_dict[inp.name] = inp_node
  303. replace_dict[inp] = inp_node.outputs[0]
  304. new_output_nodes = replace_vars(dst, replace_dict)
  305. for old, new in zip(dst, new_output_nodes):
  306. new.name = old.name
  307. return new_output_nodes, input_dict
  308. def convert_outputs(dst: List[_VarNode]) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
  309. r"""Wraps ``dst`` with :class:`~.OutputNode` in the graph to get outputs
  310. with :meth:`~.OutputNode.get_value`.
  311. Args:
  312. dst: target vars representing the graph.
  313. Returns:
  314. new vars that correspond to ``dst`` with all inputs replaced, and outputs dict.
  315. """
  316. output_dict = OrderedDict([(i.name, G.OutputNode(i)) for i in dst])
  317. new_output_nodes = [i.outputs[0] for i in output_dict.values()]
  318. return new_output_nodes, output_dict
  319. def embed_inputs(
  320. dst: List[_VarNode], data: List[np.ndarray], inputs: List[_VarNode] = None
  321. ) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
  322. r"""Embeds ``data`` to the graph's inputs of ``dst``.
  323. Args:
  324. dst: target vars representing the graph.
  325. data: data to be embeded.
  326. inputs: indicates which inputs to be replaced. All
  327. inputs(``Host2DeiceCopy``) will be replaced if not specified.
  328. Returns:
  329. new vars that correspond to ``dst`` with all inputs replaced, and new inputs dict.
  330. """
  331. if inputs is None:
  332. inputs = get_dep_vars(dst, "Host2DeviceCopy")
  333. assert len(data) == len(inputs)
  334. input_dict = OrderedDict()
  335. replace_dict = {}
  336. for inp, d in zip(inputs, data):
  337. new_inp = _imperative_rt.make_shared(inp.graph, Tensor(d)._dev_tensor())
  338. new_inp.name = inp.name
  339. input_dict[inp.name] = new_inp
  340. replace_dict[inp] = new_inp
  341. new_output_nodes = replace_vars(dst, replace_dict)
  342. for old, new in zip(dst, new_output_nodes):
  343. new.name = old.name
  344. return new_output_nodes, input_dict
  345. class GraphInference:
  346. r"""Loads a serialized computing graph as a GraphInference object which can be used
  347. to execute the computing graph.
  348. Args:
  349. file: could be file object or filename.
  350. outputs: only compile the subgraph with outputs as its endpoints.
  351. """
  352. def __init__(
  353. self,
  354. file,
  355. outputs: List[str] = None,
  356. profiling: bool = False,
  357. optimize_for_inference: bool = False,
  358. **kwargs
  359. ):
  360. ret = G.load_graph(file)
  361. self._graph, output_nodes = ret.graph, ret.output_vars_list
  362. if outputs is not None:
  363. output_nodes = find_vars_by_name(output_nodes, outputs)
  364. self._origin_outputs = output_nodes
  365. # replace inputs with `InputNode`
  366. output_nodes, self._inp_dict = convert_inputs(output_nodes)
  367. # replace outputs with `OutputNode`
  368. output_nodes, self._oup_dict = convert_outputs(output_nodes)
  369. self._func = self._graph.compile(output_nodes)
  370. def run(
  371. self, *inp_args: np.ndarray, inp_dict: Dict[str, np.ndarray] = None
  372. ) -> Dict[str, np.ndarray]:
  373. r"""
  374. Args:
  375. inp_args: list of input datas.
  376. inp_dict: dict of named input datas.
  377. Returns:
  378. a dict {output_name: output_value}.
  379. Note:
  380. Note that the order of the Graph's input nodes may be different from the order of the origin traced function's arguments.
  381. It is recommended to use ``inp_dict`` to provide input data by name.
  382. """
  383. assert len(inp_args) <= len(
  384. self._inp_dict
  385. ), "This model expects {} inputs".format(len(self._inp_dict))
  386. inputs = {}
  387. inp_keys = list(self._inp_dict.keys())
  388. for ind, data in enumerate(inp_args):
  389. inputs[inp_keys[ind]] = data
  390. if inp_dict is not None:
  391. inputs.update(inp_dict)
  392. assert (
  393. inputs.keys() == self._inp_dict.keys()
  394. ), "This model expects inputs {}, but gets inputs {}".format(
  395. list(self._inp_dict.keys()), list(inputs.keys())
  396. )
  397. for key in self._inp_dict:
  398. self._inp_dict[key].set_value(
  399. Tensor(inputs[key], device=self._inp_dict[key].device)._dev_tensor()
  400. )
  401. self._func.execute()
  402. self._func.wait()
  403. result = OrderedDict()
  404. for key in self._oup_dict:
  405. result[key] = self._oup_dict[key].get_value().numpy()
  406. return result