Browse Source

feat(mge/functional): add repeat and tile opr

GitOrigin-RevId: a20d4b6fb0
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
fa4bf16800
11 changed files with 214 additions and 17 deletions
  1. +2
    -2
      imperative/python/megengine/autodiff/grad_manager.py
  2. +1
    -1
      imperative/python/megengine/core/tensor/indexing.py
  3. +4
    -4
      imperative/python/megengine/core/tensor/utils.py
  4. +1
    -1
      imperative/python/megengine/distributed/helper.py
  5. +2
    -2
      imperative/python/megengine/functional/inplace.py
  6. +143
    -0
      imperative/python/megengine/functional/tensor.py
  7. +0
    -4
      imperative/python/megengine/tensor.py
  8. +8
    -2
      imperative/python/src/tensor.cpp
  9. +1
    -0
      imperative/python/src/tensor.h
  10. +50
    -0
      imperative/python/test/unit/functional/test_tensor.py
  11. +2
    -1
      imperative/python/test/unit/test_tracing.py

+ 2
- 2
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -279,8 +279,8 @@ class GradManager:
tensor.grad = grad tensor.grad = grad
else: else:
tensor.grad += grad tensor.grad += grad
if tensor.isscalar() and tensor.grad is not None:
tensor.grad.setscalar()
if tensor._isscalar() and tensor.grad is not None:
tensor.grad._setscalar()
finally: finally:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache


+ 1
- 1
imperative/python/megengine/core/tensor/indexing.py View File

@@ -225,7 +225,7 @@ def getitem(tensor, index):
op = builtin.IndexingMultiAxisVec(items=items) op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors) (result,) = apply(op, tensor, *tensors)
if ret_scalar: if ret_scalar:
result.setscalar()
result._setscalar()
return result return result






+ 4
- 4
imperative/python/megengine/core/tensor/utils.py View File

@@ -51,10 +51,10 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype): def astype(x, dtype):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if not is_dtype_equal(x.dtype, dtype): if not is_dtype_equal(x.dtype, dtype):
isscalar = x.isscalar()
isscalar = x._isscalar()
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) (x,) = apply(builtin.TypeCvt(dtype=dtype), x)
if isscalar: if isscalar:
x.setscalar()
x._setscalar()
return x return x




@@ -98,14 +98,14 @@ def result_type(*args):
def isscalar(x): def isscalar(x):


if isinstance(x, Tensor): if isinstance(x, Tensor):
return x.isscalar()
return x._isscalar()


return np.isscalar(x) return np.isscalar(x)




def setscalar(x): def setscalar(x):
if isinstance(x, Tensor): if isinstance(x, Tensor):
x.setscalar()
x._setscalar()
else: else:
raise NotImplementedError("Unsupport type {}".format(type(x))) raise NotImplementedError("Unsupport type {}".format(type(x)))




+ 1
- 1
imperative/python/megengine/distributed/helper.py View File

@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
outputs = apply(op, inp) outputs = apply(op, inp)
for s, x in zip(shapes, outputs): for s, x in zip(shapes, outputs):
if not s: if not s:
x.setscalar()
x._setscalar()
return outputs return outputs






+ 2
- 2
imperative/python/megengine/functional/inplace.py View File

@@ -12,8 +12,8 @@ from ..core.ops.builtin import InplaceAdd




def _inplace_add_(dest, delta, alpha, beta): def _inplace_add_(dest, delta, alpha, beta):
isscalar = dest.isscalar()
isscalar = dest._isscalar()
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])
if isscalar: if isscalar:
dest.setscalar()
dest._setscalar()
return dest return dest

+ 143
- 0
imperative/python/megengine/functional/tensor.py View File

