From dedecf6922db60e69e51655ed82767bfaf44b2af Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 23 Jun 2021 18:12:47 +0800 Subject: [PATCH] fix(imperative/utils): fix logical error of replace var GitOrigin-RevId: 614302552cbeaa66cbc977ee81e5492b6023c1c4 --- .../python/megengine/core/tensor/megbrain_graph.py | 3 +- imperative/python/megengine/utils/network.py | 45 +++++++++++++++------ imperative/python/megengine/utils/network_node.py | 21 ++++++++-- imperative/python/test/unit/utils/test_network.py | 46 ++++++++++++++++++++++ 4 files changed, 96 insertions(+), 19 deletions(-) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 2cb81874..ea86c46e 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -519,8 +519,7 @@ def _unwrap(x): return type(x)(map(_unwrap, x)) if isinstance(x, VarNode): return x._node - else: - return x + return x def apply_normal_varnode(op: OpDef, *args: VarNode): diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 2cf62d32..91fc8ec1 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -12,14 +12,16 @@ import itertools import pickle import re from collections import OrderedDict -from typing import Any, Dict, List, Sequence +from typing import Any, Dict, List, Optional, Sequence +from ..core import _imperative_rt from ..core._imperative_rt import ComputingGraph, SerializationMetadata from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape from ..core.tensor import megbrain_graph as G from ..logger import get_logger from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq from .network_node import ( + ConstOpBase, Host2DeviceCopy, ImmutableTensor, NetworkNode, @@ -37,8 +39,10 @@ class Network: self._orig_inputs = [] self.output_vars = [] # output var of graph self._orig_outputs = [] - self.all_oprs_map = OrderedDict() - self.all_vars_map = OrderedDict() + self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode + self.all_vars_map = ( + OrderedDict() + ) # _imperative_rt.graph.OperatorNode.id: OpNode self.graph = ComputingGraph() self._metadata = None @@ -101,7 +105,7 @@ class Network: self.all_oprs_map = {} self.all_vars_map = {} for opr in self.all_oprs: - if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): + if isinstance(opr, (ConstOpBase, Host2DeviceCopy)): opr.compile(self.graph) else: opr.compile() @@ -295,6 +299,9 @@ class Network: def add_dep_oprs(self, *vars): if len(vars) == 0: vars = self.output_vars + + assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode" + q = list(vars) while len(q) > 0: cur = q.pop(0) @@ -368,11 +375,14 @@ class Network: for var in self.all_vars: if var in repl_dict: repl_var = repl_dict[var] - owner = repl_var.owner - idx = owner.outputs.index(repl_var) - owner.outputs[idx] = var - var.__dict__.update(repl_var.__dict__) - var.var = repl_var.var + if repl_var is var: + continue + for opnode in var.users: + assert var in opnode.inputs + opnode.inputs = [repl_var if var is i else i for i in opnode.inputs] + if opnode not in repl_var.users: + repl_var.users.append(opnode) + var.users.clear() self._compile() def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): @@ -473,14 +483,20 @@ class Network: def all_oprs_dict(self): return self.opr_filter.as_dict() - # used for loading and building graph - def _add_opr(self, opr): + def _add_opr(self, opr) -> Optional[OpNode]: + """ + Used for loading and building graph. + """ + assert isinstance(opr, _imperative_rt.graph.OperatorNode) + # TODO: use megbrain C++ RTTI to replace type string if opr.id not in self.all_oprs_map: opnode = str_to_mge_class(get_opr_type(opr)).load(opr) self.all_oprs_map[opr.id] = opnode for var in opr.inputs: - opnode.add_inp_var(self._get_var(var)) + varnode = self._get_var(var) + opnode.add_inp_var(varnode) + varnode.users.append(opnode) for var in opr.outputs: opnode.add_out_var(self._get_var(var)) return opnode @@ -503,7 +519,10 @@ class Network: return None def _get_var(self, x): - # auto convert to VarNode of Network + """ + Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`. + """ + assert isinstance(x, _imperative_rt.graph.VarNode) if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) return self.all_vars_map[x.id] diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index f94fa86d..a77d38ea 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -37,6 +37,7 @@ class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): def __init__(self, var=None, *, owner_opr=None, name=None): SymbolVar.__init__(self, var) + self.users = [] # List[OpNode] self.owner = owner_opr self.name = name self.id = id(self) @@ -214,6 +215,7 @@ class Host2DeviceCopy(OpNode): def compile(self, graph): if ( self._opr is None + or self._opr.graph != graph or self._opr.outputs[0].comp_node != self.device or self._opr.outputs[0].shape != self.shape or self._opr.outputs[0].dtype != self.dtype @@ -226,10 +228,11 @@ class Host2DeviceCopy(OpNode): assert self.outputs[0].owner is self -class ImmutableTensor(OpNode): - type = "ImmutableTensor" +class ConstOpBase(OpNode): + type = "ConstOpBase" def __init__(self, data=None, name=None, device=None, graph=None): + assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated" super().__init__() self.name = name self.outputs = [] @@ -254,7 +257,7 @@ class ImmutableTensor(OpNode): return self._opr.outputs[0].dtype if self._opr else None def numpy(self): - return self._opr.outputs[0].value if self._opr else None + return self.outputs[0].numpy() def set_value(self, data, device=None): assert self.graph is not None @@ -266,7 +269,7 @@ class ImmutableTensor(OpNode): data = data.astype(np.float32) elif data.dtype == np.int64: data = data.astype(np.int32) - varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) + varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name) if len(self.outputs) == 0: self.outputs.append(VarNode(owner_opr=self, name=self.name)) self.outputs[0].var = varnode @@ -291,6 +294,16 @@ class ImmutableTensor(OpNode): self.outputs[0].var.name = self.name +class ImmutableTensor(ConstOpBase): + type = "ImmutableTensor" + rt_fun = rt.make_const + + +class SharedDeviceTensor(ConstOpBase): + type = "SharedDeviceTensor" + rt_fun = rt.make_shared + + class ReadOnlyOpNode(OpNode): @classmethod def load(cls, opr): diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index b7760872..7c55d914 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -130,6 +130,52 @@ def test_replace_opr(): np.testing.assert_equal(out["o"], [0, 0]) +def test_splice_network(): + x = F.ones((2,)) + y = F.ones((2,)) + + @trace(symbolic=True, capture_as_const=True) + def fun1(a, b): + return (a + b) * 2 + + @trace(symbolic=True, capture_as_const=True) + def fun2(a): + return a * 2 - 1 + + model = io.BytesIO() + fun1(x, y) + fun2(x) + fun1.dump( + model, + arg_names=["net1_i0", "net1_i1"], + output_names=["net1_o0"], + optimize_for_inference=False, + ) + model.seek(0) + net1 = Net.load(model) + model.seek(0) + fun2.dump( + model, + arg_names=["net2_i0"], + output_names=["net2_o0"], + optimize_for_inference=False, + ) + model.seek(0) + net2 = Net.load(model) + net1.add_output(*net2.output_vars) + var = net1.var_filter.name("net1_i0").as_unique() + repl_var = net2.var_filter.name("net2_o0").as_unique() + net1.replace_vars({var: repl_var}) + assert "net1_i0" not in [var.name for var in net1.all_vars] + assert "net2_i0" in [var.name for var in net1.all_vars] + model.seek(0) + net1.dump(model, keep_var_name=2, optimize_for_inference=False) + model.seek(0) + net = Net.load(model) + assert "net1_i0" not in [var.name for var in net.all_vars] + assert "net2_i0" in [var.name for var in net.all_vars] + + def test_modify_params(): a = Tensor([1, 2])