GitOrigin-RevId: 47a26dd6dd
release-1.5
@@ -893,6 +893,10 @@ class trace: | |||||
if isinstance(file, str): | if isinstance(file, str): | ||||
permission = "wb" if append == False else "ab" | permission = "wb" if append == False else "ab" | ||||
file = open(file, permission) | file = open(file, permission) | ||||
if keep_opr_priority: | |||||
graph._set_priority_to_id(dest_vars) | |||||
dump_content, dump_info = G.dump_graph( | dump_content, dump_info = G.dump_graph( | ||||
dest_vars, | dest_vars, | ||||
keep_var_name=keep_var_name, | keep_var_name=keep_var_name, | ||||
@@ -6,6 +6,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | import collections | ||||
import heapq | |||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from typing import Dict, List, Tuple, Union | from typing import Dict, List, Tuple, Union | ||||
@@ -88,6 +89,41 @@ def get_opr_type(opr: _OpNode) -> str: | |||||
return opr.type | return opr.type | ||||
class _OprStableOrderHeapq: | |||||
"""heap implementation for operator comparison in stable order""" | |||||
_list = None | |||||
_extra_priority = None | |||||
_used_id_name_pairs = None | |||||
def __init__(self, extra_priority): | |||||
assert isinstance(extra_priority, collections.Callable) | |||||
self._list = [] | |||||
self._extra_priority = extra_priority | |||||
self._used_id_name_pairs = {} | |||||
def pop_min(self): | |||||
return heapq.heappop(self._list)[-1] | |||||
def add(self, opr): | |||||
# named as add to mimic set() interface | |||||
id_ = opr.id | |||||
name = opr.name | |||||
other = self._used_id_name_pairs.setdefault((id_, name), opr) | |||||
if other is not opr: | |||||
raise RuntimeError( | |||||
"duplicated (id, name) pair: opr0={} opr1={}".format(other, opr) | |||||
) | |||||
item = self._extra_priority(opr) + (id_, name, opr) | |||||
heapq.heappush(self._list, item) | |||||
def __bool__(self): | |||||
return bool(self._list) | |||||
def graph_traversal(outputs: _VarNode): | def graph_traversal(outputs: _VarNode): | ||||
""" | """ | ||||
Helper function to traverse the computing graph and return enough useful information. | Helper function to traverse the computing graph and return enough useful information. | ||||
@@ -110,12 +146,13 @@ def graph_traversal(outputs: _VarNode): | |||||
var2oprs = collections.defaultdict(list) | var2oprs = collections.defaultdict(list) | ||||
opr2receivers = collections.defaultdict(list) | opr2receivers = collections.defaultdict(list) | ||||
queue = list(set(map(lambda x: x.owner, outputs))) | |||||
queue = [] | |||||
[queue.append(o) for o in [x.owner for x in outputs] if o not in queue] | |||||
visited = set(map(lambda x: x.id, queue)) | visited = set(map(lambda x: x.id, queue)) | ||||
# iterate through whole comp_graph, fill in meta information | # iterate through whole comp_graph, fill in meta information | ||||
indegree2opr = collections.defaultdict(set) | indegree2opr = collections.defaultdict(set) | ||||
indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,)) | |||||
opr2indegree = {} | opr2indegree = {} | ||||
idx = 0 | idx = 0 | ||||
@@ -138,8 +175,8 @@ def graph_traversal(outputs: _VarNode): | |||||
indegree += 1 | indegree += 1 | ||||
opr2receivers[pre_opr.id].append(cur_opr.id) | opr2receivers[pre_opr.id].append(cur_opr.id) | ||||
indegree2opr[indegree].add(cur_opr.id) | |||||
opr = cur_opr if indegree == 0 else cur_opr.id | |||||
indegree2opr[indegree].add(opr) | |||||
opr2indegree[cur_opr.id] = indegree | opr2indegree[cur_opr.id] = indegree | ||||
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | ||||
@@ -162,8 +199,8 @@ def get_oprs_seq( | |||||
oprs_seq = [] | oprs_seq = [] | ||||
nr_remain = len(map_oprs) | nr_remain = len(map_oprs) | ||||
while indegree2opr[0]: | while indegree2opr[0]: | ||||
opr_id = indegree2opr[0].pop() | |||||
opr = map_oprs[opr_id] | |||||
opr = indegree2opr[0].pop_min() | |||||
opr_id = opr.id | |||||
nr_remain -= 1 | nr_remain -= 1 | ||||
if opr.type != "ImmutableTensor" or not prune_immtensor: | if opr.type != "ImmutableTensor" or not prune_immtensor: | ||||
oprs_seq.append(opr) | oprs_seq.append(opr) | ||||
@@ -173,7 +210,10 @@ def get_oprs_seq( | |||||
indegree2opr[indegree].remove(post_id) | indegree2opr[indegree].remove(post_id) | ||||
indegree -= 1 | indegree -= 1 | ||||
indegree2opr[indegree].add(post_id) | |||||
if indegree == 0: | |||||
indegree2opr[indegree].add(map_oprs[post_id]) | |||||
else: | |||||
indegree2opr[indegree].add(post_id) | |||||
opr2indegree[post_id] = indegree | opr2indegree[post_id] = indegree | ||||
assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | ||||
@@ -213,10 +253,34 @@ def get_oprs_seq( | |||||
# filter out all marked oprs | # filter out all marked oprs | ||||
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | ||||
# adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs. | |||||
def reorder_oprs_seq(oprs): | |||||
rst = [] | |||||
param_or_data_provider_oprs = [] | |||||
other_oprs = [] | |||||
for o in oprs: | |||||
if o.type in ["ImmutableTensor", "Host2DeviceCopy"]: | |||||
param_or_data_provider_oprs.append(o) | |||||
else: | |||||
other_oprs.append(o) | |||||
for o in other_oprs: | |||||
for inp in o.inputs: | |||||
if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]: | |||||
if inp.owner in param_or_data_provider_oprs: | |||||
rst.append(inp.owner) | |||||
param_or_data_provider_oprs.remove(inp.owner) | |||||
rst.append(o) | |||||
rst = rst + param_or_data_provider_oprs | |||||
assert len(rst) == len(oprs) | |||||
return rst | |||||
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | ||||
outputs | outputs | ||||
) | ) | ||||
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | ||||
oprs_seq = reorder_oprs_seq(oprs_seq) | |||||
if prune_reshape is True: | if prune_reshape is True: | ||||
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | ||||
return oprs_seq | return oprs_seq | ||||
@@ -241,6 +241,7 @@ class Network: | |||||
if optimize_for_inference: | if optimize_for_inference: | ||||
metadata.optimize_options = optimize_options | metadata.optimize_options = optimize_options | ||||
G.set_priority_to_id([o._node if isinstance(o, G.VarNode) else o for o in out]) | |||||
dump_content, _ = G.dump_graph( | dump_content, _ = G.dump_graph( | ||||
out, | out, | ||||
keep_var_name=keep_var_name, | keep_var_name=keep_var_name, | ||||
@@ -353,7 +354,7 @@ class Network: | |||||
) | ) | ||||
shp[0] = batchsize | shp[0] = batchsize | ||||
i.shape = tuple(shp) | i.shape = tuple(shp) | ||||
self._compile() | |||||
assert prev_batchsize is not None, "no data provider found" | assert prev_batchsize is not None, "no data provider found" | ||||
assert not blacklist, "unused items in blacklist: {}".format(blacklist) | assert not blacklist, "unused items in blacklist: {}".format(blacklist) | ||||
@@ -363,7 +364,6 @@ class Network: | |||||
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | ||||
""" | """ | ||||
if not all([var.owner for var in repl_dict.values()]): | if not all([var.owner for var in repl_dict.values()]): | ||||
print(repl_dict.values()) | |||||
self.add_dep_oprs(*list(repl_dict.values())) | self.add_dep_oprs(*list(repl_dict.values())) | ||||
for var in self.all_vars: | for var in self.all_vars: | ||||
if var in repl_dict: | if var in repl_dict: | ||||
@@ -373,6 +373,7 @@ class Network: | |||||
owner.outputs[idx] = var | owner.outputs[idx] = var | ||||
var.__dict__.update(repl_var.__dict__) | var.__dict__.update(repl_var.__dict__) | ||||
var.var = repl_var.var | var.var = repl_var.var | ||||
self._compile() | |||||
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | ||||
""" | """ | ||||
@@ -384,11 +385,11 @@ class Network: | |||||
assert len(opr.outputs) == len( | assert len(opr.outputs) == len( | ||||
repl_dict[opr].outputs | repl_dict[opr].outputs | ||||
), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | ), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | ||||
repl_dict[opr].outputs = opr.outputs | |||||
for ind, var in enumerate(opr.outputs): | for ind, var in enumerate(opr.outputs): | ||||
var.owner = repl_dict[opr] | var.owner = repl_dict[opr] | ||||
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | ||||
var.var = repl_dict[opr].outputs[ind].var | var.var = repl_dict[opr].outputs[ind].var | ||||
self._compile() | |||||
def get_opr_by_type(self, oprcls, unique=True): | def get_opr_by_type(self, oprcls, unique=True): | ||||
assert issubclass(oprcls, OpNode) | assert issubclass(oprcls, OpNode) | ||||
@@ -90,6 +90,10 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||||
def dtype(self): | def dtype(self): | ||||
return self.var.dtype if self.var else None | return self.var.dtype if self.var else None | ||||
@property | |||||
def ndim(self): | |||||
return super().ndim | |||||
def __bool__(self): | def __bool__(self): | ||||
return False | return False | ||||
@@ -134,7 +138,18 @@ class OpNode(NetworkNode): | |||||
self.outputs = [] | self.outputs = [] | ||||
self.params = {} | self.params = {} | ||||
self._opr = None # mgb opnode | self._opr = None # mgb opnode | ||||
self.id = id(self) | |||||
@property | |||||
def id(self): | |||||
if self._opr is not None: | |||||
return self._opr.id | |||||
return id(self) | |||||
@property | |||||
def priority(self): | |||||
if self._opr is not None: | |||||
return self._opr.priority | |||||
return 0 | |||||
@classmethod | @classmethod | ||||
def load(cls, opr): | def load(cls, opr): | ||||
@@ -144,16 +159,21 @@ class OpNode(NetworkNode): | |||||
obj._opr = opr | obj._opr = opr | ||||
return obj | return obj | ||||
def compile(self, graph=None): | |||||
op = self.opdef(**self.params) | |||||
args = [i.var for i in self.inputs] | |||||
outputs = rt.invoke_op(op, args) | |||||
assert len(outputs) == len(self.outputs) | |||||
self._opr = outputs[0].owner | |||||
for i in range(len(self.outputs)): | |||||
self.outputs[i].var = outputs[i] | |||||
self.outputs[i].var.name = self.outputs[i].name | |||||
assert self.outputs[i].owner is self | |||||
def compile(self): | |||||
if ( | |||||
self._opr is None | |||||
or len(self._opr.inputs) != len(self.inputs) | |||||
or any([i != j.var for i, j in zip(self._opr.inputs, self.inputs)]) | |||||
): | |||||
op = self.opdef(**self.params) | |||||
args = [i.var for i in self.inputs] | |||||
outputs = rt.invoke_op(op, args) | |||||
assert len(outputs) == len(self.outputs) | |||||
self._opr = outputs[0].owner | |||||
for i in range(len(self.outputs)): | |||||
self.outputs[i].var = outputs[i] | |||||
self.outputs[i].var.name = self.outputs[i].name | |||||
assert self.outputs[i].owner is self | |||||
def add_inp_var(self, x): | def add_inp_var(self, x): | ||||
self.inputs.append(x) | self.inputs.append(x) | ||||
@@ -197,11 +217,17 @@ class Host2DeviceCopy(OpNode): | |||||
return self | return self | ||||
def compile(self, graph): | def compile(self, graph): | ||||
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||||
self._opr = outputs.owner | |||||
if len(self.outputs) == 0: | |||||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
self.outputs[0].var = outputs | |||||
if ( | |||||
self._opr is None | |||||
or self._opr.outputs[0].comp_node != self.device | |||||
or self._opr.outputs[0].shape != self.shape | |||||
or self._opr.outputs[0].dtype != self.dtype | |||||
): | |||||
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||||
self._opr = outputs.owner | |||||
if len(self.outputs) == 0: | |||||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||||
self.outputs[0].var = outputs | |||||
assert self.outputs[0].owner is self | assert self.outputs[0].owner is self | ||||
@@ -192,7 +192,14 @@ void init_graph_rt(py::module m) { | |||||
}) | }) | ||||
.def("__repr__", [](cg::OperatorNodeBase* opr){ | .def("__repr__", [](cg::OperatorNodeBase* opr){ | ||||
return "Opr:" + opr->name(); | return "Opr:" + opr->name(); | ||||
}); | |||||
}) | |||||
.def_property("priority", | |||||
[](cg::OperatorNodeBase* opr) { | |||||
return opr->node_prop().attribute().priority; | |||||
}, | |||||
[](cg::OperatorNodeBase* opr, int priority) { | |||||
opr->node_prop().attribute().priority = priority; | |||||
}); | |||||
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | ||||
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) | .def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) | ||||
@@ -19,6 +19,7 @@ from megengine.core.tensor import megbrain_graph as mgb_graph | |||||
from megengine.core.tensor.megbrain_graph import apply_normal_varnode | from megengine.core.tensor.megbrain_graph import apply_normal_varnode | ||||
from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
from megengine.jit import trace | from megengine.jit import trace | ||||
from megengine.utils.network import Network | |||||
def make_dev_tensor(value, dtype=None, device=None): | def make_dev_tensor(value, dtype=None, device=None): | ||||
@@ -143,6 +144,46 @@ def test_get_opr_seq(): | |||||
assert len(seq_2) == 6 | assert len(seq_2) == 6 | ||||
def test_topological_sort(): | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def func(x, y): | |||||
a = x + y | |||||
a1 = F.relu(a) | |||||
a2 = F.abs(a) | |||||
a3 = F.ceil(a) * 2 | |||||
a4 = F.floor(a) | |||||
r = a1 - a2 | |||||
r1 = a3 / a4 | |||||
return r, r1 | |||||
file = io.BytesIO() | |||||
func(megengine.tensor(1.0), megengine.tensor(2.0)) | |||||
func.dump( | |||||
file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True | |||||
) | |||||
file.seek(0) | |||||
g = Network.load(file) | |||||
oprseq1 = g.all_oprs | |||||
gt = [ | |||||
"Host2DeviceCopy", | |||||
"Host2DeviceCopy", | |||||
"ADD", | |||||
"RELU", | |||||
"ABS", | |||||
"CEIL", | |||||
"ImmutableTensor", | |||||
"MUL", | |||||
"FLOOR", | |||||
"SUB", | |||||
"TRUE_DIV", | |||||
] | |||||
for op, mode in zip(oprseq1, gt): | |||||
if op.type == "Elemwise": | |||||
assert op.params["mode"] == mode | |||||
else: | |||||
assert op.type == mode | |||||
def test_graph_function(): | def test_graph_function(): | ||||
class Net(M.Module): | class Net(M.Module): | ||||
def forward(self, a, b): | def forward(self, a, b): | ||||