Browse Source

fix(mge/utils): fix get_oprs_seq of cgtools

GitOrigin-RevId: 366a56f4d5
release-1.1
Megvii Engine Team 4 years ago
parent
commit
8118a5946d
4 changed files with 61 additions and 7 deletions
  1. +0
    -1
      imperative/python/megengine/__init__.py
  2. +22
    -4
      imperative/python/megengine/utils/comp_graph_tools.py
  3. +37
    -1
      imperative/python/test/unit/test_cgtools.py
  4. +2
    -1
      imperative/python/test/unit/test_tracing.py

+ 0
- 1
imperative/python/megengine/__init__.py View File

@@ -78,7 +78,6 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
from .tensor import Parameter, Tensor, tensor from .tensor import Parameter, Tensor, tensor
from .version import __version__ from .version import __version__
from .utils import comp_graph_tools as cgtools


_set_fork_exec_path_for_timed_func( _set_fork_exec_path_for_timed_func(
sys.executable, sys.executable,


+ 22
- 4
imperative/python/megengine/utils/comp_graph_tools.py View File

@@ -15,6 +15,19 @@ from ..core._imperative_rt import OperatorNode, VarNode
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.raw_tensor import as_raw_tensor from ..core.tensor.raw_tensor import as_raw_tensor


__all__ = [
"get_dep_vars",
"get_owner_opr_inputs",
"get_owner_opr_type",
"get_opr_type",
"graph_traversal",
"get_oprs_seq",
"replace_vars",
"replace_oprs",
"set_priority_to_id",
"load_and_inference",
]



def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
""" """
@@ -166,7 +179,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
def prune_reshape_oprs(outputs, oprs_seq, var2oprs): def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
def iterative_pruning(cur_opr, post_opr, marked_opr_ids):
def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited):
useless = True useless = True
for oup in cur_opr.outputs: for oup in cur_opr.outputs:
if "workspace" not in oup.name: if "workspace" not in oup.name:
@@ -177,15 +190,20 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo
if useless: if useless:
marked_opr_ids.append(cur_opr.id) marked_opr_ids.append(cur_opr.id)


for inp in cur_opr.inputs:
iterative_pruning(inp.owner, cur_opr, marked_opr_ids)
for opr in set([var.owner for var in cur_opr.inputs]):
if (opr.id, cur_opr.id) not in visited:
visited.add((opr.id, cur_opr.id))
iterative_pruning(opr, cur_opr, marked_opr_ids, visited)


reshape_vars = get_dep_vars(outputs, "Reshape") reshape_vars = get_dep_vars(outputs, "Reshape")
reshape_oprs = [var.owner for var in reshape_vars] reshape_oprs = [var.owner for var in reshape_vars]


marked_opr_ids = [] marked_opr_ids = []
visited = set()
for reshape_opr in reshape_oprs: for reshape_opr in reshape_oprs:
iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids)
iterative_pruning(
reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited
)


# filter out all marked oprs # filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))


+ 37
- 1
imperative/python/test/unit/test_cgtools.py View File

@@ -13,9 +13,10 @@ import pytest
import megengine import megengine
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import cgtools
import megengine.utils.comp_graph_tools as cgtools
from megengine.core.tensor import megbrain_graph as mgb_graph from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.utils import astensor1d
from megengine.jit import trace from megengine.jit import trace




@@ -98,3 +99,38 @@ def test_load_refcnt():
graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf)) graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf))
del graph del graph
varnode.owner varnode.owner


def test_get_opr_seq():
class Net(M.Module):
def __init__(self):
super().__init__()
self.data = megengine.tensor(
np.random.random((1, 1, 4, 4)), dtype=np.float32
)

def forward(self, input):
A = input.shape[0]
shape = astensor1d((A, A), self.data, dtype="int32", device=input.device)
x = F.reshape(self.data, shape)
o = input + x
return o

net = Net()
input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32)

@trace(symbolic=True, capture_as_const=True)
def func(inp, *, net=None):
return net(inp)

func(input, net=net)
file = io.BytesIO()
func.dump(file, optimize_for_inference=False)
file.seek(0)
*_, outputs = mgb_graph.load_graph(file)

seq_1 = cgtools.get_oprs_seq(outputs, True)
assert len(seq_1) == 5

seq_2 = cgtools.get_oprs_seq(outputs, False)
assert len(seq_2) == 6

+ 2
- 1
imperative/python/test/unit/test_tracing.py View File

@@ -14,7 +14,8 @@ import pytest


import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F import megengine.functional as F
from megengine import cgtools, tensor
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise


Loading…
Cancel
Save