@@ -44,11 +44,13 @@ __all__ = [
"linspace", "linspace",
"ones", "ones",
"ones_like", "ones_like",
"repeat",
"reshape", "reshape",
"split", "split",
"squeeze", "squeeze",
"stack", "stack",
"scatter", "scatter",
"tile",
"transpose", "transpose",
"where", "where",
"zeros", "zeros",
@@ -987,3 +989,144 @@ def arange(
if np.dtype(dtype) == np.int32: if np.dtype(dtype) == np.int32:
return result.astype(dtype) return result.astype(dtype)
return result return result


def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
"""
Repeat elements of an array.

:param inp: input tensor.
:param repeats: the number of repetitions for each element.
:param axis: the axis along which to repeat values. By default, use the
flattened input array, and return a flat output array.
:return: output tensor.

Examples:

.. testcode::

import numpy as np
import megengine.functional as F
from megengine import tensor

x = tensor([[1, 2], [3, 4]], np.int32)
y = F.repeat(x, 2, axis=0)
print(y.numpy())

Outputs:

.. testoutput::

[[1 2]
[1 2]
[3 4]
[3 4]]

"""
if axis is None:
inp = inp.reshape(-1) # flatten
axis = 0
if inp._isscalar():
inp._unsetscalar()
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
# assume inp.ndim is not changed during trace
max_axis = len(shape) - 1
assert axis >= 0 and axis <= max_axis
assert repeats >= 1

base_shape, bcast_shape, target_shape = [], [], []
if axis != 0:
target_shape.append(shape[:axis])
base_shape.extend([shape[: axis + 1], [1,]])
bcast_shape.extend([shape[: axis + 1], [repeats,]])
target_shape.extend(
[shape[axis] * repeats,]
)
if axis + 1 <= max_axis:
base_shape.append(shape[axis + 1 :])
bcast_shape.append(shape[axis + 1 :])
target_shape.append(shape[axis + 1 :])

out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
return out


def _tile_one_dim(inp, rep, axis):
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
# assume inp.ndim is not changed during trace
max_axis = len(shape) - 1

base_shape, bcast_shape, target_shape = [], [], []

if axis != 0:
base_shape.append(shape[:axis])
bcast_shape.append(shape[:axis])
target_shape.append(shape[:axis])
base_shape.extend([[1,], shape[axis:]])
bcast_shape.extend([rep, shape[axis:]])
target_shape.append(shape[axis] * rep)
if axis + 1 <= max_axis:
target_shape.append(shape[axis + 1 :])

out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
return out


def tile(inp: Tensor, reps: Iterable[int]):
"""
Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d,
the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``,
``inp`` is promoted to be ``d``-dimensional by prepending new axis.

:param inp: input tensor.
:param reps: The number of repetitions of inp along each axis.
:return: output tensor.

Examples:

.. testcode::

import numpy as np
import megengine.functional as F
from megengine import tensor

x = tensor([[1, 2], [3, 4]], np.int32)
y = F.tile(x, (2,1))
print(y.numpy())

Outputs:

.. testoutput::

[[1 2]
[3 4]
[1 2]
[3 4]]

"""
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
reps = astensor1d(reps, inp, dtype="int32", device=inp.device)
l_shape = len(shape)
l_reps = len(reps)
assert (
l_reps >= l_shape
), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor"

for i in range(l_shape):
rep = reps[i + (l_reps - l_shape)]
inp = _tile_one_dim(inp, rep, i)

if l_reps > l_shape:
shape = inp.shape
extra = reps[:-l_shape]
extra_ones = ones_like(extra)
base_shape = concat([extra_ones, shape])
bcast_shape = concat([extra, shape])
target_shape = concat([extra, shape])
inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)

return inp

+ 0
- 4
imperative/python/megengine/tensor.py View File

@@ -51,10 +51,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
cn = device._cn cn = device._cn


if isinstance(data, _Tensor): if isinstance(data, _Tensor):
if dtype is not None:
logger.warning(
"dtype does not work when creating a new Tensor with another Tensor"
)
obj = _Tensor.__new__(cls, data) obj = _Tensor.__new__(cls, data)
else: else:
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):


+ 8
- 2
imperative/python/src/tensor.cpp View File

@@ -557,6 +557,11 @@ void TensorWrapper::setscalar() {
} }




void TensorWrapper::unsetscalar() {
m_tensor->m_flags &= ~Tensor::Flags::SCALAR;
}


struct TensorWeakRef { struct TensorWeakRef {
std::weak_ptr<Tensor> wptr; std::weak_ptr<Tensor> wptr;


@@ -794,8 +799,9 @@ void init_tensor(py::module m) {
.def_getset<&TensorWrapper::dtype>("dtype") .def_getset<&TensorWrapper::dtype>("dtype")
.def_getset<&TensorWrapper::device>("device") .def_getset<&TensorWrapper::device>("device")
.def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("isscalar")
.def<&TensorWrapper::setscalar>("setscalar")
.def<&TensorWrapper::isscalar>("_isscalar")
.def<&TensorWrapper::setscalar>("_setscalar")
.def<&TensorWrapper::unsetscalar>("_unsetscalar")
.def<&TensorWrapper::detach>("detach") .def<&TensorWrapper::detach>("detach")
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
.def<&TensorWrapper::_swap_out>("_swap_out") .def<&TensorWrapper::_swap_out>("_swap_out")


+ 1
- 0
imperative/python/src/tensor.h View File

@@ -153,6 +153,7 @@ struct TensorWrapper {
PyObject* detach(); PyObject* detach();
PyObject* isscalar(); PyObject* isscalar();
void setscalar(); void setscalar();
void unsetscalar();
PyObject* _dev_tensor(); PyObject* _dev_tensor();
void _swap_in(); void _swap_in();
void _swap_out(); void _swap_out();


+ 50
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -406,3 +406,53 @@ def test_copy_d2h():
def test_copy_d2d(): def test_copy_d2d():
copy_test("gpu0", "gpu1") copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1") copy_test("gpu0:0", "gpu0:1")


@pytest.mark.parametrize(
"shape, repeats, axis",
[
((2,), 2, 0),
((2, 3, 4, 5), 3, 0),
((2, 3, 4, 5), 4, 3),
((2,), 2, None),
((2, 3, 4, 5), 3, None),
((), 1, None),
((), 10, None),
],
)
def test_repeat(shape, repeats, axis):
def repeat_func(inp):
return F.repeat(inp=inp, repeats=repeats, axis=axis)

if shape != ():
cases = [
{"input": np.random.randn(*shape).astype("float32")},
]
else:
cases = [{"input": np.array(1.23)}]

opr_test(
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis),
)


@pytest.mark.parametrize(
"shape, reps",
[
((2,), (2,)),
((2, 3, 4, 5), (1, 1, 1, 1)),
((2, 3, 4, 5), (1, 2, 3, 4)),
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
],
)
def test_tile(shape, reps):
def tile_func(inp):
return F.tile(inp=inp, reps=reps)

cases = [
{"input": np.random.randn(*shape).astype("float32")},
]

opr_test(
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps),
)

+ 2
- 1
imperative/python/test/unit/test_tracing.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io import io
import itertools
from tempfile import mkstemp from tempfile import mkstemp


import numpy as np import numpy as np
@@ -359,7 +360,7 @@ def test_trace_warp_perspective():
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
return out return out


for i in range(1):
for i in range(3):
f(x, M) f(x, M)






Loading…
Cancel
Save