|
|
@@ -15,6 +15,19 @@ from ..core._imperative_rt import OperatorNode, VarNode |
|
|
|
from ..core.tensor import megbrain_graph as G |
|
|
|
from ..core.tensor.raw_tensor import as_raw_tensor |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
"get_dep_vars", |
|
|
|
"get_owner_opr_inputs", |
|
|
|
"get_owner_opr_type", |
|
|
|
"get_opr_type", |
|
|
|
"graph_traversal", |
|
|
|
"get_oprs_seq", |
|
|
|
"replace_vars", |
|
|
|
"replace_oprs", |
|
|
|
"set_priority_to_id", |
|
|
|
"load_and_inference", |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: |
|
|
|
""" |
|
|
@@ -166,7 +179,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo |
|
|
|
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor |
|
|
|
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph |
|
|
|
def prune_reshape_oprs(outputs, oprs_seq, var2oprs): |
|
|
|
def iterative_pruning(cur_opr, post_opr, marked_opr_ids): |
|
|
|
def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited): |
|
|
|
useless = True |
|
|
|
for oup in cur_opr.outputs: |
|
|
|
if "workspace" not in oup.name: |
|
|
@@ -177,15 +190,20 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo |
|
|
|
if useless: |
|
|
|
marked_opr_ids.append(cur_opr.id) |
|
|
|
|
|
|
|
for inp in cur_opr.inputs: |
|
|
|
iterative_pruning(inp.owner, cur_opr, marked_opr_ids) |
|
|
|
for opr in set([var.owner for var in cur_opr.inputs]): |
|
|
|
if (opr.id, cur_opr.id) not in visited: |
|
|
|
visited.add((opr.id, cur_opr.id)) |
|
|
|
iterative_pruning(opr, cur_opr, marked_opr_ids, visited) |
|
|
|
|
|
|
|
reshape_vars = get_dep_vars(outputs, "Reshape") |
|
|
|
reshape_oprs = [var.owner for var in reshape_vars] |
|
|
|
|
|
|
|
marked_opr_ids = [] |
|
|
|
visited = set() |
|
|
|
for reshape_opr in reshape_oprs: |
|
|
|
iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids) |
|
|
|
iterative_pruning( |
|
|
|
reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited |
|
|
|
) |
|
|
|
|
|
|
|
# filter out all marked oprs |
|
|
|
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) |
|
|
|