Browse Source

fix(imperative/utils): fix logical error of replace var

GitOrigin-RevId: 614302552c
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
dedecf6922
4 changed files with 96 additions and 19 deletions
  1. +1
    -2
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +32
    -13
      imperative/python/megengine/utils/network.py
  3. +17
    -4
      imperative/python/megengine/utils/network_node.py
  4. +46
    -0
      imperative/python/test/unit/utils/test_network.py

+ 1
- 2
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -519,8 +519,7 @@ def _unwrap(x):
return type(x)(map(_unwrap, x))
if isinstance(x, VarNode):
return x._node
else:
return x
return x


def apply_normal_varnode(op: OpDef, *args: VarNode):


+ 32
- 13
imperative/python/megengine/utils/network.py View File

@@ -12,14 +12,16 @@ import itertools
import pickle
import re
from collections import OrderedDict
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, List, Optional, Sequence

from ..core import _imperative_rt
from ..core._imperative_rt import ComputingGraph, SerializationMetadata
from ..core._trace_option import set_symbolic_shape as _set_symbolic_shape
from ..core.tensor import megbrain_graph as G
from ..logger import get_logger
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import (
ConstOpBase,
Host2DeviceCopy,
ImmutableTensor,
NetworkNode,
@@ -37,8 +39,10 @@ class Network:
self._orig_inputs = []
self.output_vars = [] # output var of graph
self._orig_outputs = []
self.all_oprs_map = OrderedDict()
self.all_vars_map = OrderedDict()
self.all_oprs_map = OrderedDict() # _imperative_rt.graph.VarNode.id: VarNode
self.all_vars_map = (
OrderedDict()
) # _imperative_rt.graph.OperatorNode.id: OpNode
self.graph = ComputingGraph()
self._metadata = None

@@ -101,7 +105,7 @@ class Network:
self.all_oprs_map = {}
self.all_vars_map = {}
for opr in self.all_oprs:
if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)):
if isinstance(opr, (ConstOpBase, Host2DeviceCopy)):
opr.compile(self.graph)
else:
opr.compile()
@@ -295,6 +299,9 @@ class Network:
def add_dep_oprs(self, *vars):
if len(vars) == 0:
vars = self.output_vars

assert all(isinstance(var, VarNode) for var in vars), "Only support add VarNode"

q = list(vars)
while len(q) > 0:
cur = q.pop(0)
@@ -368,11 +375,14 @@ class Network:
for var in self.all_vars:
if var in repl_dict:
repl_var = repl_dict[var]
owner = repl_var.owner
idx = owner.outputs.index(repl_var)
owner.outputs[idx] = var
var.__dict__.update(repl_var.__dict__)
var.var = repl_var.var
if repl_var is var:
continue
for opnode in var.users:
assert var in opnode.inputs
opnode.inputs = [repl_var if var is i else i for i in opnode.inputs]
if opnode not in repl_var.users:
repl_var.users.append(opnode)
var.users.clear()
self._compile()

def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
@@ -473,14 +483,20 @@ class Network:
def all_oprs_dict(self):
return self.opr_filter.as_dict()

# used for loading and building graph
def _add_opr(self, opr):
def _add_opr(self, opr) -> Optional[OpNode]:
"""
Used for loading and building graph.
"""
assert isinstance(opr, _imperative_rt.graph.OperatorNode)

# TODO: use megbrain C++ RTTI to replace type string
if opr.id not in self.all_oprs_map:
opnode = str_to_mge_class(get_opr_type(opr)).load(opr)
self.all_oprs_map[opr.id] = opnode
for var in opr.inputs:
opnode.add_inp_var(self._get_var(var))
varnode = self._get_var(var)
opnode.add_inp_var(varnode)
varnode.users.append(opnode)
for var in opr.outputs:
opnode.add_out_var(self._get_var(var))
return opnode
@@ -503,7 +519,10 @@ class Network:
return None

