@@ -279,8 +279,8 @@ class GradManager: | |||
tensor.grad = grad | |||
else: | |||
tensor.grad += grad | |||
if tensor.isscalar() and tensor.grad is not None: | |||
tensor.grad.setscalar() | |||
if tensor._isscalar() and tensor.grad is not None: | |||
tensor.grad._setscalar() | |||
finally: | |||
self.release() | |||
backwarding_grad_manager = cache | |||
@@ -225,7 +225,7 @@ def getitem(tensor, index): | |||
op = builtin.IndexingMultiAxisVec(items=items) | |||
(result,) = apply(op, tensor, *tensors) | |||
if ret_scalar: | |||
result.setscalar() | |||
result._setscalar() | |||
return result | |||
@@ -51,10 +51,10 @@ def concatenate(inputs, axis=0, *, device=None): | |||
def astype(x, dtype): | |||
dtype = np.dtype(dtype) | |||
if not is_dtype_equal(x.dtype, dtype): | |||
isscalar = x.isscalar() | |||
isscalar = x._isscalar() | |||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
if isscalar: | |||
x.setscalar() | |||
x._setscalar() | |||
return x | |||
@@ -98,14 +98,14 @@ def result_type(*args): | |||
def isscalar(x): | |||
if isinstance(x, Tensor): | |||
return x.isscalar() | |||
return x._isscalar() | |||
return np.isscalar(x) | |||
def setscalar(x): | |||
if isinstance(x, Tensor): | |||
x.setscalar() | |||
x._setscalar() | |||
else: | |||
raise NotImplementedError("Unsupport type {}".format(type(x))) | |||
@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||
outputs = apply(op, inp) | |||
for s, x in zip(shapes, outputs): | |||
if not s: | |||
x.setscalar() | |||
x._setscalar() | |||
return outputs | |||
@@ -12,8 +12,8 @@ from ..core.ops.builtin import InplaceAdd | |||
def _inplace_add_(dest, delta, alpha, beta): | |||
isscalar = dest.isscalar() | |||
isscalar = dest._isscalar() | |||
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | |||
if isscalar: | |||
dest.setscalar() | |||
dest._setscalar() | |||
return dest |
@@ -44,11 +44,13 @@ __all__ = [ | |||
"linspace", | |||
"ones", | |||
"ones_like", | |||
"repeat", | |||
"reshape", | |||
"split", | |||
"squeeze", | |||
"stack", | |||
"scatter", | |||
"tile", | |||
"transpose", | |||
"where", | |||
"zeros", | |||
@@ -987,3 +989,144 @@ def arange( | |||
if np.dtype(dtype) == np.int32: | |||
return result.astype(dtype) | |||
return result | |||
def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): | |||
""" | |||
Repeat elements of an array. | |||
:param inp: input tensor. | |||
:param repeats: the number of repetitions for each element. | |||
:param axis: the axis along which to repeat values. By default, use the | |||
flattened input array, and return a flat output array. | |||
:return: output tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
from megengine import tensor | |||
x = tensor([[1, 2], [3, 4]], np.int32) | |||
y = F.repeat(x, 2, axis=0) | |||
print(y.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1 2] | |||
[1 2] | |||
[3 4] | |||
[3 4]] | |||
""" | |||
if axis is None: | |||
inp = inp.reshape(-1) # flatten | |||
axis = 0 | |||
if inp._isscalar(): | |||
inp._unsetscalar() | |||
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||
# assume inp.ndim is not changed during trace | |||
max_axis = len(shape) - 1 | |||
assert axis >= 0 and axis <= max_axis | |||
assert repeats >= 1 | |||
base_shape, bcast_shape, target_shape = [], [], [] | |||
if axis != 0: | |||
target_shape.append(shape[:axis]) | |||
base_shape.extend([shape[: axis + 1], [1,]]) | |||
bcast_shape.extend([shape[: axis + 1], [repeats,]]) | |||
target_shape.extend( | |||
[shape[axis] * repeats,] | |||
) | |||
if axis + 1 <= max_axis: | |||
base_shape.append(shape[axis + 1 :]) | |||
bcast_shape.append(shape[axis + 1 :]) | |||
target_shape.append(shape[axis + 1 :]) | |||
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( | |||
concat(target_shape) | |||
) | |||
return out | |||
def _tile_one_dim(inp, rep, axis): | |||
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||
# assume inp.ndim is not changed during trace | |||
max_axis = len(shape) - 1 | |||
base_shape, bcast_shape, target_shape = [], [], [] | |||
if axis != 0: | |||
base_shape.append(shape[:axis]) | |||
bcast_shape.append(shape[:axis]) | |||
target_shape.append(shape[:axis]) | |||
base_shape.extend([[1,], shape[axis:]]) | |||
bcast_shape.extend([rep, shape[axis:]]) | |||
target_shape.append(shape[axis] * rep) | |||
if axis + 1 <= max_axis: | |||
target_shape.append(shape[axis + 1 :]) | |||
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( | |||
concat(target_shape) | |||
) | |||
return out | |||
def tile(inp: Tensor, reps: Iterable[int]): | |||
""" | |||
Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d, | |||
the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``, | |||
``inp`` is promoted to be ``d``-dimensional by prepending new axis. | |||
:param inp: input tensor. | |||
:param reps: The number of repetitions of inp along each axis. | |||
:return: output tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine.functional as F | |||
from megengine import tensor | |||
x = tensor([[1, 2], [3, 4]], np.int32) | |||
y = F.tile(x, (2,1)) | |||
print(y.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1 2] | |||
[3 4] | |||
[1 2] | |||
[3 4]] | |||
""" | |||
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||
reps = astensor1d(reps, inp, dtype="int32", device=inp.device) | |||
l_shape = len(shape) | |||
l_reps = len(reps) | |||
assert ( | |||
l_reps >= l_shape | |||
), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor" | |||
for i in range(l_shape): | |||
rep = reps[i + (l_reps - l_shape)] | |||
inp = _tile_one_dim(inp, rep, i) | |||
if l_reps > l_shape: | |||
shape = inp.shape | |||
extra = reps[:-l_shape] | |||
extra_ones = ones_like(extra) | |||
base_shape = concat([extra_ones, shape]) | |||
bcast_shape = concat([extra, shape]) | |||
target_shape = concat([extra, shape]) | |||
inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) | |||
return inp |
@@ -51,10 +51,6 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
cn = device._cn | |||
if isinstance(data, _Tensor): | |||
if dtype is not None: | |||
logger.warning( | |||
"dtype does not work when creating a new Tensor with another Tensor" | |||
) | |||
obj = _Tensor.__new__(cls, data) | |||
else: | |||
if isinstance(data, np.ndarray): | |||
@@ -557,6 +557,11 @@ void TensorWrapper::setscalar() { | |||
} | |||
void TensorWrapper::unsetscalar() { | |||
m_tensor->m_flags &= ~Tensor::Flags::SCALAR; | |||
} | |||
struct TensorWeakRef { | |||
std::weak_ptr<Tensor> wptr; | |||
@@ -794,8 +799,9 @@ void init_tensor(py::module m) { | |||
.def_getset<&TensorWrapper::dtype>("dtype") | |||
.def_getset<&TensorWrapper::device>("device") | |||
.def<&TensorWrapper::reset>("_reset") | |||
.def<&TensorWrapper::isscalar>("isscalar") | |||
.def<&TensorWrapper::setscalar>("setscalar") | |||
.def<&TensorWrapper::isscalar>("_isscalar") | |||
.def<&TensorWrapper::setscalar>("_setscalar") | |||
.def<&TensorWrapper::unsetscalar>("_unsetscalar") | |||
.def<&TensorWrapper::detach>("detach") | |||
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") | |||
.def<&TensorWrapper::_swap_out>("_swap_out") | |||
@@ -153,6 +153,7 @@ struct TensorWrapper { | |||
PyObject* detach(); | |||
PyObject* isscalar(); | |||
void setscalar(); | |||
void unsetscalar(); | |||
PyObject* _dev_tensor(); | |||
void _swap_in(); | |||
void _swap_out(); | |||
@@ -406,3 +406,53 @@ def test_copy_d2h(): | |||
def test_copy_d2d(): | |||
copy_test("gpu0", "gpu1") | |||
copy_test("gpu0:0", "gpu0:1") | |||
@pytest.mark.parametrize( | |||
"shape, repeats, axis", | |||
[ | |||
((2,), 2, 0), | |||
((2, 3, 4, 5), 3, 0), | |||
((2, 3, 4, 5), 4, 3), | |||
((2,), 2, None), | |||
((2, 3, 4, 5), 3, None), | |||
((), 1, None), | |||
((), 10, None), | |||
], | |||
) | |||
def test_repeat(shape, repeats, axis): | |||
def repeat_func(inp): | |||
return F.repeat(inp=inp, repeats=repeats, axis=axis) | |||
if shape != (): | |||
cases = [ | |||
{"input": np.random.randn(*shape).astype("float32")}, | |||
] | |||
else: | |||
cases = [{"input": np.array(1.23)}] | |||
opr_test( | |||
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||
) | |||
@pytest.mark.parametrize( | |||
"shape, reps", | |||
[ | |||
((2,), (2,)), | |||
((2, 3, 4, 5), (1, 1, 1, 1)), | |||
((2, 3, 4, 5), (1, 2, 3, 4)), | |||
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||
], | |||
) | |||
def test_tile(shape, reps): | |||
def tile_func(inp): | |||
return F.tile(inp=inp, reps=reps) | |||
cases = [ | |||
{"input": np.random.randn(*shape).astype("float32")}, | |||
] | |||
opr_test( | |||
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), | |||
) |
@@ -7,6 +7,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import io | |||
import itertools | |||
from tempfile import mkstemp | |||
import numpy as np | |||
@@ -359,7 +360,7 @@ def test_trace_warp_perspective(): | |||
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) | |||
return out | |||
for i in range(1): | |||
for i in range(3): | |||
f(x, M) | |||