Browse Source

fix(mge/utils): fix bug of VarNode inplace operations

GitOrigin-RevId: fa9eec7079
release-1.4
Megvii Engine Team 4 years ago
parent
commit
2c2beac915
2 changed files with 94 additions and 30 deletions
  1. +9
    -13
      imperative/python/megengine/utils/network_node.py
  2. +85
    -17
      imperative/python/test/unit/core/test_tensor_wrapper.py

+ 9
- 13
imperative/python/megengine/utils/network_node.py View File

@@ -6,10 +6,9 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import json import json
import sys import sys
from typing import Callable, Sequence
from typing import Sequence


import numpy as np import numpy as np


@@ -19,10 +18,7 @@ from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device from ..core._wrap import Device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.indexing import getitem as _getitem
from ..core.tensor.indexing import setitem as _setitem
from ..core.tensor.megbrain_graph import InputNode, OutputNode
from ..tensor import Tensor
from ..core.tensor.megbrain_graph import OutputNode
from .comp_graph_tools import replace_vars from .comp_graph_tools import replace_vars
from .module_stats import ( from .module_stats import (
preprocess_receptive_field, preprocess_receptive_field,
@@ -110,18 +106,18 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
self.graph.compile(o.outputs).execute() self.graph.compile(o.outputs).execute()
return o.get_value().numpy() return o.get_value().numpy()


def __getitem__(self, index):
return _getitem(self, index)
def __setitem__(self, index, value):
if index is not Ellipsis:
value = _setitem(self, index, value)
def _reset(self, other):
if not isinstance(other, VarNode):
assert self.graph, "VarNode _reset must have graph"
node = ImmutableTensor(other, graph=self.graph)
node.compile(self.graph)
other = node.outputs[0]
if self.owner is not None: if self.owner is not None:
idx = self.owner.outputs.index(self) idx = self.owner.outputs.index(self)
self.owner.outputs[idx] = VarNode( self.owner.outputs[idx] = VarNode(
self.var, owner_opr=self.owner, name=self.var.name self.var, owner_opr=self.owner, name=self.var.name
) )
self.var = value.var
self.var = other.var
self.owner = None self.owner = None


def set_owner_opr(self, owner_opr): def set_owner_opr(self, owner_opr):


+ 85
- 17
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -9,38 +9,81 @@
import copy import copy


import numpy as np import numpy as np
import pytest
from utils import make_tensor


from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Tensor from megengine.tensor import Tensor
from megengine.utils.network import Network




def test_basic():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_basic(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = x * x y = x * x
y_np = y.numpy() y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * x_np) np.testing.assert_almost_equal(y_np, x_np * x_np)




def test_literal_arith():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_literal_arith(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = x * 2 y = x * 2
y_np = y.numpy() y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * 2) np.testing.assert_almost_equal(y_np, x_np * 2)




def test_matmul():
A = Tensor(np.random.rand(5, 7).astype("float32"))
B = Tensor(np.random.rand(7, 10).astype("float32"))
@pytest.mark.parametrize("is_varnode", [True, False])
def test_matmul(is_varnode):
if is_varnode:
network = Network()
else:
network = None

A = make_tensor(np.random.rand(5, 7).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network)
C = A @ B C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)




def test_reduce():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_inplace_add(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_np = np.random.rand(10).astype("float32")
y_np = np.random.rand(10).astype("float32")
x = make_tensor(x_np, network)
y = make_tensor(y_np, network)
y += x
out_np = y.numpy()
np.testing.assert_almost_equal(out_np, x_np + y_np)


@pytest.mark.parametrize("is_varnode", [True, False])
def test_reduce(is_varnode):
if is_varnode:
network = Network()
else:
network = None

def test_x(x_np): def test_x(x_np):
for m in ["sum", "prod", "min", "max", "mean"]: for m in ["sum", "prod", "min", "max", "mean"]:
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = getattr(x, m)(axis=-1, keepdims=True) y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6) np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)


@@ -50,16 +93,28 @@ def test_reduce():
test_x(np.array([True, False, True])) test_x(np.array([True, False, True]))




def test_set_value():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_set_value(is_varnode):
if is_varnode:
network = Network()
else:
network = None

v0 = np.random.random((2, 3)).astype(np.float32) v0 = np.random.random((2, 3)).astype(np.float32)
param = Tensor(v0)
param = make_tensor(v0, network)
v1 = np.random.random((2, 3)).astype(np.float32) v1 = np.random.random((2, 3)).astype(np.float32)
param[...] = v1 param[...] = v1
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6) np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)




def test_set_subtensor():
x = Tensor([1, 2, 3])
@pytest.mark.parametrize("is_varnode", [True, False])
def test_set_subtensor(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = make_tensor([1, 2, 3], network)
x[:] = [1, 1, 1] x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
x[[0, 2]] = [3, 2] x[[0, 2]] = [3, 2]
@@ -78,14 +133,27 @@ def test_computing_with_numpy_array():
np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x)) np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x))




def test_transpose():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_transpose(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.random.rand(2, 5).astype("float32") x = np.random.rand(2, 5).astype("float32")
xx = Tensor(x)
xx = make_tensor(x, network)
np.testing.assert_almost_equal(xx.T.numpy(), x.T) np.testing.assert_almost_equal(xx.T.numpy(), x.T)




def test_as_type():
x = Tensor([1, 2, 3], dtype=np.float32)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_as_type(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_np = np.array([1, 2, 3], dtype=np.float32)
x = make_tensor(x_np, network)
y = x.astype(qint8(0.1)) y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1) np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2)) z = y.astype(qint8(0.2))


Loading…
Cancel
Save