@@ -7,12 +7,14 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import collections
from collections import OrderedDict
from collections import OrderedDict
from typing import Dict, List, Optional
from typing import Dict, List, Tuple, Union
import numpy
import numpy as np
from ..core import _imperative_rt
from ..core import _imperative_rt
from ..core._imperative_rt import OperatorNode, VarNode
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt import OperatorNode as _OpNode
from ..core._imperative_rt import VarNode as _VarNode
from ..core.tensor import megbrain_graph as G
from ..core.tensor import megbrain_graph as G
from ..core.tensor.megbrain_graph import set_priority_to_id
from ..core.tensor.megbrain_graph import set_priority_to_id
from ..tensor import Tensor
from ..tensor import Tensor
@@ -31,7 +33,9 @@ __all__ = [
]
]
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
def get_dep_vars(
var: Union[_VarNode, List[_VarNode]], var_type: Union[str, List[str]] = None
) -> List[_VarNode]:
"""
"""
Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
depands on. If ``var_type`` is None, returns all types.
depands on. If ``var_type`` is None, returns all types.
@@ -39,7 +43,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
outputs = []
outputs = []
memo = set()
memo = set()
if isinstance(var, VarNode):
if isinstance(var, _ VarNode):
var = [var]
var = [var]
if isinstance(var_type, str):
if isinstance(var_type, str):
@@ -61,14 +65,14 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
return outputs
return outputs
def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
def get_owner_opr_inputs(var: _ VarNode) -> List[_ VarNode]:
"""
"""
Gets the inputs of owner opr of a variable.
Gets the inputs of owner opr of a variable.
"""
"""
return var.owner.inputs
return var.owner.inputs
def get_owner_opr_type(var: VarNode) -> str:
def get_owner_opr_type(var: _ VarNode) -> str:
"""
"""
Gets the type of owner opr of a variable.
Gets the type of owner opr of a variable.
@@ -76,15 +80,15 @@ def get_owner_opr_type(var: VarNode) -> str:
return var.owner.type
return var.owner.type
def get_opr_type(opr: Operator Node) -> str:
def get_opr_type(opr: _ OpNode) -> str:
"""
"""
Gets the type of an opr.
Gets the type of an opr.
"""
"""
assert isinstance(opr, Operator Node)
assert isinstance(opr, _ OpNode)
return opr.type
return opr.type
def graph_traversal(outputs: VarNode):
def graph_traversal(outputs: _ VarNode):
"""
"""
Helper function to traverse the computing graph and return enough useful information.
Helper function to traverse the computing graph and return enough useful information.
@@ -142,8 +146,8 @@ def graph_traversal(outputs: VarNode):
def get_oprs_seq(
def get_oprs_seq(
outputs: List[VarNode], prune_reshape=False, prune_immtensor=True
) -> List[Operator Node]:
outputs: List[_ VarNode], prune_reshape=False, prune_immtensor=True
) -> List[_ OpNode]:
"""
"""
Gets oprs in some topological order for a dumped model.
Gets oprs in some topological order for a dumped model.
@@ -218,7 +222,9 @@ def get_oprs_seq(
return oprs_seq
return oprs_seq
def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
def replace_vars(
dst: List[_VarNode], varmap: Dict[_VarNode, _VarNode]
) -> List[_VarNode]:
"""
"""
Replaces vars in the graph.
Replaces vars in the graph.
@@ -232,21 +238,19 @@ def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
repl_src_vec = []
repl_src_vec = []
repl_dst_vec = []
repl_dst_vec = []
for i in dst:
for i in dst:
assert isinstance(i, VarNode)
assert isinstance(i, _ VarNode)
dst_vec.append(i)
dst_vec.append(i)
for i, j in getattr(varmap, "items", lambda: varmap)():
for i, j in getattr(varmap, "items", lambda: varmap)():
assert isinstance(i, VarNode)
assert isinstance(j, VarNode)
assert isinstance(i, _ VarNode)
assert isinstance(j, _ VarNode)
repl_src_vec.append(i)
repl_src_vec.append(i)
repl_dst_vec.append(j)
repl_dst_vec.append(j)
return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
def replace_oprs(
dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode]
) -> List[VarNode]:
def replace_oprs(dst: List[_VarNode], oprmap: Dict[_OpNode, _OpNode]) -> List[_VarNode]:
"""
"""
Replaces operators in the graph.
Replaces operators in the graph.
@@ -260,65 +264,154 @@ def replace_oprs(
repl_src_vec = []
repl_src_vec = []
repl_dst_vec = []
repl_dst_vec = []
for i in dst:
for i in dst:
assert isinstance(i, VarNode)
assert isinstance(i, _ VarNode)
dst_vec.append(i)
dst_vec.append(i)
for i, j in getattr(oprmap, "items", lambda: oprmap)():
for i, j in getattr(oprmap, "items", lambda: oprmap)():
assert isinstance(i, Operator Node)
assert isinstance(j, Operator Node)
assert isinstance(i, _ OpNode)
assert isinstance(j, _ OpNode)
repl_src_vec.append(i)
repl_src_vec.append(i)
repl_dst_vec.append(j)
repl_dst_vec.append(j)
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def find_vars_by_name(dst: List[_VarNode], names: List[str]) -> List[_VarNode]:
"""
Gets VarNode list by names in the graph.
:param dst: target vars representing the graph.
:param names: name list for target VarNode.
:return: results found by names.
"""
output_names = names.copy()
all_vars = get_dep_vars(dst) + dst
# use dict to keep outputs order the same as names.
output_dict = {}
for i in all_vars:
if i.name in output_names:
output_dict[i.name] = i
output_names.remove(i.name)
assert len(output_names) == 0, "Can not find varnode {} in this model".format(
output_names
)
return [output_dict[i] for i in names]
def convert_inputs(
dst: List[_VarNode], inputs: List[_VarNode] = None
) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Replaces ``Host2DeviceCopy`` with :class:`~.InputNode` in the graph
to :meth:`~.InputNode.set_value` and run.
:param dst: target vars representing the graph.
:param inputs: indicates which inputs to be replaced. All
inputs(``Host2DeiceCopy``) will be replaced if not specified.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and new inputs dict.
"""
if inputs is None:
inputs = get_dep_vars(dst, "Host2DeviceCopy")
input_dict = OrderedDict()
replace_dict = {}
for inp in inputs:
inp_node = G.InputNode(
device=inp.comp_node, dtype=inp.dtype, shape=inp.shape, graph=inp.graph,
)
inp_node.name = inp.name
input_dict[inp.name] = inp_node
replace_dict[inp] = inp_node.outputs[0]
new_output_nodes = replace_vars(dst, replace_dict)
for old, new in zip(dst, new_output_nodes):
new.name = old.name
return new_output_nodes, input_dict
def convert_outputs(dst: List[_VarNode]) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Wraps ``dst`` with :class:`~.OutputNode` in the graph to get outputs
with :meth:`~.OutputNode.get_value`.
:param dst: target vars representing the graph.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and outputs dict.
"""
output_dict = OrderedDict([(i.name, G.OutputNode(i)) for i in dst])
new_output_nodes = [i.outputs[0] for i in output_dict.values()]
return new_output_nodes, output_dict
def embed_inputs(
dst: List[_VarNode], data: List[np.ndarray], inputs: List[_VarNode] = None
) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Embeds ``data`` to the graph's inputs of ``dst``.
:param dst: target vars representing the graph.
:param data: data to be embeded.
:param inputs: indicates which inputs to be replaced. All
inputs(``Host2DeiceCopy``) will be replaced if not specified.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and new inputs dict.
"""
if inputs is None:
inputs = get_dep_vars(dst, "Host2DeviceCopy")
assert len(data) == len(inputs)
input_dict = OrderedDict()
replace_dict = {}
for inp, d in zip(inputs, data):
new_inp = _imperative_rt.make_shared(inp.graph, Tensor(d)._dev_tensor())
new_inp.name = inp.name
input_dict[inp.name] = new_inp
replace_dict[inp] = new_inp
new_output_nodes = replace_vars(dst, replace_dict)
for old, new in zip(dst, new_output_nodes):
new.name = old.name
return new_output_nodes, input_dict
class GraphInference:
class GraphInference:
"""
"""
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.
The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}.
Loads a serialized computing graph as a GraphInference object which can be used
to execute the computing graph .
:param file: could be file object or filename.
:param file: could be file object or filename.
:param outputs: only compile the subgraph with outputs as its endpoints.
:param outputs: only compile the subgraph with outputs as its endpoints.
"""
"""
def __init__(self, file, outputs: Optional[List[str]] = None):
*_, output_nodes = G.load_graph(file)
def __init__(
self,
file,
outputs: List[str] = None,
profiling: bool = False,
optimize_for_inference: bool = False,
**kwargs
):
self._graph, _, output_nodes = G.load_graph(file)
if outputs is not None:
if outputs is not None:
output_name = outputs.copy()
all_vars = get_dep_vars(output_nodes) + output_nodes
new_outputs = {}
for i in all_vars:
if i.name in output_name:
new_outputs[i.name] = i
output_name.remove(i.name)
assert (
len(output_name) == 0
), "Can not find varnode {} in this model".format(output_name)
output_nodes = [new_outputs[i] for i in outputs]
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict()
replace_dict = {}
for idx, i in enumerate(inputs):
inp_node = G.InputNode(
device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph
)
self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0]
new_output_nodes = replace_vars(output_nodes, replace_dict)
for old, new in zip(output_nodes, new_output_nodes):
new.name = old.name
self._out_dict = OrderedDict(
[(i.name, G.OutputNode(i)) for i in new_output_nodes]
)
new_out_list = [i.outputs[0] for i in self._out_dict.values()]
cg = new_out_list[0].graph
self._func = cg.compile(new_out_list)
output_nodes = find_vars_by_name(output_nodes, outputs)
self._origin_outputs = output_nodes
# replace inputs with `InputNode`
output_nodes, self._inp_dict = convert_inputs(output_nodes)
# replace outputs with `OutputNode`
output_nodes, self._oup_dict = convert_outputs(output_nodes)
self._func = self._graph.compile(output_nodes)
def run(
def run(
self,
*inp_args: numpy.ndarray,
inp_dict: Optional[Dict[str, numpy.ndarray]] = None
):
self, *inp_args: np.ndarray, inp_dict: Dict[str, np.ndarray] = None
) -> Dict[str, np.ndarray]:
"""
:param inp_args: list of input datas.
:param inp_dict: dict of named input datas.
:return: a dict {output_name: output_value}.
"""
assert len(inp_args) <= len(
assert len(inp_args) <= len(
self._inp_dict
self._inp_dict
), "This model expects {} inputs".format(len(self._inp_dict))
), "This model expects {} inputs".format(len(self._inp_dict))
@@ -335,8 +428,11 @@ class GraphInference:
)
)
for key in self._inp_dict:
for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._func.execute()
self._func.execute()
self._func.wait()
result = OrderedDict()
result = OrderedDict()
for key in self._out_dict:
result[key] = self._out _dict[key].get_value().numpy()
for key in self._oup _dict:
result[key] = self._oup _dict[key].get_value().numpy()
return result
return result