|
|
@@ -17,8 +17,8 @@ from ..core.tensor.raw_tensor import as_raw_tensor |
|
|
|
|
|
|
|
|
|
|
|
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: |
|
|
|
"""return :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var`` |
|
|
|
depands on. If ``var_type`` is None, return all types. |
|
|
|
"""Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var`` |
|
|
|
depands on. If ``var_type`` is None, returns all types. |
|
|
|
""" |
|
|
|
outputs = [] |
|
|
|
memo = set() |
|
|
@@ -46,14 +46,14 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: |
|
|
|
|
|
|
|
|
|
|
|
def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: |
|
|
|
"""get the inputs of owner opr of a variable |
|
|
|
"""Gets the inputs of owner opr of a variable. |
|
|
|
""" |
|
|
|
assert isinstance(var, VarNode) |
|
|
|
return var.owner.inputs |
|
|
|
|
|
|
|
|
|
|
|
def get_owner_opr_type(var: VarNode) -> str: |
|
|
|
"""get the type of owner opr of a variable |
|
|
|
"""Gets the type of owner opr of a variable. |
|
|
|
|
|
|
|
""" |
|
|
|
assert isinstance(var, VarNode) |
|
|
@@ -61,16 +61,16 @@ def get_owner_opr_type(var: VarNode) -> str: |
|
|
|
|
|
|
|
|
|
|
|
def get_opr_type(opr: OperatorNode) -> str: |
|
|
|
"""get the type of a opr |
|
|
|
"""Gets the type of an opr. |
|
|
|
""" |
|
|
|
assert isinstance(opr, OperatorNode) |
|
|
|
return opr.type |
|
|
|
|
|
|
|
|
|
|
|
def graph_traversal(outputs: VarNode): |
|
|
|
"""helper function to traverse the computing graph and return enough useful information |
|
|
|
"""Helper function to traverse the computing graph and return enough useful information. |
|
|
|
|
|
|
|
:param outputs: model outputs |
|
|
|
:param outputs: model outputs. |
|
|
|
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree) |
|
|
|
WHERE |
|
|
|
map_oprs is dict from opr_id to actual opr |
|
|
@@ -124,11 +124,11 @@ def graph_traversal(outputs: VarNode): |
|
|
|
|
|
|
|
|
|
|
|
def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]: |
|
|
|
"""get oprs in some topological order for a dumped model |
|
|
|
"""Gets oprs in some topological order for a dumped model. |
|
|
|
|
|
|
|
:param outputs: model outputs |
|
|
|
:param prune_reshape: whether to prune the operators useless during inference |
|
|
|
:return: opr list with some correct execution order |
|
|
|
:param outputs: model outputs. |
|
|
|
:param prune_reshape: whether to prune the useless operators during inference. |
|
|
|
:return: opr list with some correct execution order. |
|
|
|
""" |
|
|
|
|
|
|
|
def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree): |
|
|
@@ -194,13 +194,13 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo |
|
|
|
|
|
|
|
|
|
|
|
def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]: |
|
|
|
"""replace vars in the graph |
|
|
|
"""Replaces vars in the graph. |
|
|
|
|
|
|
|
:param dst: target vars representing the graph |
|
|
|
:param varmap: the map that specifies how to replace the vars |
|
|
|
:param dst: target vars representing the graph. |
|
|
|
:param varmap: the map that specifies how to replace the vars. |
|
|
|
|
|
|
|
:return: new vars that correspond to ``dst`` with all the dependencies |
|
|
|
replaced |
|
|
|
replaced. |
|
|
|
""" |
|
|
|
dst_vec = [] |
|
|
|
repl_src_vec = [] |
|
|
@@ -221,13 +221,13 @@ def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]: |
|
|
|
def replace_oprs( |
|
|
|
dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode] |
|
|
|
) -> List[VarNode]: |
|
|
|
"""Replace operators in the graph. |
|
|
|
"""Replaces operators in the graph. |
|
|
|
|
|
|
|
:param dst: target vars representing the graph |
|
|
|
:param oprmap: the map that specifies how to replace the operators |
|
|
|
:param dst: target vars representing the graph. |
|
|
|
:param oprmap: the map that specifies how to replace the operators. |
|
|
|
|
|
|
|
:return: new vars that correspond to ``dst`` with all the dependencies |
|
|
|
replaced |
|
|
|
replaced. |
|
|
|
""" |
|
|
|
dst_vec = [] |
|
|
|
repl_src_vec = [] |
|
|
@@ -246,9 +246,9 @@ def replace_oprs( |
|
|
|
|
|
|
|
|
|
|
|
def set_priority_to_id(dest_vars): |
|
|
|
"""For all oprs in the subgraph constructed by dest_vars |
|
|
|
set its priority to id if its original priority is zero |
|
|
|
:param dest_vars: target vars representing the graph |
|
|
|
"""For all oprs in the subgraph constructed by dest_vars, |
|
|
|
sets its priority to id if its original priority is zero. |
|
|
|
:param dest_vars: target vars representing the graph. |
|
|
|
""" |
|
|
|
dest_vec = [] |
|
|
|
for i in dest_vars: |
|
|
@@ -258,11 +258,11 @@ def set_priority_to_id(dest_vars): |
|
|
|
|
|
|
|
|
|
|
|
def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.ndarray]: |
|
|
|
"""Load a serialized computing graph and run inference with input data. |
|
|
|
"""Loads a serialized computing graph and run inference with input data. |
|
|
|
|
|
|
|
:param file: Path or Handle of the input file. |
|
|
|
:param inp_data_list: List of input data. |
|
|
|
:return: List of inference results. |
|
|
|
:param file: path or handle of the input file. |
|
|
|
:param inp_data_list: list of input data. |
|
|
|
:return: list of inference results. |
|
|
|
|
|
|
|
""" |
|
|
|
*_, out_list = G.load_graph(file) |
|
|
|