Browse Source

fix(mge/utils): fix toposort to get definition order

GitOrigin-RevId: 47a26dd6dd
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
2d42455fa8
6 changed files with 170 additions and 27 deletions
  1. +4
    -0
      imperative/python/megengine/jit/tracing.py
  2. +71
    -7
      imperative/python/megengine/utils/comp_graph_tools.py
  3. +4
    -3
      imperative/python/megengine/utils/network.py
  4. +42
    -16
      imperative/python/megengine/utils/network_node.py
  5. +8
    -1
      imperative/python/src/graph_rt.cpp
  6. +41
    -0
      imperative/python/test/unit/utils/test_cgtools.py

+ 4
- 0
imperative/python/megengine/jit/tracing.py View File

@@ -893,6 +893,10 @@ class trace:
if isinstance(file, str): if isinstance(file, str):
permission = "wb" if append == False else "ab" permission = "wb" if append == False else "ab"
file = open(file, permission) file = open(file, permission)

if keep_opr_priority:
graph._set_priority_to_id(dest_vars)

dump_content, dump_info = G.dump_graph( dump_content, dump_info = G.dump_graph(
dest_vars, dest_vars,
keep_var_name=keep_var_name, keep_var_name=keep_var_name,


+ 71
- 7
imperative/python/megengine/utils/comp_graph_tools.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import heapq
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union


@@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str:
return opr.type return opr.type




class _OprStableOrderHeapq:
"""heap implementation for operator comparison in stable order"""

_list = None
_extra_priority = None
_used_id_name_pairs = None

def __init__(self, extra_priority):
assert isinstance(extra_priority, collections.Callable)
self._list = []
self._extra_priority = extra_priority
self._used_id_name_pairs = {}

def pop_min(self):
return heapq.heappop(self._list)[-1]

def add(self, opr):
# named as add to mimic set() interface

id_ = opr.id
name = opr.name

other = self._used_id_name_pairs.setdefault((id_, name), opr)
if other is not opr:
raise RuntimeError(
"duplicated (id, name) pair: opr0={} opr1={}".format(other, opr)
)

item = self._extra_priority(opr) + (id_, name, opr)
heapq.heappush(self._list, item)

def __bool__(self):
return bool(self._list)


def graph_traversal(outputs: _VarNode): def graph_traversal(outputs: _VarNode):
""" """
Helper function to traverse the computing graph and return enough useful information. Helper function to traverse the computing graph and return enough useful information.
@@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode):


var2oprs = collections.defaultdict(list) var2oprs = collections.defaultdict(list)
opr2receivers = collections.defaultdict(list) opr2receivers = collections.defaultdict(list)
queue = list(set(map(lambda x: x.owner, outputs)))
queue = []
[queue.append(o) for o in [x.owner for x in outputs] if o not in queue]
visited = set(map(lambda x: x.id, queue)) visited = set(map(lambda x: x.id, queue))


# iterate through whole comp_graph, fill in meta information # iterate through whole comp_graph, fill in meta information
indegree2opr = collections.defaultdict(set) indegree2opr = collections.defaultdict(set)
indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,))
opr2indegree = {} opr2indegree = {}


idx = 0 idx = 0
@@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode):


indegree += 1 indegree += 1
opr2receivers[pre_opr.id].append(cur_opr.id) opr2receivers[pre_opr.id].append(cur_opr.id)
indegree2opr[indegree].add(cur_opr.id)
opr = cur_opr if indegree == 0 else cur_opr.id
indegree2opr[indegree].add(opr)
opr2indegree[cur_opr.id] = indegree opr2indegree[cur_opr.id] = indegree