def _get_var(self, x):
# auto convert to VarNode of Network
"""
Convert :class:`~._imperative_rt.graph.VarNode` to :class:`~.VarNode`.
"""
assert isinstance(x, _imperative_rt.graph.VarNode)
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x:
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner))
return self.all_vars_map[x.id]


+ 17
- 4
imperative/python/megengine/utils/network_node.py View File

@@ -37,6 +37,7 @@ class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def __init__(self, var=None, *, owner_opr=None, name=None):
SymbolVar.__init__(self, var)
self.users = [] # List[OpNode]
self.owner = owner_opr
self.name = name
self.id = id(self)
@@ -214,6 +215,7 @@ class Host2DeviceCopy(OpNode):
def compile(self, graph):
if (
self._opr is None
or self._opr.graph != graph
or self._opr.outputs[0].comp_node != self.device
or self._opr.outputs[0].shape != self.shape
or self._opr.outputs[0].dtype != self.dtype
@@ -226,10 +228,11 @@ class Host2DeviceCopy(OpNode):
assert self.outputs[0].owner is self


class ImmutableTensor(OpNode):
type = "ImmutableTensor"
class ConstOpBase(OpNode):
type = "ConstOpBase"

def __init__(self, data=None, name=None, device=None, graph=None):
assert type(self) is not ConstOpBase, "ConstOpBase cannot be instantiated"
super().__init__()
self.name = name
self.outputs = []
@@ -254,7 +257,7 @@ class ImmutableTensor(OpNode):
return self._opr.outputs[0].dtype if self._opr else None

def numpy(self):
return self._opr.outputs[0].value if self._opr else None
return self.outputs[0].numpy()

def set_value(self, data, device=None):
assert self.graph is not None
@@ -266,7 +269,7 @@ class ImmutableTensor(OpNode):
data = data.astype(np.float32)
elif data.dtype == np.int64:
data = data.astype(np.int32)
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name)
if len(self.outputs) == 0:
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = varnode
@@ -291,6 +294,16 @@ class ImmutableTensor(OpNode):
self.outputs[0].var.name = self.name


class ImmutableTensor(ConstOpBase):
type = "ImmutableTensor"
rt_fun = rt.make_const


class SharedDeviceTensor(ConstOpBase):
type = "SharedDeviceTensor"
rt_fun = rt.make_shared


class ReadOnlyOpNode(OpNode):
@classmethod
def load(cls, opr):


+ 46
- 0
imperative/python/test/unit/utils/test_network.py View File

@@ -130,6 +130,52 @@ def test_replace_opr():
np.testing.assert_equal(out["o"], [0, 0])


def test_splice_network():
x = F.ones((2,))
y = F.ones((2,))

@trace(symbolic=True, capture_as_const=True)
def fun1(a, b):
return (a + b) * 2

@trace(symbolic=True, capture_as_const=True)
def fun2(a):
return a * 2 - 1

model = io.BytesIO()
fun1(x, y)
fun2(x)
fun1.dump(
model,
arg_names=["net1_i0", "net1_i1"],
output_names=["net1_o0"],
optimize_for_inference=False,
)
model.seek(0)
net1 = Net.load(model)
model.seek(0)
fun2.dump(
model,
arg_names=["net2_i0"],
output_names=["net2_o0"],
optimize_for_inference=False,
)
model.seek(0)
net2 = Net.load(model)
net1.add_output(*net2.output_vars)
var = net1.var_filter.name("net1_i0").as_unique()
repl_var = net2.var_filter.name("net2_o0").as_unique()
net1.replace_vars({var: repl_var})
assert "net1_i0" not in [var.name for var in net1.all_vars]
assert "net2_i0" in [var.name for var in net1.all_vars]
model.seek(0)
net1.dump(model, keep_var_name=2, optimize_for_inference=False)
model.seek(0)
net = Net.load(model)
assert "net1_i0" not in [var.name for var in net.all_vars]
assert "net2_i0" in [var.name for var in net.all_vars]


def test_modify_params():

a = Tensor([1, 2])


Loading…
Cancel
Save