GitOrigin-RevId: 47a26dd6dd
release-1.5
@@ -893,6 +893,10 @@ class trace: | |||
if isinstance(file, str): | |||
permission = "wb" if append == False else "ab" | |||
file = open(file, permission) | |||
if keep_opr_priority: | |||
graph._set_priority_to_id(dest_vars) | |||
dump_content, dump_info = G.dump_graph( | |||
dest_vars, | |||
keep_var_name=keep_var_name, | |||
@@ -6,6 +6,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import heapq | |||
from collections import OrderedDict | |||
from typing import Dict, List, Tuple, Union | |||
@@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str: | |||
return opr.type | |||
class _OprStableOrderHeapq: | |||
"""heap implementation for operator comparison in stable order""" | |||
_list = None | |||
_extra_priority = None | |||
_used_id_name_pairs = None | |||
def __init__(self, extra_priority): | |||
assert isinstance(extra_priority, collections.Callable) | |||
self._list = [] | |||
self._extra_priority = extra_priority | |||
self._used_id_name_pairs = {} | |||
def pop_min(self): | |||
return heapq.heappop(self._list)[-1] | |||
def add(self, opr): | |||
# named as add to mimic set() interface | |||
id_ = opr.id | |||
name = opr.name | |||
other = self._used_id_name_pairs.setdefault((id_, name), opr) | |||
if other is not opr: | |||
raise RuntimeError( | |||
"duplicated (id, name) pair: opr0={} opr1={}".format(other, opr) | |||
) | |||
item = self._extra_priority(opr) + (id_, name, opr) | |||
heapq.heappush(self._list, item) | |||
def __bool__(self): | |||
return bool(self._list) | |||
def graph_traversal(outputs: _VarNode): | |||
""" | |||
Helper function to traverse the computing graph and return enough useful information. | |||
@@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode): | |||
var2oprs = collections.defaultdict(list) | |||
opr2receivers = collections.defaultdict(list) | |||
queue = list(set(map(lambda x: x.owner, outputs))) | |||
queue = [] | |||
[queue.append(o) for o in [x.owner for x in outputs] if o not in queue] | |||
visited = set(map(lambda x: x.id, queue)) | |||
# iterate through whole comp_graph, fill in meta information | |||
indegree2opr = collections.defaultdict(set) | |||
indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,)) | |||
opr2indegree = {} | |||
idx = 0 | |||
@@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode): | |||
indegree += 1 | |||
opr2receivers[pre_opr.id].append(cur_opr.id) | |||
indegree2opr[indegree].add(cur_opr.id) | |||
opr = cur_opr if indegree == 0 else cur_opr.id | |||
indegree2opr[indegree].add(opr) | |||
opr2indegree[cur_opr.id] = indegree | |||
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | |||
@@ -162,8 +199,8 @@ def get_oprs_seq( | |||
oprs_seq = [] | |||
nr_remain = len(map_oprs) | |||
while indegree2opr[0]: | |||
opr_id = indegree2opr[0].pop() | |||
opr = map_oprs[opr_id] | |||
opr = indegree2opr[0].pop_min() | |||
opr_id = opr.id | |||
nr_remain -= 1 | |||
if opr.type != "ImmutableTensor" or not prune_immtensor: | |||
oprs_seq.append(opr) | |||
@@ -173,7 +210,10 @@ def get_oprs_seq( | |||
indegree2opr[indegree].remove(post_id) | |||
indegree -= 1 | |||
indegree2opr[indegree].add(post_id) | |||
if indegree == 0: | |||
indegree2opr[indegree].add(map_oprs[post_id]) | |||
else: | |||
indegree2opr[indegree].add(post_id) | |||
opr2indegree[post_id] = indegree | |||
assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | |||
@@ -213,10 +253,34 @@ def get_oprs_seq( | |||
# filter out all marked oprs | |||
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | |||
# adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs. | |||
def reorder_oprs_seq(oprs): | |||
rst = [] | |||
param_or_data_provider_oprs = [] | |||
other_oprs = [] | |||
for o in oprs: | |||
if o.type in ["ImmutableTensor", "Host2DeviceCopy"]: | |||
param_or_data_provider_oprs.append(o) | |||
else: | |||
other_oprs.append(o) | |||
for o in other_oprs: | |||
for inp in o.inputs: | |||
if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]: | |||
if inp.owner in param_or_data_provider_oprs: | |||
rst.append(inp.owner) | |||
param_or_data_provider_oprs.remove(inp.owner) | |||
rst.append(o) | |||
rst = rst + param_or_data_provider_oprs | |||
assert len(rst) == len(oprs) | |||
return rst | |||
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | |||
outputs | |||
) | |||
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | |||
oprs_seq = reorder_oprs_seq(oprs_seq) | |||
if prune_reshape is True: | |||
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | |||
return oprs_seq | |||
@@ -241,6 +241,7 @@ class Network: | |||
if optimize_for_inference: | |||
metadata.optimize_options = optimize_options | |||
G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out]) | |||
dump_content, _ = G.dump_graph( | |||
out, | |||
keep_var_name=keep_var_name, | |||
@@ -353,7 +354,7 @@ class Network: | |||
) | |||
shp[0] = batchsize | |||
i.shape = tuple(shp) | |||
self._compile() | |||
assert prev_batchsize is not None, "no data provider found" | |||
assert not blacklist, "unused items in blacklist: {}".format(blacklist) | |||
@@ -363,7 +364,6 @@ class Network: | |||
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | |||
""" | |||
if not all([var.owner for var in repl_dict.values()]): | |||
print(repl_dict.values()) | |||
self.add_dep_oprs(*list(repl_dict.values())) | |||
for var in self.all_vars: | |||
if var in repl_dict: | |||
@@ -373,6 +373,7 @@ class Network: | |||
owner.outputs[idx] = var | |||
var.__dict__.update(repl_var.__dict__) | |||
var.var = repl_var.var | |||
self._compile() | |||
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||
""" | |||
@@ -384,11 +385,11 @@ class Network: | |||
assert len(opr.outputs) == len( | |||
repl_dict[opr].outputs | |||
), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | |||
repl_dict[opr].outputs = opr.outputs | |||
for ind, var in enumerate(opr.outputs): | |||
var.owner = repl_dict[opr] | |||
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||
var.var = repl_dict[opr].outputs[ind].var | |||
self._compile() | |||
def get_opr_by_type(self, oprcls, unique=True): | |||
assert issubclass(oprcls, OpNode) | |||
@@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
def dtype(self): | |||
return self.var.dtype if self.var else None | |||
@property | |||
def ndim(self): | |||
return super().ndim | |||
def __bool__(self): | |||
return False | |||
@@ -134,7 +138,18 @@ class OpNode(NetworkNode): | |||
self.outputs = [] | |||
self.params = {} | |||
self._opr = None # mgb opnode | |||
self.id = id(self) | |||
@property | |||
def id(self): | |||
if self._opr is not None: | |||
return self._opr.id | |||
return id(self) | |||
@property | |||
def priority(self): | |||
if self._opr is not None: | |||
return self._opr.priority | |||
return 0 | |||
@classmethod | |||
def load(cls, opr): | |||
@@ -144,16 +159,21 @@ class OpNode(NetworkNode): | |||
obj._opr = opr | |||
return obj | |||
def compile(self, graph=None): | |||
op = self.opdef(**self.params) | |||
args = [i.var for i in self.inputs] | |||
outputs = rt.invoke_op(op, args) | |||
assert len(outputs) == len(self.outputs) | |||
self._opr = outputs[0].owner | |||
for i in range(len(self.outputs)): | |||
self.outputs[i].var = outputs[i] | |||
self.outputs[i].var.name = self.outputs[i].name | |||
assert self.outputs[i].owner is self | |||
def compile(self): | |||
if ( | |||
self._opr is None | |||
or len(self._opr.inputs) != len(self.inputs) | |||
or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)]) | |||
): | |||
op = self.opdef(**self.params) | |||
args = [i.var for i in self.inputs] | |||
outputs = rt.invoke_op(op, args) | |||
assert len(outputs) == len(self.outputs) | |||
self._opr = outputs[0].owner | |||
for i in range(len(self.outputs)): | |||
self.outputs[i].var = outputs[i] | |||
self.outputs[i].var.name = self.outputs[i].name | |||
assert self.outputs[i].owner is self | |||
def add_inp_var(self, x): | |||
self.inputs.append(x) | |||
@@ -197,11 +217,17 @@ class Host2DeviceCopy(OpNode): | |||
return self | |||
def compile(self, graph): | |||
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
self._opr = outputs.owner | |||
if len(self.outputs) == 0: | |||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
self.outputs[0].var = outputs | |||
if ( | |||
self._opr is None | |||
or self._opr.outputs[0].comp_node != self.device | |||
or self._opr.outputs[0].shape != self.shape | |||
or self._opr.outputs[0].dtype != self.dtype | |||
): | |||
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
self._opr = outputs.owner | |||
if len(self.outputs) == 0: | |||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
self.outputs[0].var = outputs | |||
assert self.outputs[0].owner is self | |||
@@ -192,7 +192,14 @@ void init_graph_rt(py::module m) { | |||
}) | |||
.def("__repr__", [](cg::OperatorNodeBase* opr){ | |||
return "Opr:" + opr->name(); | |||
}); | |||
}) | |||
.def_property("priority", | |||
[](cg::OperatorNodeBase* opr) { | |||
return opr->node_prop().attribute().priority; | |||
}, | |||
[](cg::OperatorNodeBase* opr, int priority) { | |||
opr->node_prop().attribute().priority = priority; | |||
}); | |||
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | |||
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) | |||
@@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph | |||
from megengine.core.tensor.megbrain_graph import apply_normal_varnode | |||
from megengine.core.tensor.utils import astensor1d | |||
from megengine.jit import trace | |||
from megengine.utils.network import Network | |||
def make_dev_tensor(value, dtype=None, device=None): | |||
@@ -143,6 +144,46 @@ def test_get_opr_seq(): | |||
assert len(seq_2) == 6 | |||
def test_topological_sort(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def func(x, y): | |||
a = x + y | |||
a1 = F.relu(a) | |||
a2 = F.abs(a) | |||
a3 = F.ceil(a) * 2 | |||
a4 = F.floor(a) | |||
r = a1 - a2 | |||
r1 = a3 / a4 | |||
return r, r1 | |||
file = io.BytesIO() | |||
func(megengine.tensor(1.0), megengine.tensor(2.0)) | |||
func.dump( | |||
file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True | |||
) | |||
file.seek(0) | |||
g = Network.load(file) | |||
oprseq1 = g.all_oprs | |||
gt = [ | |||
"Host2DeviceCopy", | |||
"Host2DeviceCopy", | |||
"ADD", | |||
"RELU", | |||
"ABS", | |||
"CEIL", | |||
"ImmutableTensor", | |||
"MUL", | |||
"FLOOR", | |||
"SUB", | |||
"TRUE_DIV", | |||
] | |||
for op, mode in zip(oprseq1, gt): | |||
if op.type == "Elemwise": | |||
assert op.params["mode"] == mode | |||
else: | |||
assert op.type == mode | |||
def test_graph_function(): | |||
class Net(M.Module): | |||
def forward(self, a, b): | |||