Browse Source

fix(mge/utils): use static infer manager to get value of network.varnode

GitOrigin-RevId: ecc47edab8
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
7225b0f09f
4 changed files with 39 additions and 17 deletions
  1. +3
    -8
      imperative/python/megengine/utils/network_node.py
  2. +13
    -1
      imperative/python/test/helpers/utils.py
  3. +16
    -5
      imperative/python/test/unit/core/test_tensor_wrapper.py
  4. +7
    -3
      imperative/python/test/unit/functional/test_tensor.py

+ 3
- 8
imperative/python/megengine/utils/network_node.py View File

@@ -18,7 +18,6 @@ from ..core._trace_option import use_symbolic_shape
from ..core._wrap import Device
from ..core.ops import builtin
from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.megbrain_graph import OutputNode
from .comp_graph_tools import replace_vars
from .module_stats import (
preprocess_receptive_field,
@@ -106,9 +105,7 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
return id(self)

def numpy(self):
o = OutputNode(self.var)
self.graph.compile(o.outputs).execute()
return o.get_value().numpy()
return super().numpy()

def _reset(self, other):
if not isinstance(other, VarNode):
@@ -141,15 +138,13 @@ class OpNode(NetworkNode):

@property
def id(self):
if self._opr is not None:
return self._opr.id
return id(self)

@property
def priority(self):
if self._opr is not None:
return self._opr.priority
return 0
return (self._opr.priority, self._opr.id)
return (0, 0)

@classmethod
def load(cls, opr):


+ 13
- 1
imperative/python/test/helpers/utils.py View File

@@ -5,6 +5,7 @@ import numpy as np
import megengine.core.tensor.megbrain_graph as G
import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor
from megengine.core.tensor.megbrain_graph import OutputNode
from megengine.jit import trace
from megengine.utils.network_node import VarNode

@@ -12,8 +13,10 @@ from megengine.utils.network_node import VarNode
def _default_compare_fn(x, y):
if isinstance(x, np.ndarray):
np.testing.assert_allclose(x, y, rtol=1e-6)
else:
elif isinstance(x, tensor):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
else:
np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6)


def make_tensor(x, network=None, device=None):
@@ -25,6 +28,15 @@ def make_tensor(x, network=None, device=None):
return tensor(x, device=device)


def get_var_value(x):
try:
o = OutputNode(x.var)
o.graph.compile(o.outputs).execute()
return o.get_value().numpy()
except RuntimeError:
raise ValueError("value invalid!")


def opr_test(
cases,
func,


+ 16
- 5
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -10,7 +10,7 @@ import copy

import numpy as np
import pytest
from utils import make_tensor
from utils import get_var_value, make_tensor

from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
from megengine.tensor import Parameter, Tensor
@@ -55,7 +55,12 @@ def test_matmul(is_varnode):
A = make_tensor(np.random.rand(5, 7).astype("float32"), network)
B = make_tensor(np.random.rand(7, 10).astype("float32"), network)
C = A @ B
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)
if is_varnode:
np.testing.assert_almost_equal(
get_var_value(C), get_var_value(A) @ get_var_value(B), decimal=6
)
else:
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6)


@pytest.mark.parametrize("is_varnode", [True, False])
@@ -116,11 +121,17 @@ def test_set_subtensor(is_varnode):

x = make_tensor([1, 2, 3], network)
x[:] = [1, 1, 1]
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6)
np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [1, 1, 1], decimal=6
)
x[[0, 2]] = [3, 2]
np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6)
np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [3, 1, 2], decimal=6
)
x[1:3] = [4, 5]
np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6)
np.testing.assert_almost_equal(
get_var_value(x) if is_varnode else x.numpy(), [3, 4, 5], decimal=6
)


def test_computing_with_numpy_array():


+ 7
- 3
imperative/python/test/unit/functional/test_tensor.py View File

@@ -11,7 +11,7 @@ import platform

import numpy as np
import pytest
from utils import make_tensor, opr_test
from utils import get_var_value, make_tensor, opr_test

import megengine.functional as F
from megengine import tensor
@@ -75,8 +75,12 @@ def test_condtake(is_varnode):
xx = make_tensor(x, network)
yy = make_tensor(y, network)
val, idx = F.cond_take(yy, xx)
np.testing.assert_equal(val.numpy(), x[y])
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
if is_varnode:
np.testing.assert_equal(get_var_value(val), x[y])
np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0])
else:
np.testing.assert_equal(val.numpy(), x[y])
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])


@pytest.mark.parametrize("is_varnode", [True, False])


Loading…
Cancel
Save