return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
@@ -162,8 +199,8 @@ def get_oprs_seq(
oprs_seq = [] oprs_seq = []
nr_remain = len(map_oprs) nr_remain = len(map_oprs)
while indegree2opr[0]: while indegree2opr[0]:
opr_id = indegree2opr[0].pop()
opr = map_oprs[opr_id]
opr = indegree2opr[0].pop_min()
opr_id = opr.id
nr_remain -= 1 nr_remain -= 1
if opr.type != "ImmutableTensor" or not prune_immtensor: if opr.type != "ImmutableTensor" or not prune_immtensor:
oprs_seq.append(opr) oprs_seq.append(opr)
@@ -173,7 +210,10 @@ def get_oprs_seq(
indegree2opr[indegree].remove(post_id) indegree2opr[indegree].remove(post_id)


indegree -= 1 indegree -= 1
indegree2opr[indegree].add(post_id)
if indegree == 0:
indegree2opr[indegree].add(map_oprs[post_id])
else:
indegree2opr[indegree].add(post_id)
opr2indegree[post_id] = indegree opr2indegree[post_id] = indegree


assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
@@ -213,10 +253,34 @@ def get_oprs_seq(
# filter out all marked oprs # filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))


# adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs.
def reorder_oprs_seq(oprs):
rst = []
param_or_data_provider_oprs = []
other_oprs = []

for o in oprs:
if o.type in ["ImmutableTensor", "Host2DeviceCopy"]:
param_or_data_provider_oprs.append(o)
else:
other_oprs.append(o)

for o in other_oprs:
for inp in o.inputs:
if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]:
if inp.owner in param_or_data_provider_oprs:
rst.append(inp.owner)
param_or_data_provider_oprs.remove(inp.owner)
rst.append(o)
rst = rst + param_or_data_provider_oprs
assert len(rst) == len(oprs)
return rst

map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
outputs outputs
) )
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
oprs_seq = reorder_oprs_seq(oprs_seq)
if prune_reshape is True: if prune_reshape is True:
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
return oprs_seq return oprs_seq


+ 4
- 3
imperative/python/megengine/utils/network.py View File

@@ -241,6 +241,7 @@ class Network:
if optimize_for_inference: if optimize_for_inference:
metadata.optimize_options = optimize_options metadata.optimize_options = optimize_options


G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out])
dump_content, _ = G.dump_graph( dump_content, _ = G.dump_graph(
out, out,
keep_var_name=keep_var_name, keep_var_name=keep_var_name,
@@ -353,7 +354,7 @@ class Network:
) )
shp[0] = batchsize shp[0] = batchsize
i.shape = tuple(shp) i.shape = tuple(shp)
self._compile()
assert prev_batchsize is not None, "no data provider found" assert prev_batchsize is not None, "no data provider found"
assert not blacklist, "unused items in blacklist: {}".format(blacklist) assert not blacklist, "unused items in blacklist: {}".format(blacklist)


@@ -363,7 +364,6 @@ class Network:
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
""" """
if not all([var.owner for var in repl_dict.values()]): if not all([var.owner for var in repl_dict.values()]):
print(repl_dict.values())
self.add_dep_oprs(*list(repl_dict.values())) self.add_dep_oprs(*list(repl_dict.values()))
for var in self.all_vars: for var in self.all_vars:
if var in repl_dict: if var in repl_dict:
@@ -373,6 +373,7 @@ class Network:
owner.outputs[idx] = var owner.outputs[idx] = var
var.__dict__.update(repl_var.__dict__) var.__dict__.update(repl_var.__dict__)
var.var = repl_var.var var.var = repl_var.var
self._compile()


def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
""" """
@@ -384,11 +385,11 @@ class Network:
assert len(opr.outputs) == len( assert len(opr.outputs) == len(
repl_dict[opr].outputs repl_dict[opr].outputs
), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr]))
repl_dict[opr].outputs = opr.outputs
for ind, var in enumerate(opr.outputs): for ind, var in enumerate(opr.outputs):
var.owner = repl_dict[opr] var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var var.var = repl_dict[opr].outputs[ind].var
self._compile()


def get_opr_by_type(self, oprcls, unique=True): def get_opr_by_type(self, oprcls, unique=True):
assert issubclass(oprcls, OpNode) assert issubclass(oprcls, OpNode)


+ 42
- 16
imperative/python/megengine/utils/network_node.py View File

@@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def dtype(self): def dtype(self):
return self.var.dtype if self.var else None return self.var.dtype if self.var else None


@property
def ndim(self):
return super().ndim

def __bool__(self): def __bool__(self):
return False return False


@@ -134,7 +138,18 @@ class OpNode(NetworkNode):
self.outputs = [] self.outputs = []
self.params = {} self.params = {}
self._opr = None # mgb opnode self._opr = None # mgb opnode
self.id = id(self)

