Browse Source

fix(imperative/utils): fix logical error of replace var

GitOrigin-RevId: 614302552c
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
dedecf6922
4 changed files with 96 additions and 19 deletions
  1. +1
    -2
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +32
    -13
      imperative/python/megengine/utils/network.py
  3. +17
    -4
      imperative/python/megengine/utils/network_node.py
  4. +46
    -0
      imperative/python/test/unit/utils/test_network.py

+ 1
- 2
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -519,8 +519,7 @@ def _unwrap(x):
return type(x)(map(_unwrap, x)) return type(x)(map(_unwrap, x))
if isinstance(x, VarNode): if isinstance(x, VarNode):
return x._node return x._node
else:
return x
return x




def apply_normal_varnode(op: OpDef, *args: VarNode): def apply_normal_varnode(op: OpDef, *args: VarNode):


+ 32
- 13
imperative/python/megengine/utils/network.py View File

@@ -12,14 +12,16 @@ import itertools
import pickle import pickle
import re import re
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Optional, Sequence


from ..core import _imperative_rt
from ..core._imperative_rt import ComputingGraph, SerializationMetadata from ..core._imperative_rt import ComputingGraph, SerializationMetadata
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..logger import get_logger from ..logger import get_logger
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import ( from .network_node import (
ConstOpBase,
Host2DeviceCopy, Host2DeviceCopy,
ImmutableTensor, ImmutableTensor,
NetworkNode, NetworkNode,
@@ -37,8 +39,10 @@ class Network:
self._orig_inputs = [] self._orig_inputs = []
self.output_vars = [] # output var of graph self.output_vars = [] # output var of graph
self._orig_outputs = [] self._orig_outputs = []
self.all_oprs_map = OrderedDict()
self.all_vars_map = OrderedDict()
self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode
self.all_vars_map = (
OrderedDict()
) # _imperative_rt.graph.OperatorNode.id: OpNode
self.graph = ComputingGraph() self.graph = ComputingGraph()
self._metadata = None self._metadata = None


@@ -101,7 +105,7 @@ class Network:
self.all_oprs_map = {} self.all_oprs_map = {}
self.all_vars_map = {} self.all_vars_map = {}
for opr in self.all_oprs: for opr in self.all_oprs:
if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)):
if isinstance(opr, (ConstOpBase, Host2DeviceCopy)):
opr.compile(self.graph) opr.compile(self.graph)
else: else:
opr.compile() opr.compile()
@@ -295,6 +299,9 @@ class Network:
def add_dep_oprs(self, *vars): def add_dep_oprs(self, *vars):
if len(vars) == 0: if len(vars) == 0:
vars = self.output_vars vars = self.output_vars

assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode"

q = list(vars) q = list(vars)
while len(q) > 0: while len(q) > 0:
cur = q.pop(0) cur = q.pop(0)
@@ -368,11 +375,14 @@ class Network:
for var in self.all_vars: for var in self.all_vars:
if var in repl_dict: if var in repl_dict:
repl_var = repl_dict[var] repl_var = repl_dict[var]
owner = repl_var.owner
idx = owner.outputs.index(repl_var)
owner.outputs[idx] = var
var.__dict__.update(repl_var.__dict__)
var.var = repl_var.var
if repl_var is var:
continue
for opnode in var.users:
assert var in opnode.inputs
opnode.inputs = [repl_var if var is i else i for i in opnode.inputs]
if opnode not in repl_var.users:
repl_var.users.append(opnode)
var.users.clear()
self._compile() self._compile()


def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
@@ -473,14 +483,20 @@ class Network:
def all_oprs_dict(self): def all_oprs_dict(self):
return self.opr_filter.as_dict() return self.opr_filter.as_dict()


# used for loading and building graph
def _add_opr(self, opr):
def _add_opr(self, opr) -> Optional[OpNode]:
"""
Used for loading and building graph.
"""
assert isinstance(opr, _imperative_rt.graph.OperatorNode)

# TODO: use megbrain C++ RTTI to replace type string # TODO: use megbrain C++ RTTI to replace type string
if opr.id not in self.all_oprs_map: if opr.id not in self.all_oprs_map:
opnode = str_to_mge_class(get_opr_type(opr)).load(opr) opnode = str_to_mge_class(get_opr_type(opr)).load(opr)
self.all_oprs_map[opr.id] = opnode self.all_oprs_map[opr.id] = opnode
for var in opr.inputs: for var in opr.inputs:
opnode.add_inp_var(self._get_var(var))
varnode = self._get_var(var)
opnode.add_inp_var(varnode)
varnode.users.append(opnode)
for var in opr.outputs: for var in opr.outputs:
opnode.add_out_var(self._get_var(var)) opnode.add_out_var(self._get_var(var))
return opnode return opnode
@@ -503,7 +519,10 @@ class Network:
return None return None


