Browse Source

feat(mge/tensor): support non-Tensor value in `_reset` and remove depreciated tests

GitOrigin-RevId: faf6c78aa8
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
e6b06914f7
4 changed files with 19 additions and 101 deletions
  1. +9
    -10
      imperative/python/megengine/core/tensor/array_method.py
  2. +2
    -2
      imperative/python/megengine/tensor.py
  3. +8
    -0
      imperative/python/test/unit/core/test_tensor_wrapper.py
  4. +0
    -89
      imperative/python/test/unit/module/test_module_tensor.py

+ 9
- 10
imperative/python/megengine/core/tensor/array_method.py View File

@@ -16,7 +16,6 @@ from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const
from . import utils
from .indexing import getitem as _getitem
from .indexing import setitem as _setitem
@@ -373,7 +372,7 @@ class ArrayMethodMixin(abc.ABC):
@property
def size(self):
r"""
Returns the size of the self :class:`~.Tensor`.
Returns the size of the self :class:`~.Tensor`.
The returned value is a subclass of :class:`tuple`.
"""
shape = self.shape
@@ -390,7 +389,7 @@ class ArrayMethodMixin(abc.ABC):

def item(self, *args):
r"""
Returns the value of this :class:`~.Tensor` as a standard Python :class:`numbers.Number`.
Returns the value of this :class:`~.Tensor` as a standard Python :class:`numbers.Number`.
This only works for tensors with one element. For other cases, see :meth:`~.tolist`.
"""
if not args:
@@ -401,8 +400,8 @@ class ArrayMethodMixin(abc.ABC):

def tolist(self):
r"""
Returns the tensor as a (nested) list.
For scalars, a standard Python number is returned, just like with :meth:`~.item`.
Returns the tensor as a (nested) list.
For scalars, a standard Python number is returned, just like with :meth:`~.item`.
Tensors are automatically moved to the CPU first if necessary.

This operation is not differentiable.
@@ -450,7 +449,7 @@ class ArrayMethodMixin(abc.ABC):
def sum(self, axis=None, keepdims: bool = False):
r"""
Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
@@ -483,7 +482,7 @@ class ArrayMethodMixin(abc.ABC):
def prod(self, axis=None, keepdims: bool = False):
r"""
Returns the product of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
@@ -516,7 +515,7 @@ class ArrayMethodMixin(abc.ABC):
def min(self, axis=None, keepdims: bool = False):
r"""
Returns the min value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
@@ -549,7 +548,7 @@ class ArrayMethodMixin(abc.ABC):
def max(self, axis=None, keepdims: bool = False):
r"""
Returns the max value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.
@@ -582,7 +581,7 @@ class ArrayMethodMixin(abc.ABC):
def mean(self, axis=None, keepdims: bool = False):
r"""
Returns the mean value of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor,
except in the dimension(s) ``axis`` where it is of size 1.


+ 2
- 2
imperative/python/megengine/tensor.py View File

@@ -119,6 +119,8 @@ class Tensor(_Tensor, ArrayMethodMixin):
return super().detach()

def _reset(self, other):
if not isinstance(other, _Tensor):
other = Tensor(other, dtype=self.dtype, device=self.device)
super()._reset(other)

def __repr__(self):
@@ -141,8 +143,6 @@ class Tensor(_Tensor, ArrayMethodMixin):

@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value):
if not isinstance(value, _Tensor):
value = Tensor(value, dtype=self.dtype, device=self.device)
self._reset(value)

@deprecated(version="1.0", reason="use *= 0 instead")


+ 8
- 0
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -50,6 +50,14 @@ def test_reduce():
test_x(np.array([True, False, True]))


def test_set_value():
v0 = np.random.random((2, 3)).astype(np.float32)
param = Tensor(v0)
v1 = np.random.random((2, 3)).astype(np.float32)
param[...] = v1
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)


def test_set_subtensor():
x = Tensor([1, 2, 3])
x[:] = [1, 1, 1]


+ 0
- 89
imperative/python/test/unit/module/test_module_tensor.py View File

@@ -1,89 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy

import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
from megengine import Parameter, Tensor
from megengine.module import Conv2d


# TODO: delete this test after deleting set_value
def test_set_value():
v0 = np.random.random((2, 3)).astype(np.float32)
param = Parameter(v0)
v1 = np.random.random((2, 3)).astype(np.float32)
param.set_value(v1)
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)
v2 = np.random.random((3, 3)).astype(np.float32)
# TODO: add this
# with pytest.raises(ValueError):
# param.set_value(v2)
np.testing.assert_allclose(param.numpy(), v1, atol=5e-6)


@pytest.mark.skip(reason="fill unsupported")
def test_fill():
a = Tensor(np.zeros((2, 3), dtype=np.float32))
a.fill(3)
np.testing.assert_allclose(a.numpy(), np.full((2, 3), 3, dtype=np.float32))
a.fill(124.568)
np.testing.assert_allclose(a.numpy(), np.full((2, 3), 124.568, dtype=np.float32))


# TODO: remove or rewrite following test
# def test_attach():
# p_ = np.random.random((2, 3)).astype(np.float32)

# with Graph() as g:
# g.set_option('eager_evaluation', False)
# p = Parameter(p_)
# v = p * 2
# f = compile(v, None)

# out, = f()
# np.testing.assert_allclose(out, p_ * 2)

# F.add_update(p, p)
# out, = f()
# np.testing.assert_allclose(out, p_ * 4)


# TODO: remove or rewrite following test
# def test_module_attach():
# v = np.random.random((1, 3, 64, 64)).astype(np.float32)
# net = Conv2d(3, 16, 3)

# with Graph() as g:
# g.set_option('eager_evaluation', False)

# data0 = Input("data")
# f = compile(net(data0), None)

# out0, = f(data=v)

# data1 = Input("data", value=v)
# out1 = net(data1)

# np.testing.assert_allclose(out0, out1.numpy())


# def test_shape_warning():
# with Graph() as cg:
# cg.set_option("eager_evaluation", False)
# b = Tensor(np.ones((2, 3)).astype(np.float32))
# with pytest.warns(None) as record:
# print(b.shape)
# if len(record) != 0:
# raise ValueError(
# "Getting the shape of a constant Tensor should throw no Warning"
# )

Loading…
Cancel
Save