Browse Source

fix(mge/utils): fix bug of VarNode inplace operations

GitOrigin-RevId: fa9eec7079
release-1.4
Megvii Engine Team 4 years ago
parent
commit
2c2beac915
2 changed files with 94 additions and 30 deletions
  1. +9
    -13
      imperative/python/megengine/utils/network_node.py
  2. +85
    -17
      imperative/python/test/unit/core/test_tensor_wrapper.py

+ 9
- 13
imperative/python/megengine/utils/network_node.py View File

@@ -6,10 +6,9 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import json
import sys
from typing import Callable, Sequence
from typing import Sequence

import numpy as np

@@ -19,10 +18,7 @@ from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device
from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.indexing import getitem as _getitem
from ..core.tensor.indexing import setitem as _setitem
from ..core.tensor.megbrain_graph import InputNode, OutputNode
from ..tensor import Tensor
from ..core.tensor.megbrain_graph import OutputNode
from .comp_graph_tools import replace_vars
from .module_stats import (
preprocess_receptive_field,
@@ -110,18 +106,18 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
self.graph.compile(o.outputs).execute()
return o.get_value().numpy()

def __getitem__(self, index):
return _getitem(self, index)
def __setitem__(self, index, value):
if index is not Ellipsis:
value = _setitem(self, index, value)
def _reset(self, other):
if not isinstance(other, VarNode):
assert self.graph, "VarNode _reset must have graph"
node = ImmutableTensor(other, graph=self.graph)
node.compile(self.graph)
other = node.outputs[0]
if self.owner is not None:
idx = self.owner.outputs.index(self)
self.owner.outputs[idx] = VarNode(
self.var, owner_opr=self.owner, name=self.var.name
)
self.var = value.var
self.var = other.var
self.owner = None

def set_owner_opr(self, owner_opr):


+ 85
- 17
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -9,38 +9,81 @@
import copy

import numpy as np
import pytest
from utils import make_tensor

from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Tensor
from megengine.utils.network import Network


def test_basic():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_basic(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_np = np.random.rand(10).astype("float32")
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = x * x
y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * x_np)


def test_literal_arith():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_literal_arith(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_np = np.random.rand(10).astype("float32")
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = x * 2
y_np = y.numpy()
np.testing.assert_almost_equal(y_np, x_np * 2)


def test_matmul():
A = Tensor(np.random.rand(5, 7).astype("float32"))
B = Tensor(np.random.rand(7, 10).astype("float32"))
@pytest.mark.parametrize("is_varnode", [True, False])
def test_matmul(is_varnode):
if is_varnode:
network = Network()
else:
network = None

A = make_tensor(np.random.rand(5, 7).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network)
C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)


def test_reduce():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_inplace_add(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_np = np.random.rand(10).astype("float32")
y_np = np.random.rand(10).astype("float32")
x = make_tensor(x_np, network)
y = make_tensor(y_np, network)
y += x
out_np = y.numpy()
np.testing.assert_almost_equal(out_np, x_np + y_np)


@pytest.mark.parametrize("is_varnode", [True, False])
def test_reduce(is_varnode):
if is_varnode:
network = Network()
else:
network = None

def test_x(x_np):
for m in ["sum", "prod", "min", "max", "mean"]:
x = Tensor(x_np)
x = make_tensor(x_np, network)
y = getattr(x, m)(axis=-1, keepdims=True)
np.testing.assert_almost_equal(y.numpy(), getattr(x_np, m)(-1), decimal=6)

@@ -50,16 +93,28 @@ def test_reduce():
test_x(np.array([True, False, True]))


def test_set_value():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_set_value(is_varnode):
if is_varnode:
network = Network()
else:
network = None

v0 = np.random.random((2, 3)).astype(np.float32)
param = Tensor(v0)
param = make_tensor(v0, network)
v1 = np.random.random((2, 3)).astype(np.float32)
param[...] = v1
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)


def test_set_subtensor():
x = Tensor([1, 2, 3])
@pytest.mark.parametrize("is_varnode", [True, False])
def test_set_subtensor(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = make_tensor([1, 2, 3], network)
x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
x[[0, 2]] = [3, 2]
@@ -78,14 +133,27 @@ def test_computing_with_numpy_array():
np.testing.assert_equal(np.equal(xx, xx).numpy(), np.equal(x, x))


def test_transpose():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_transpose(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.random.rand(2, 5).astype("float32")
xx = Tensor(x)
xx = make_tensor(x, network)
np.testing.assert_almost_equal(xx.T.numpy(), x.T)


def test_as_type():
x = Tensor([1, 2, 3], dtype=np.float32)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_as_type(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_np = np.array([1, 2, 3], dtype=np.float32)
x = make_tensor(x_np, network)
y = x.astype(qint8(0.1))
np.testing.assert_almost_equal(get_scale(y.dtype), 0.1)
z = y.astype(qint8(0.2))


Loading…
Cancel
Save