GitOrigin-RevId: ecc47edab8
release-1.5
@@ -18,7 +18,6 @@ from ..core._trace_option import use_symbolic_shape | |||
from ..core._wrap import Device | |||
from ..core.ops import builtin | |||
from ..core.tensor.array_method import ArrayMethodMixin | |||
from ..core.tensor.megbrain_graph import OutputNode | |||
from .comp_graph_tools import replace_vars | |||
from .module_stats import ( | |||
preprocess_receptive_field, | |||
@@ -106,9 +105,7 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
return id(self) | |||
def numpy(self): | |||
o = OutputNode(self.var) | |||
self.graph.compile(o.outputs).execute() | |||
return o.get_value().numpy() | |||
return super().numpy() | |||
def _reset(self, other): | |||
if not isinstance(other, VarNode): | |||
@@ -141,15 +138,13 @@ class OpNode(NetworkNode): | |||
@property | |||
def id(self): | |||
if self._opr is not None: | |||
return self._opr.id | |||
return id(self) | |||
@property | |||
def priority(self): | |||
if self._opr is not None: | |||
return self._opr.priority | |||
return 0 | |||
return (self._opr.priority, self._opr.id) | |||
return (0, 0) | |||
@classmethod | |||
def load(cls, opr): | |||
@@ -5,6 +5,7 @@ import numpy as np | |||
import megengine.core.tensor.megbrain_graph as G | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine import tensor | |||
from megengine.core.tensor.megbrain_graph import OutputNode | |||
from megengine.jit import trace | |||
from megengine.utils.network_node import VarNode | |||
@@ -12,8 +13,10 @@ from megengine.utils.network_node import VarNode | |||
def _default_compare_fn(x, y): | |||
if isinstance(x, np.ndarray): | |||
np.testing.assert_allclose(x, y, rtol=1e-6) | |||
else: | |||
elif isinstance(x, tensor): | |||
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
else: | |||
np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6) | |||
def make_tensor(x, network=None, device=None): | |||
@@ -25,6 +28,15 @@ def make_tensor(x, network=None, device=None): | |||
return tensor(x, device=device) | |||
def get_var_value(x): | |||
try: | |||
o = OutputNode(x.var) | |||
o.graph.compile(o.outputs).execute() | |||
return o.get_value().numpy() | |||
except RuntimeError: | |||
raise ValueError("value invalid!") | |||
def opr_test( | |||
cases, | |||
func, | |||
@@ -10,7 +10,7 @@ import copy | |||
import numpy as np | |||
import pytest | |||
from utils import make_tensor | |||
from utils import get_var_value, make_tensor | |||
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | |||
from megengine.tensor import Parameter, Tensor | |||
@@ -55,7 +55,12 @@ def test_matmul(is_varnode): | |||
A = make_tensor(np.random.rand(5, 7).astype("float32"), network) | |||
B = make_tensor(np.random.rand(7, 10).astype("float32"), network) | |||
C = A @ B | |||
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) | |||
if is_varnode: | |||
np.testing.assert_almost_equal( | |||
get_var_value(C), get_var_value(A) @ get_var_value(B), decimal=6 | |||
) | |||
else: | |||
np.testing.assert_almost_equal(C.numpy(), A.numpy() @ B.numpy(), decimal=6) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
@@ -116,11 +121,17 @@ def test_set_subtensor(is_varnode): | |||
x = make_tensor([1, 2, 3], network) | |||
x[:] = [1, 1, 1] | |||
np.testing.assert_almost_equal(x.numpy(), [1, 1, 1], decimal=6) | |||
np.testing.assert_almost_equal( | |||
get_var_value(x) if is_varnode else x.numpy(), [1, 1, 1], decimal=6 | |||
) | |||
x[[0, 2]] = [3, 2] | |||
np.testing.assert_almost_equal(x.numpy(), [3, 1, 2], decimal=6) | |||
np.testing.assert_almost_equal( | |||
get_var_value(x) if is_varnode else x.numpy(), [3, 1, 2], decimal=6 | |||
) | |||
x[1:3] = [4, 5] | |||
np.testing.assert_almost_equal(x.numpy(), [3, 4, 5], decimal=6) | |||
np.testing.assert_almost_equal( | |||
get_var_value(x) if is_varnode else x.numpy(), [3, 4, 5], decimal=6 | |||
) | |||
def test_computing_with_numpy_array(): | |||
@@ -11,7 +11,7 @@ import platform | |||
import numpy as np | |||
import pytest | |||
from utils import make_tensor, opr_test | |||
from utils import get_var_value, make_tensor, opr_test | |||
import megengine.functional as F | |||
from megengine import tensor | |||
@@ -75,8 +75,12 @@ def test_condtake(is_varnode): | |||
xx = make_tensor(x, network) | |||
yy = make_tensor(y, network) | |||
val, idx = F.cond_take(yy, xx) | |||
np.testing.assert_equal(val.numpy(), x[y]) | |||
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | |||
if is_varnode: | |||
np.testing.assert_equal(get_var_value(val), x[y]) | |||
np.testing.assert_equal(get_var_value(idx), np.where(y.reshape(-1))[0]) | |||
else: | |||
np.testing.assert_equal(val.numpy(), x[y]) | |||
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||