|
- import collections
- import heapq
- from collections import OrderedDict
- from typing import Dict, List, Tuple, Union
-
- import numpy as np
-
- from ..core import _imperative_rt
- from ..core._imperative_rt import GraphProfiler
- from ..core._imperative_rt import OperatorNode as _OpNode
- from ..core._imperative_rt import VarNode as _VarNode
- from ..core.tensor import megbrain_graph as G
- from ..core.tensor.megbrain_graph import set_priority_to_id
- from ..tensor import Tensor
-
- __all__ = [
- "get_dep_vars",
- "get_owner_opr_inputs",
- "get_owner_opr_type",
- "get_opr_type",
- "graph_traversal",
- "get_oprs_seq",
- "replace_vars",
- "replace_oprs",
- "set_priority_to_id",
- "GraphInference",
- ]
-
-
- def get_dep_vars(
- var: Union[_VarNode, List[_VarNode]], var_type: Union[str, List[str]] = None
- ) -> List[_VarNode]:
- r"""Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
- depands on. If ``var_type`` is None, returns all types.
- """
- outputs = []
- memo = set()
-
- if isinstance(var, _VarNode):
- var = [var]
-
- if isinstance(var_type, str):
- var_type = [var_type]
-
- q = list(var)
- while q:
- v = q.pop(0)
- if v in memo:
- continue
- memo.add(v)
- q.extend(get_owner_opr_inputs(v))
- if var_type is not None:
- if get_owner_opr_type(v) in var_type:
- outputs.append(v)
- else:
- outputs.append(v)
-
- return outputs
-
-
- def get_owner_opr_inputs(var: _VarNode) -> List[_VarNode]:
- r"""Gets the inputs of owner opr of a variable. """
- return var.owner.inputs
-
-
- def get_owner_opr_type(var: _VarNode) -> str:
- r"""Gets the type of owner opr of a variable."""
- return var.owner.type
-
-
- def get_opr_type(opr: _OpNode) -> str:
- r"""Gets the type of an opr."""
- assert isinstance(opr, _OpNode)
- return opr.type
-
-
- class _OprStableOrderHeapq:
- r"""heap implementation for operator comparison in stable order"""
-
- _list = None
- _extra_priority = None
- _used_id_name_pairs = None
-
- def __init__(self, extra_priority):
- assert isinstance(extra_priority, collections.abc.Callable)
- self._list = []
- self._extra_priority = extra_priority
- self._used_id_name_pairs = {}
-
- def pop_min(self):
- return heapq.heappop(self._list)[-1]
-
- def add(self, opr):
- # named as add to mimic set() interface
-
- id_ = opr.id
- name = opr.name
-
- other = self._used_id_name_pairs.setdefault((id_, name), opr)
- if other is not opr:
- raise RuntimeError(
- "duplicated (id, name) pair: opr0={} opr1={}".format(other, opr)
- )
-
- item = self._extra_priority(opr) + (id_, name, opr)
- heapq.heappush(self._list, item)
-
- def __bool__(self):
- return bool(self._list)
-
-
- def graph_traversal(outputs: _VarNode):
- r"""Helper function to traverse the computing graph and return enough useful information.
-
- Args:
- outputs: model outputs.
-
- Returns:
- tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
-
- WHERE
-
- * map_oprs is dict from opr_id to actual opr
- * map_vars is dict from var_id to actual var
- * var2oprs is dict from var to dest oprs along with index
- * opr2receivers is dict from current opr to next opr
- * indegree2opr is dict from in_degree to opr in computing graph
- * opr2indegree is dict from opr in computing graph to in_degree
-
- (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
- """
- # meta information for comp graph
- map_oprs = collections.defaultdict(set)
- map_vars = collections.defaultdict(set)
-
- var2oprs = collections.defaultdict(list)
- opr2receivers = collections.defaultdict(list)
- queue = []
- [queue.append(o) for o in [x.owner for x in outputs] if o not in queue]
- visited = set(map(lambda x: x.id, queue))
-
- # iterate through whole comp_graph, fill in meta information
- indegree2opr = collections.defaultdict(set)
- indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,))
- opr2indegree = {}
-
- idx = 0
- while idx < len(queue):
- cur_opr = queue[idx]
- map_oprs[cur_opr.id] = cur_opr
-
- idx += 1
-
- indegree = 0
- for var_idx, var in enumerate(cur_opr.inputs):
- map_vars[var.id] = var
- var2oprs[var.id].append((cur_opr.id, var_idx))
-
- pre_opr = var.owner
-
- if pre_opr.id not in visited:
- visited.add(pre_opr.id)
- queue.append(pre_opr)
-
- indegree += 1
- opr2receivers[pre_opr.id].append(cur_opr.id)
- opr = cur_opr if indegree == 0 else cur_opr.id
- indegree2opr[indegree].add(opr)
- opr2indegree[cur_opr.id] = indegree
-
- return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
-
-
- def get_oprs_seq(
- outputs: List[_VarNode], prune_reshape=False, prune_immtensor=True
- ) -> List[_OpNode]:
- r"""Gets oprs in some topological order for a dumped model.
-
- Args:
- outputs: model outputs.
- prune_reshape: whether to prune the useless operators used by Reshape opr during inference.
- prune_immtensor: whether to prune the ImmutableTensor opr.
-
- Returns:
- opr list with some correct execution order.
- """
-
- def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
- # generate an execution order with topological sort algorithm
- oprs_seq = []
- nr_remain = len(map_oprs)
- while indegree2opr[0]:
- opr = indegree2opr[0].pop_min()
- opr_id = opr.id
- nr_remain -= 1
- if opr.type != "ImmutableTensor" or not prune_immtensor:
- oprs_seq.append(opr)
-
- for post_id in opr2receivers[opr_id]:
- indegree = opr2indegree[post_id]
- indegree2opr[indegree].remove(post_id)
-
- indegree -= 1
- if indegree == 0:
- indegree2opr[indegree].add(map_oprs[post_id])
- else:
- indegree2opr[indegree].add(post_id)
- opr2indegree[post_id] = indegree
-
- assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
- nr_remain
- )
- return oprs_seq
-
- # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
- # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
- def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
- def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited):
- useless = True
- for oup in cur_opr.outputs:
- if "workspace" not in oup.name:
- var_idx = post_opr.inputs.index(oup)
- var2oprs[oup.id].remove((post_opr.id, var_idx))
- useless = useless and (len(var2oprs[oup.id]) == 0)
-
- if useless:
- marked_opr_ids.append(cur_opr.id)
-
- for opr in set([var.owner for var in cur_opr.inputs]):
- if (opr.id, cur_opr.id) not in visited:
- visited.add((opr.id, cur_opr.id))
- iterative_pruning(opr, cur_opr, marked_opr_ids, visited)
-
- reshape_vars = get_dep_vars(outputs, "Reshape")
- reshape_oprs = [var.owner for var in reshape_vars]
-
- marked_opr_ids = []
- visited = set()
- for reshape_opr in reshape_oprs:
- iterative_pruning(
- reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited
- )
-
- # filter out all marked oprs
- return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
-
- # adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs.
- def reorder_oprs_seq(oprs):
- rst = []
- param_or_data_provider_oprs = []
- other_oprs = []
-
- for o in oprs:
- if o.type in ["ImmutableTensor", "Host2DeviceCopy"]:
- param_or_data_provider_oprs.append(o)
- else:
- other_oprs.append(o)
-
- for o in other_oprs:
- for inp in o.inputs:
- if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]:
- if inp.owner in param_or_data_provider_oprs:
- rst.append(inp.owner)
- param_or_data_provider_oprs.remove(inp.owner)
- rst.append(o)
- rst = rst + param_or_data_provider_oprs
- assert len(rst) == len(oprs)
- return rst
-
- map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
- outputs
- )
- oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
- oprs_seq = reorder_oprs_seq(oprs_seq)
- if prune_reshape is True:
- oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
- return oprs_seq
-
-
- def replace_vars(
- dst: List[_VarNode], varmap: Dict[_VarNode, _VarNode]
- ) -> List[_VarNode]:
- r"""Replaces vars in the graph.
-
- Args:
- dst: target vars representing the graph.
- varmap: the map that specifies how to replace the vars.
-
- Returns:
- new vars that correspond to ``dst`` with all the dependencies replaced.
- """
- dst_vec = []
- repl_src_vec = []
- repl_dst_vec = []
- for i in dst:
- assert isinstance(i, _VarNode)
- dst_vec.append(i)
-
- for i, j in getattr(varmap, "items", lambda: varmap)():
- assert isinstance(i, _VarNode)
- assert isinstance(j, _VarNode)
- repl_src_vec.append(i)
- repl_dst_vec.append(j)
-
- return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
-
-
- def replace_oprs(dst: List[_VarNode], oprmap: Dict[_OpNode, _OpNode]) -> List[_VarNode]:
- """Replaces operators in the graph.
-
- Args:
- dst: target vars representing the graph.
- oprmap: the map that specifies how to replace the operators.
-
- Returns:
- new vars that correspond to ``dst`` with all the dependencies replaced.
- """
- dst_vec = []
- repl_src_vec = []
- repl_dst_vec = []
- for i in dst:
- assert isinstance(i, _VarNode)
- dst_vec.append(i)
-
- for i, j in getattr(oprmap, "items", lambda: oprmap)():
- assert isinstance(i, _OpNode)
- assert isinstance(j, _OpNode)
- repl_src_vec.append(i)
- repl_dst_vec.append(j)
-
- return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
-
-
- def find_vars_by_name(dst: List[_VarNode], names: List[str]) -> List[_VarNode]:
- r"""Gets VarNode list by names in the graph.
-
- Args:
- dst: target vars representing the graph.
- names: name list for target VarNode.
-
- Returns:
- results found by names.
- """
- output_names = names.copy()
- all_vars = get_dep_vars(dst) + dst
- # use dict to keep outputs order the same as names.
- output_dict = {}
- for i in all_vars:
- if i.name in output_names:
- output_dict[i.name] = i
- output_names.remove(i.name)
- assert len(output_names) == 0, "Can not find varnode {} in this model".format(
- output_names
- )
- return [output_dict[i] for i in names]
-
-
- def convert_inputs(
- dst: List[_VarNode], inputs: List[_VarNode] = None
- ) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
- r"""Replaces ``Host2DeviceCopy`` with :class:`~.InputNode` in the graph
- to :meth:`~.InputNode.set_value` and run.
-
- Args:
- dst: target vars representing the graph.
- inputs: indicates which inputs to be replaced. All
- inputs(``Host2DeiceCopy``) will be replaced if not specified.
-
- Returns:
- new vars that correspond to ``dst`` with all inputs replaced, and new inputs dict.
- """
- if inputs is None:
- inputs = get_dep_vars(dst, "Host2DeviceCopy")
- input_dict = OrderedDict()
- replace_dict = {}
- for inp in inputs:
- inp_node = G.InputNode(
- device=inp.comp_node, dtype=inp.dtype, shape=inp.shape, graph=inp.graph,
- )
- inp_node.name = inp.name
- input_dict[inp.name] = inp_node
- replace_dict[inp] = inp_node.outputs[0]
- new_output_nodes = replace_vars(dst, replace_dict)
- for old, new in zip(dst, new_output_nodes):
- new.name = old.name
- return new_output_nodes, input_dict
-
-
- def convert_outputs(dst: List[_VarNode]) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
- r"""Wraps ``dst`` with :class:`~.OutputNode` in the graph to get outputs
- with :meth:`~.OutputNode.get_value`.
-
- Args:
- dst: target vars representing the graph.
-
- Returns:
- new vars that correspond to ``dst`` with all inputs replaced, and outputs dict.
- """
- output_dict = OrderedDict([(i.name, G.OutputNode(i)) for i in dst])
- new_output_nodes = [i.outputs[0] for i in output_dict.values()]
- return new_output_nodes, output_dict
-
-
- def embed_inputs(
- dst: List[_VarNode], data: List[np.ndarray], inputs: List[_VarNode] = None
- ) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
- r"""Embeds ``data`` to the graph's inputs of ``dst``.
-
- Args:
- dst: target vars representing the graph.
- data: data to be embeded.
- inputs: indicates which inputs to be replaced. All
- inputs(``Host2DeiceCopy``) will be replaced if not specified.
-
- Returns:
- new vars that correspond to ``dst`` with all inputs replaced, and new inputs dict.
- """
- if inputs is None:
- inputs = get_dep_vars(dst, "Host2DeviceCopy")
- assert len(data) == len(inputs)
- input_dict = OrderedDict()
- replace_dict = {}
- for inp, d in zip(inputs, data):
- new_inp = _imperative_rt.make_shared(inp.graph, Tensor(d)._dev_tensor())
- new_inp.name = inp.name
- input_dict[inp.name] = new_inp
- replace_dict[inp] = new_inp
- new_output_nodes = replace_vars(dst, replace_dict)
- for old, new in zip(dst, new_output_nodes):
- new.name = old.name
- return new_output_nodes, input_dict
-
-
- class GraphInference:
- r"""Loads a serialized computing graph as a GraphInference object which can be used
- to execute the computing graph.
-
- Args:
- file: could be file object or filename.
- outputs: only compile the subgraph with outputs as its endpoints.
- """
-
- def __init__(
- self,
- file,
- outputs: List[str] = None,
- profiling: bool = False,
- optimize_for_inference: bool = False,
- **kwargs
- ):
- ret = G.load_graph(file)
- self._graph, output_nodes = ret.graph, ret.output_vars_list
- if outputs is not None:
- output_nodes = find_vars_by_name(output_nodes, outputs)
- self._origin_outputs = output_nodes
-
- # replace inputs with `InputNode`
- output_nodes, self._inp_dict = convert_inputs(output_nodes)
-
- # replace outputs with `OutputNode`
- output_nodes, self._oup_dict = convert_outputs(output_nodes)
-
- self._func = self._graph.compile(output_nodes)
-
- def run(
- self, *inp_args: np.ndarray, inp_dict: Dict[str, np.ndarray] = None
- ) -> Dict[str, np.ndarray]:
- r"""
-
- Args:
- inp_args: list of input datas.
- inp_dict: dict of named input datas.
-
- Returns:
- a dict {output_name: output_value}.
-
- Note:
- Note that the order of the Graph's input nodes may be different from the order of the origin traced function's arguments.
- It is recommended to use ``inp_dict`` to provide input data by name.
- """
- assert len(inp_args) <= len(
- self._inp_dict
- ), "This model expects {} inputs".format(len(self._inp_dict))
- inputs = {}
- inp_keys = list(self._inp_dict.keys())
- for ind, data in enumerate(inp_args):
- inputs[inp_keys[ind]] = data
- if inp_dict is not None:
- inputs.update(inp_dict)
- assert (
- inputs.keys() == self._inp_dict.keys()
- ), "This model expects inputs {}, but gets inputs {}".format(
- list(self._inp_dict.keys()), list(inputs.keys())
- )
- for key in self._inp_dict:
- self._inp_dict[key].set_value(
- Tensor(inputs[key], device=self._inp_dict[key].device)._dev_tensor()
- )
- self._func.execute()
- self._func.wait()
-
- result = OrderedDict()
- for key in self._oup_dict:
- result[key] = self._oup_dict[key].get_value().numpy()
- return result
|