def _get_var(self, x): def _get_var(self, x):
# auto convert to VarNode of Network
"""
Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`.
"""
assert isinstance(x, _imperative_rt.graph.VarNode)
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x:
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner))
return self.all_vars_map[x.id] return self.all_vars_map[x.id]


+ 17
- 4
imperative/python/megengine/utils/network_node.py View File

@@ -37,6 +37,7 @@ class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def __init__(self, var=None, *, owner_opr=None, name=None): def __init__(self, var=None, *, owner_opr=None, name=None):
SymbolVar.__init__(self, var) SymbolVar.__init__(self, var)
self.users = [] # List[OpNode]
self.owner = owner_opr self.owner = owner_opr
self.name = name self.name = name
self.id = id(self) self.id = id(self)
@@ -214,6 +215,7 @@ class Host2DeviceCopy(OpNode):
def compile(self, graph): def compile(self, graph):
if ( if (
self._opr is None self._opr is None
or self._opr.graph != graph
or self._opr.outputs[0].comp_node != self.device or self._opr.outputs[0].comp_node != self.device
or self._opr.outputs[0].shape != self.shape or self._opr.outputs[0].shape != self.shape
or self._opr.outputs[0].dtype != self.dtype or self._opr.outputs[0].dtype != self.dtype
@@ -226,10 +228,11 @@ class Host2DeviceCopy(OpNode):
assert self.outputs[0].owner is self assert self.outputs[0].owner is self




class ImmutableTensor(OpNode):
type = "ImmutableTensor"
class ConstOpBase(OpNode):
type = "ConstOpBase"


def __init__(self, data=None, name=None, device=None, graph=None): def __init__(self, data=None, name=None, device=None, graph=None):
assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated"
super().__init__() super().__init__()
self.name = name self.name = name
self.outputs = [] self.outputs = []
@@ -254,7 +257,7 @@ class ImmutableTensor(OpNode):
return self._opr.outputs[0].dtype if self._opr else None return self._opr.outputs[0].dtype if self._opr else None


def numpy(self): def numpy(self):
return self._opr.outputs[0].value if self._opr else None
return self.outputs[0].numpy()


def set_value(self, data, device=None): def set_value(self, data, device=None):
assert self.graph is not None assert self.graph is not None
@@ -266,7 +269,7 @@ class ImmutableTensor(OpNode):
data = data.astype(np.float32) data = data.astype(np.float32)
elif data.dtype == np.int64: elif data.dtype == np.int64:
data = data.astype(np.int32) data = data.astype(np.int32)
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name)
if len(self.outputs) == 0: if len(self.outputs) == 0:
self.outputs.append(VarNode(owner_opr=self, name=self.name)) self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = varnode self.outputs[0].var = varnode
@@ -291,6 +294,16 @@ class ImmutableTensor(OpNode):
self.outputs[0].var.name = self.name self.outputs[0].var.name = self.name




class ImmutableTensor(ConstOpBase):
type = "ImmutableTensor"
rt_fun = rt.make_const


class SharedDeviceTensor(ConstOpBase):
type = "SharedDeviceTensor"
rt_fun = rt.make_shared


class ReadOnlyOpNode(OpNode): class ReadOnlyOpNode(OpNode):
@classmethod @classmethod
def load(cls, opr): def load(cls, opr):


+ 46
- 0
imperative/python/test/unit/utils/test_network.py View File

@@ -130,6 +130,52 @@ def test_replace_opr():
np.testing.assert_equal(out["o"], [0, 0]) np.testing.assert_equal(out["o"], [0, 0])




def test_splice_network():
x = F.ones((2,))
y = F.ones((2,))

@trace(symbolic=True, capture_as_const=True)
def fun1(a, b):
return (a + b) * 2

@trace(symbolic=True, capture_as_const=True)
def fun2(a):
return a * 2 - 1

model = io.BytesIO()
fun1(x, y)
fun2(x)
fun1.dump(
model,
arg_names=["net1_i0", "net1_i1"],
output_names=["net1_o0"],
optimize_for_inference=False,
)
model.seek(0)
net1 = Net.load(model)
model.seek(0)
fun2.dump(
model,
arg_names=["net2_i0"],
output_names=["net2_o0"],
optimize_for_inference=False,
)
model.seek(0)
net2 = Net.load(model)
net1.add_output(*net2.output_vars)
var = net1.var_filter.name("net1_i0").as_unique()
repl_var = net2.var_filter.name("net2_o0").as_unique()
net1.replace_vars({var: repl_var})
assert "net1_i0" not in [var.name for var in net1.all_vars]
assert "net2_i0" in [var.name for var in net1.all_vars]
model.seek(0)
net1.dump(model, keep_var_name=2, optimize_for_inference=False)
model.seek(0)
net = Net.load(model)
assert "net1_i0" not in [var.name for var in net.all_vars]
assert "net2_i0" in [var.name for var in net.all_vars]


def test_modify_params(): def test_modify_params():


a = Tensor([1, 2]) a = Tensor([1, 2])


Loading…
Cancel
Save