Browse Source

feat(mge/cgtools): add network vistior interface with optional pruning

GitOrigin-RevId: cfa69e3e83
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
90107b6dda
3 changed files with 187 additions and 0 deletions
  1. +139
    -0
      python_module/megengine/_internal/comp_graph_tools.py
  2. +4
    -0
      python_module/src/swig/comp_graph_tools.i
  3. +44
    -0
      python_module/test/unit/jit/test_jit.py

+ 139
- 0
python_module/megengine/_internal/comp_graph_tools.py View File

@@ -67,6 +67,145 @@ def get_type(var):
return _mgb._get_owner_opr_type(var)


def get_opr_type(opr):
"""get the type of a opr

:type var: :class:`.Operator`
:rtype: ``str``
"""
assert isinstance(opr, _mgb.Operator)
return _mgb._get_opr_type(opr)


def graph_traversal(outputs):
"""helper function to traverse the computing graph and reeturn enough useful information

:param outputs: model outputs
:type outputs: :class:`.Symbolvar`
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
WHERE
map_oprs is dict from opr_id to actual opr
map_vars is dict from var_id to actual var
var2oprs is dict from var to dest oprs along with index
opr2receivers is dict from current opr to next opr
indegree2opr is dict from in_degree to opr in computing graph
opr2indegree is dict from opr in computing graph to in_degree

(indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
"""
# meta information for comp graph
map_oprs = collections.defaultdict(set)
map_vars = collections.defaultdict(set)

var2oprs = collections.defaultdict(list)
opr2receivers = collections.defaultdict(list)

queue = list(map(lambda x: x.owner_opr, outputs))
visited = set(map(lambda x: x.id, queue))

# iterate through whole comp_graph, fill in meta information
indegree2opr = collections.defaultdict(set)
opr2indegree = {}

idx = 0
while idx < len(queue):
cur_opr = queue[idx]
map_oprs[cur_opr.id] = cur_opr

idx += 1

indegree = 0
for var_idx, var in enumerate(cur_opr.inputs):
map_vars[var.id] = var
var2oprs[var.id].append((cur_opr.id, var_idx))

pre_opr = var.owner_opr

if pre_opr.id not in visited:
visited.add(pre_opr.id)
queue.append(pre_opr)

indegree += 1
opr2receivers[pre_opr.id].append(cur_opr.id)

indegree2opr[indegree].add(cur_opr.id)
opr2indegree[cur_opr.id] = indegree

return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree


def get_oprs_seq(outputs, prune_reshape=False):
"""get oprs in some topological order for a dumped model

:param outputs: model outputs
:param prune_reshape: whether to prune the operators useless during inference
:return: opr list with some correct execution order
"""

def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
# generate an execution order with topological sort algorithm
oprs_seq = []
nr_remain = len(map_oprs)
while indegree2opr[0]:
opr_id = indegree2opr[0].pop()
opr = map_oprs[opr_id]
nr_remain -= 1

# skip const value generation operator
if get_opr_type(opr) != "ImmutableTensor":
oprs_seq.append(opr)

for post_id in opr2receivers[opr_id]:
indegree = opr2indegree[post_id]
indegree2opr[indegree].remove(post_id)

indegree -= 1
indegree2opr[indegree].add(post_id)
opr2indegree[post_id] = indegree

assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
nr_remain
)
return oprs_seq

# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
def iterative_pruning(cur_opr, post_opr, marked_opr_ids):
useless = True
for oup in cur_opr.outputs:
if "workspace" not in oup.name:
var_idx = post_opr.inputs.index(oup)
var2oprs[oup.id].remove((post_opr.id, var_idx))
useless = useless and (len(var2oprs[oup.id]) == 0)

if useless:
marked_opr_ids.append(cur_opr.id)

for inp in cur_opr.inputs:
iterative_pruning(inp.owner_opr, cur_opr, marked_opr_ids)

reshape_vars = get_dep_vars(outputs, "Reshape")
reshape_oprs = [var.owner_opr for var in reshape_vars]

marked_opr_ids = []
for reshape_opr in reshape_oprs:
iterative_pruning(
reshape_opr.inputs[1].owner_opr, reshape_opr, marked_opr_ids
)

# filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))

map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
outputs
)
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
if prune_reshape is True:
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
return oprs_seq


def replace_vars(dst, varmap):
"""replace vars in the graph



+ 4
- 0
python_module/src/swig/comp_graph_tools.i View File

@@ -10,6 +10,10 @@
return var.node()->owner_opr()->dyn_typeinfo()->name;
}

std::string _get_opr_type(Operator opr) {
return opr.node()->dyn_typeinfo()->name;
}

SymbolVarArray _replace_vars(const SymbolVarArray& repl_src,
const SymbolVarArray& repl_dst,
const SymbolVarArray& vars) {


+ 44
- 0
python_module/test/unit/jit/test_jit.py View File

@@ -15,6 +15,7 @@ import pytest

import megengine as mge
import megengine._internal as mgb
import megengine.functional as F
import megengine.module as M
from megengine import functional as F
from megengine import jit, tensor
@@ -148,6 +149,49 @@ def test_dump_volatile():
assert mgb.cgtools.get_type(mgb.cgtools.get_inputs(out)[1]) == "SharedDeviceTensor"


def test_graph_traversal():
net = M.Conv2d(3, 4, 3, 1, 1, groups=1, bias=False)
net.eval()

@jit.trace(symbolic=True)
def fun(data):
return net(data)

data = np.random.random([1, 3, 224, 224]).astype(np.float32)
fun.trace(data)

with mkstemp() as out:
fun.dump(out)
*_, outputs = mgb.load_comp_graph_from_file(out)

_, map_vars, var2oprs, *_ = mgb.cgtools.graph_traversal(outputs)
input_var = map_vars[1]
_, var_idx = var2oprs[input_var.id][0]

assert var_idx == 0


def test_network_visitor():
@jit.trace(symbolic=True)
def f(x):
# this line will produce shape_of, subtensor and concat op
# after pruning, they will be deleted
target_shape = (x.shape[0], -1)

return x.reshape(*target_shape)

f.trace(tensor(np.random.random([2, 3, 4, 5]).astype(np.float32)))

with mkstemp() as out:
f.dump(out)
*_, outputs = mgb.load_comp_graph_from_file(out)

all_oprs = mgb.cgtools.get_oprs_seq(outputs)
pruned_oprs = mgb.cgtools.get_oprs_seq(outputs, prune_reshape=True)

assert len(all_oprs) == len(pruned_oprs) + 3


def test_shape_tracing():
for symbolic in [False, True]:



Loading…
Cancel
Save