GitOrigin-RevId: 614302552c
release-1.5
@@ -519,8 +519,7 @@ def _unwrap(x): | |||||
return type(x)(map(_unwrap, x)) | return type(x)(map(_unwrap, x)) | ||||
if isinstance(x, VarNode): | if isinstance(x, VarNode): | ||||
return x._node | return x._node | ||||
else: | |||||
return x | |||||
return x | |||||
def apply_normal_varnode(op: OpDef, *args: VarNode): | def apply_normal_varnode(op: OpDef, *args: VarNode): | ||||
@@ -12,14 +12,16 @@ import itertools | |||||
import pickle | import pickle | ||||
import re | import re | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from typing import Any, Dict, List, Sequence | |||||
from typing import Any, Dict, List, Optional, Sequence | |||||
from ..core import _imperative_rt | |||||
from ..core._imperative_rt import ComputingGraph, SerializationMetadata | from ..core._imperative_rt import ComputingGraph, SerializationMetadata | ||||
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape | from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape | ||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | ||||
from .network_node import ( | from .network_node import ( | ||||
ConstOpBase, | |||||
Host2DeviceCopy, | Host2DeviceCopy, | ||||
ImmutableTensor, | ImmutableTensor, | ||||
NetworkNode, | NetworkNode, | ||||
@@ -37,8 +39,10 @@ class Network: | |||||
self._orig_inputs = [] | self._orig_inputs = [] | ||||
self.output_vars = [] # output var of graph | self.output_vars = [] # output var of graph | ||||
self._orig_outputs = [] | self._orig_outputs = [] | ||||
self.all_oprs_map = OrderedDict() | |||||
self.all_vars_map = OrderedDict() | |||||
self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode | |||||
self.all_vars_map = ( | |||||
OrderedDict() | |||||
) # _imperative_rt.graph.OperatorNode.id: OpNode | |||||
self.graph = ComputingGraph() | self.graph = ComputingGraph() | ||||
self._metadata = None | self._metadata = None | ||||
@@ -101,7 +105,7 @@ class Network: | |||||
self.all_oprs_map = {} | self.all_oprs_map = {} | ||||
self.all_vars_map = {} | self.all_vars_map = {} | ||||
for opr in self.all_oprs: | for opr in self.all_oprs: | ||||
if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): | |||||
if isinstance(opr, (ConstOpBase, Host2DeviceCopy)): | |||||
opr.compile(self.graph) | opr.compile(self.graph) | ||||
else: | else: | ||||
opr.compile() | opr.compile() | ||||
@@ -295,6 +299,9 @@ class Network: | |||||
def add_dep_oprs(self, *vars): | def add_dep_oprs(self, *vars): | ||||
if len(vars) == 0: | if len(vars) == 0: | ||||
vars = self.output_vars | vars = self.output_vars | ||||
assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode" | |||||
q = list(vars) | q = list(vars) | ||||
while len(q) > 0: | while len(q) > 0: | ||||
cur = q.pop(0) | cur = q.pop(0) | ||||
@@ -368,11 +375,14 @@ class Network: | |||||
for var in self.all_vars: | for var in self.all_vars: | ||||
if var in repl_dict: | if var in repl_dict: | ||||
repl_var = repl_dict[var] | repl_var = repl_dict[var] | ||||
owner = repl_var.owner | |||||
idx = owner.outputs.index(repl_var) | |||||
owner.outputs[idx] = var | |||||
var.__dict__.update(repl_var.__dict__) | |||||
var.var = repl_var.var | |||||
if repl_var is var: | |||||
continue | |||||
for opnode in var.users: | |||||
assert var in opnode.inputs | |||||
opnode.inputs = [repl_var if var is i else i for i in opnode.inputs] | |||||
if opnode not in repl_var.users: | |||||
repl_var.users.append(opnode) | |||||
var.users.clear() | |||||
self._compile() | self._compile() | ||||
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | ||||
@@ -473,14 +483,20 @@ class Network: | |||||
def all_oprs_dict(self): | def all_oprs_dict(self): | ||||
return self.opr_filter.as_dict() | return self.opr_filter.as_dict() | ||||
# used for loading and building graph | |||||
def _add_opr(self, opr): | |||||
def _add_opr(self, opr) -> Optional[OpNode]: | |||||
""" | |||||
Used for loading and building graph. | |||||
""" | |||||
assert isinstance(opr, _imperative_rt.graph.OperatorNode) | |||||
# TODO: use megbrain C++ RTTI to replace type string | # TODO: use megbrain C++ RTTI to replace type string | ||||
if opr.id not in self.all_oprs_map: | if opr.id not in self.all_oprs_map: | ||||
opnode = str_to_mge_class(get_opr_type(opr)).load(opr) | opnode = str_to_mge_class(get_opr_type(opr)).load(opr) | ||||
self.all_oprs_map[opr.id] = opnode | self.all_oprs_map[opr.id] = opnode | ||||
for var in opr.inputs: | for var in opr.inputs: | ||||
opnode.add_inp_var(self._get_var(var)) | |||||
varnode = self._get_var(var) | |||||
opnode.add_inp_var(varnode) | |||||
varnode.users.append(opnode) | |||||
for var in opr.outputs: | for var in opr.outputs: | ||||
opnode.add_out_var(self._get_var(var)) | opnode.add_out_var(self._get_var(var)) | ||||
return opnode | return opnode | ||||
@@ -503,7 +519,10 @@ class Network: | |||||
return None | return None | ||||
def _get_var(self, x): | def _get_var(self, x): | ||||
# auto convert to VarNode of Network | |||||
""" | |||||
Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`. | |||||
""" | |||||
assert isinstance(x, _imperative_rt.graph.VarNode) | |||||
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: | if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: | ||||
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | ||||
return self.all_vars_map[x.id] | return self.all_vars_map[x.id] | ||||
@@ -37,6 +37,7 @@ class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||||
class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | ||||
def __init__(self, var=None, *, owner_opr=None, name=None): | def __init__(self, var=None, *, owner_opr=None, name=None): | ||||
SymbolVar.__init__(self, var) | SymbolVar.__init__(self, var) | ||||
self.users = [] # List[OpNode] | |||||
self.owner = owner_opr | self.owner = owner_opr | ||||
self.name = name | self.name = name | ||||
self.id = id(self) | self.id = id(self) | ||||
@@ -214,6 +215,7 @@ class Host2DeviceCopy(OpNode): | |||||
def compile(self, graph): | def compile(self, graph): | ||||
if ( | if ( | ||||
self._opr is None | self._opr is None | ||||
or self._opr.graph != graph | |||||
or self._opr.outputs[0].comp_node != self.device | or self._opr.outputs[0].comp_node != self.device | ||||
or self._opr.outputs[0].shape != self.shape | or self._opr.outputs[0].shape != self.shape | ||||
or self._opr.outputs[0].dtype != self.dtype | or self._opr.outputs[0].dtype != self.dtype | ||||
@@ -226,10 +228,11 @@ class Host2DeviceCopy(OpNode): | |||||
assert self.outputs[0].owner is self | assert self.outputs[0].owner is self | ||||
class ImmutableTensor(OpNode): | |||||
type = "ImmutableTensor" | |||||
class ConstOpBase(OpNode): | |||||
type = "ConstOpBase" | |||||
def __init__(self, data=None, name=None, device=None, graph=None): | def __init__(self, data=None, name=None, device=None, graph=None): | ||||
assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated" | |||||
super().__init__() | super().__init__() | ||||
self.name = name | self.name = name | ||||
self.outputs = [] | self.outputs = [] | ||||
@@ -254,7 +257,7 @@ class ImmutableTensor(OpNode): | |||||
return self._opr.outputs[0].dtype if self._opr else None | return self._opr.outputs[0].dtype if self._opr else None | ||||
def numpy(self): | def numpy(self): | ||||
return self._opr.outputs[0].value if self._opr else None | |||||
return self.outputs[0].numpy() | |||||
def set_value(self, data, device=None): | def set_value(self, data, device=None): | ||||
assert self.graph is not None | assert self.graph is not None | ||||
@@ -266,7 +269,7 @@ class ImmutableTensor(OpNode): | |||||
data = data.astype(np.float32) | data = data.astype(np.float32) | ||||
elif data.dtype == np.int64: | elif data.dtype == np.int64: | ||||
data = data.astype(np.int32) | data = data.astype(np.int32) | ||||
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | |||||
varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name) | |||||
if len(self.outputs) == 0: | if len(self.outputs) == 0: | ||||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | self.outputs.append(VarNode(owner_opr=self, name=self.name)) | ||||
self.outputs[0].var = varnode | self.outputs[0].var = varnode | ||||
@@ -291,6 +294,16 @@ class ImmutableTensor(OpNode): | |||||
self.outputs[0].var.name = self.name | self.outputs[0].var.name = self.name | ||||
class ImmutableTensor(ConstOpBase): | |||||
type = "ImmutableTensor" | |||||
rt_fun = rt.make_const | |||||
class SharedDeviceTensor(ConstOpBase): | |||||
type = "SharedDeviceTensor" | |||||
rt_fun = rt.make_shared | |||||
class ReadOnlyOpNode(OpNode): | class ReadOnlyOpNode(OpNode): | ||||
@classmethod | @classmethod | ||||
def load(cls, opr): | def load(cls, opr): | ||||
@@ -130,6 +130,52 @@ def test_replace_opr(): | |||||
np.testing.assert_equal(out["o"], [0, 0]) | np.testing.assert_equal(out["o"], [0, 0]) | ||||
def test_splice_network(): | |||||
x = F.ones((2,)) | |||||
y = F.ones((2,)) | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fun1(a, b): | |||||
return (a + b) * 2 | |||||
@trace(symbolic=True, capture_as_const=True) | |||||
def fun2(a): | |||||
return a * 2 - 1 | |||||
model = io.BytesIO() | |||||
fun1(x, y) | |||||
fun2(x) | |||||
fun1.dump( | |||||
model, | |||||
arg_names=["net1_i0", "net1_i1"], | |||||
output_names=["net1_o0"], | |||||
optimize_for_inference=False, | |||||
) | |||||
model.seek(0) | |||||
net1 = Net.load(model) | |||||
model.seek(0) | |||||
fun2.dump( | |||||
model, | |||||
arg_names=["net2_i0"], | |||||
output_names=["net2_o0"], | |||||
optimize_for_inference=False, | |||||
) | |||||
model.seek(0) | |||||
net2 = Net.load(model) | |||||
net1.add_output(*net2.output_vars) | |||||
var = net1.var_filter.name("net1_i0").as_unique() | |||||
repl_var = net2.var_filter.name("net2_o0").as_unique() | |||||
net1.replace_vars({var: repl_var}) | |||||
assert "net1_i0" not in [var.name for var in net1.all_vars] | |||||
assert "net2_i0" in [var.name for var in net1.all_vars] | |||||
model.seek(0) | |||||
net1.dump(model, keep_var_name=2, optimize_for_inference=False) | |||||
model.seek(0) | |||||
net = Net.load(model) | |||||
assert "net1_i0" not in [var.name for var in net.all_vars] | |||||
assert "net2_i0" in [var.name for var in net.all_vars] | |||||
def test_modify_params(): | def test_modify_params(): | ||||
a = Tensor([1, 2]) | a = Tensor([1, 2]) | ||||