@property
def id(self):
if self._opr is not None:
return self._opr.id
return id(self)

@property
def priority(self):
if self._opr is not None:
return self._opr.priority
return 0


@classmethod @classmethod
def load(cls, opr): def load(cls, opr):
@@ -144,16 +159,21 @@ class OpNode(NetworkNode):
obj._opr = opr obj._opr = opr
return obj return obj


def compile(self, graph=None):
op = self.opdef(**self.params)
args = [i.var for i in self.inputs]
outputs = rt.invoke_op(op, args)
assert len(outputs) == len(self.outputs)
self._opr = outputs[0].owner
for i in range(len(self.outputs)):
self.outputs[i].var = outputs[i]
self.outputs[i].var.name = self.outputs[i].name
assert self.outputs[i].owner is self
def compile(self):
if (
self._opr is None
or len(self._opr.inputs) != len(self.inputs)
or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)])
):
op = self.opdef(**self.params)
args = [i.var for i in self.inputs]
outputs = rt.invoke_op(op, args)
assert len(outputs) == len(self.outputs)
self._opr = outputs[0].owner
for i in range(len(self.outputs)):
self.outputs[i].var = outputs[i]
self.outputs[i].var.name = self.outputs[i].name
assert self.outputs[i].owner is self


def add_inp_var(self, x): def add_inp_var(self, x):
self.inputs.append(x) self.inputs.append(x)
@@ -197,11 +217,17 @@ class Host2DeviceCopy(OpNode):
return self return self


def compile(self, graph): def compile(self, graph):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner
if len(self.outputs) == 0:
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = outputs
if (
self._opr is None
or self._opr.outputs[0].comp_node != self.device
or self._opr.outputs[0].shape != self.shape
or self._opr.outputs[0].dtype != self.dtype
):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner
if len(self.outputs) == 0:
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = outputs
assert self.outputs[0].owner is self assert self.outputs[0].owner is self






+ 8
- 1
imperative/python/src/graph_rt.cpp View File

@@ -192,7 +192,14 @@ void init_graph_rt(py::module m) {
}) })
.def("__repr__", [](cg::OperatorNodeBase* opr){ .def("__repr__", [](cg::OperatorNodeBase* opr){
return "Opr:" + opr->name(); return "Opr:" + opr->name();
});
})
.def_property("priority",
[](cg::OperatorNodeBase* opr) {
return opr->node_prop().attribute().priority;
},
[](cg::OperatorNodeBase* opr, int priority) {
opr->node_prop().attribute().priority = priority;
});


py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) .def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>())


+ 41
- 0
imperative/python/test/unit/utils/test_cgtools.py View File

@@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.megbrain_graph import apply_normal_varnode from megengine.core.tensor.megbrain_graph import apply_normal_varnode
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
from megengine.jit import trace from megengine.jit import trace
from megengine.utils.network import Network




def make_dev_tensor(value, dtype=None, device=None): def make_dev_tensor(value, dtype=None, device=None):
@@ -143,6 +144,46 @@ def test_get_opr_seq():
assert len(seq_2) == 6 assert len(seq_2) == 6




def test_topological_sort():
@trace(symbolic=True, capture_as_const=True)
def func(x, y):
a = x + y
a1 = F.relu(a)
a2 = F.abs(a)
a3 = F.ceil(a) * 2
a4 = F.floor(a)
r = a1 - a2
r1 = a3 / a4
return r, r1

file = io.BytesIO()
func(megengine.tensor(1.0), megengine.tensor(2.0))
func.dump(
file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True
)
file.seek(0)
g = Network.load(file)
oprseq1 = g.all_oprs
gt = [
"Host2DeviceCopy",
"Host2DeviceCopy",
"ADD",
"RELU",
"ABS",
"CEIL",
"ImmutableTensor",
"MUL",
"FLOOR",
"SUB",
"TRUE_DIV",
]
for op, mode in zip(oprseq1, gt):
if op.type == "Elemwise":
assert op.params["mode"] == mode
else:
assert op.type == mode


def test_graph_function(): def test_graph_function():
class Net(M.Module): class Net(M.Module):
def forward(self, a, b): def forward(self, a, b):


Loading…
Cancel
Save