@@ -78,7 +78,6 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||||
from .serialization import load, save | from .serialization import load, save | ||||
from .tensor import Parameter, Tensor, tensor | from .tensor import Parameter, Tensor, tensor | ||||
from .version import __version__ | from .version import __version__ | ||||
from .utils import comp_graph_tools as cgtools | |||||
_set_fork_exec_path_for_timed_func( | _set_fork_exec_path_for_timed_func( | ||||
sys.executable, | sys.executable, | ||||
@@ -15,6 +15,19 @@ from ..core._imperative_rt import OperatorNode, VarNode | |||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..core.tensor.raw_tensor import as_raw_tensor | from ..core.tensor.raw_tensor import as_raw_tensor | ||||
__all__ = [ | |||||
"get_dep_vars", | |||||
"get_owner_opr_inputs", | |||||
"get_owner_opr_type", | |||||
"get_opr_type", | |||||
"graph_traversal", | |||||
"get_oprs_seq", | |||||
"replace_vars", | |||||
"replace_oprs", | |||||
"set_priority_to_id", | |||||
"load_and_inference", | |||||
] | |||||
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: | def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: | ||||
""" | """ | ||||
@@ -166,7 +179,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo | |||||
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor | # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor | ||||
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph | # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph | ||||
def prune_reshape_oprs(outputs, oprs_seq, var2oprs): | def prune_reshape_oprs(outputs, oprs_seq, var2oprs): | ||||
def iterative_pruning(cur_opr, post_opr, marked_opr_ids): | |||||
def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited): | |||||
useless = True | useless = True | ||||
for oup in cur_opr.outputs: | for oup in cur_opr.outputs: | ||||
if "workspace" not in oup.name: | if "workspace" not in oup.name: | ||||
@@ -177,15 +190,20 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo | |||||
if useless: | if useless: | ||||
marked_opr_ids.append(cur_opr.id) | marked_opr_ids.append(cur_opr.id) | ||||
for inp in cur_opr.inputs: | |||||
iterative_pruning(inp.owner, cur_opr, marked_opr_ids) | |||||
for opr in set([var.owner for var in cur_opr.inputs]): | |||||
if (opr.id, cur_opr.id) not in visited: | |||||
visited.add((opr.id, cur_opr.id)) | |||||
iterative_pruning(opr, cur_opr, marked_opr_ids, visited) | |||||
reshape_vars = get_dep_vars(outputs, "Reshape") | reshape_vars = get_dep_vars(outputs, "Reshape") | ||||
reshape_oprs = [var.owner for var in reshape_vars] | reshape_oprs = [var.owner for var in reshape_vars] | ||||
marked_opr_ids = [] | marked_opr_ids = [] | ||||
visited = set() | |||||
for reshape_opr in reshape_oprs: | for reshape_opr in reshape_oprs: | ||||
iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids) | |||||
iterative_pruning( | |||||
reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited | |||||
) | |||||
# filter out all marked oprs | # filter out all marked oprs | ||||
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | ||||
@@ -13,9 +13,10 @@ import pytest | |||||
import megengine | import megengine | ||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
from megengine import cgtools | |||||
import megengine.utils.comp_graph_tools as cgtools | |||||
from megengine.core.tensor import megbrain_graph as mgb_graph | from megengine.core.tensor import megbrain_graph as mgb_graph | ||||
from megengine.core.tensor.raw_tensor import as_raw_tensor | from megengine.core.tensor.raw_tensor import as_raw_tensor | ||||
from megengine.core.tensor.utils import astensor1d | |||||
from megengine.jit import trace | from megengine.jit import trace | ||||
@@ -98,3 +99,38 @@ def test_load_refcnt(): | |||||
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) | graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) | ||||
del graph | del graph | ||||
varnode.owner | varnode.owner | ||||
def test_get_opr_seq(): | |||||
class Net(M.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.data = megengine.tensor( | |||||
np.random.random((1, 1, 4, 4)), dtype=np.float32 | |||||
) | |||||
def forward(self, input): | |||||
A = input.shape[0] | |||||
shape = astensor1d((A, A), self.data, dtype="int32", device=input.device) | |||||
x = F.reshape(self.data, shape) | |||||
o = input + x | |||||
return o | |||||
net = Net() | |||||
input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def func(inp, *, net=None): | |||||
return net(inp) | |||||
func(input, net=net) | |||||
file = io.BytesIO() | |||||
func.dump(file, optimize_for_inference=False) | |||||
file.seek(0) | |||||
*_, outputs = mgb_graph.load_graph(file) | |||||
seq_1 = cgtools.get_oprs_seq(outputs, True) | |||||
assert len(seq_1) == 5 | |||||
seq_2 = cgtools.get_oprs_seq(outputs, False) | |||||
assert len(seq_2) == 6 |
@@ -14,7 +14,8 @@ import pytest | |||||
import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import cgtools, tensor | |||||
import megengine.utils.comp_graph_tools as cgtools | |||||
from megengine import tensor | |||||
from megengine.core._trace_option import set_symbolic_shape | from megengine.core._trace_option import set_symbolic_shape | ||||
from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||