@@ -279,8 +279,8 @@ class GradManager: | |||||
tensor.grad = grad | tensor.grad = grad | ||||
else: | else: | ||||
tensor.grad += grad | tensor.grad += grad | ||||
if tensor.isscalar() and tensor.grad is not None: | |||||
tensor.grad.setscalar() | |||||
if tensor._isscalar() and tensor.grad is not None: | |||||
tensor.grad._setscalar() | |||||
finally: | finally: | ||||
self.release() | self.release() | ||||
backwarding_grad_manager = cache | backwarding_grad_manager = cache | ||||
@@ -225,7 +225,7 @@ def getitem(tensor, index): | |||||
op = builtin.IndexingMultiAxisVec(items=items) | op = builtin.IndexingMultiAxisVec(items=items) | ||||
(result,) = apply(op, tensor, *tensors) | (result,) = apply(op, tensor, *tensors) | ||||
if ret_scalar: | if ret_scalar: | ||||
result.setscalar() | |||||
result._setscalar() | |||||
return result | return result | ||||
@@ -51,10 +51,10 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
def astype(x, dtype): | def astype(x, dtype): | ||||
dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
if not is_dtype_equal(x.dtype, dtype): | if not is_dtype_equal(x.dtype, dtype): | ||||
isscalar = x.isscalar() | |||||
isscalar = x._isscalar() | |||||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | ||||
if isscalar: | if isscalar: | ||||
x.setscalar() | |||||
x._setscalar() | |||||
return x | return x | ||||
@@ -98,14 +98,14 @@ def result_type(*args): | |||||
def isscalar(x): | def isscalar(x): | ||||
if isinstance(x, Tensor): | if isinstance(x, Tensor): | ||||
return x.isscalar() | |||||
return x._isscalar() | |||||
return np.isscalar(x) | return np.isscalar(x) | ||||
def setscalar(x): | def setscalar(x): | ||||
if isinstance(x, Tensor): | if isinstance(x, Tensor): | ||||
x.setscalar() | |||||
x._setscalar() | |||||
else: | else: | ||||
raise NotImplementedError("Unsupport type {}".format(type(x))) | raise NotImplementedError("Unsupport type {}".format(type(x))) | ||||
@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||||
outputs = apply(op, inp) | outputs = apply(op, inp) | ||||
for s, x in zip(shapes, outputs): | for s, x in zip(shapes, outputs): | ||||
if not s: | if not s: | ||||
x.setscalar() | |||||
x._setscalar() | |||||
return outputs | return outputs | ||||
@@ -12,8 +12,8 @@ from ..core.ops.builtin import InplaceAdd | |||||
def _inplace_add_(dest, delta, alpha, beta): | def _inplace_add_(dest, delta, alpha, beta): | ||||
isscalar = dest.isscalar() | |||||
isscalar = dest._isscalar() | |||||
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | ||||
if isscalar: | if isscalar: | ||||
dest.setscalar() | |||||
dest._setscalar() | |||||
return dest | return dest |
@@ -44,11 +44,13 @@ __all__ = [ | |||||
"linspace", | "linspace", | ||||
"ones", | "ones", | ||||
"ones_like", | "ones_like", | ||||
"repeat", | |||||
"reshape", | "reshape", | ||||
"split", | "split", | ||||
"squeeze", | "squeeze", | ||||
"stack", | "stack", | ||||
"scatter", | "scatter", | ||||
"tile", | |||||
"transpose", | "transpose", | ||||
"where", | "where", | ||||
"zeros", | "zeros", | ||||
@@ -987,3 +989,144 @@ def arange( | |||||
if np.dtype(dtype) == np.int32: | if np.dtype(dtype) == np.int32: | ||||
return result.astype(dtype) | return result.astype(dtype) | ||||
return result | return result | ||||
def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): | |||||
""" | |||||
Repeat elements of an array. | |||||
:param inp: input tensor. | |||||
:param repeats: the number of repetitions for each element. | |||||
:param axis: the axis along which to repeat values. By default, use the | |||||
flattened input array, and return a flat output array. | |||||
:return: output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine.functional as F | |||||
from megengine import tensor | |||||
x = tensor([[1, 2], [3, 4]], np.int32) | |||||
y = F.repeat(x, 2, axis=0) | |||||
print(y.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[[1 2] | |||||
[1 2] | |||||
[3 4] | |||||
[3 4]] | |||||
""" | |||||
if axis is None: | |||||
inp = inp.reshape(-1) # flatten | |||||
axis = 0 | |||||
if inp._isscalar(): | |||||
inp._unsetscalar() | |||||
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||||
# assume inp.ndim is not changed during trace | |||||
max_axis = len(shape) - 1 | |||||
assert axis >= 0 and axis <= max_axis | |||||
assert repeats >= 1 | |||||
base_shape, bcast_shape, target_shape = [], [], [] | |||||
if axis != 0: | |||||
target_shape.append(shape[:axis]) | |||||
base_shape.extend([shape[: axis + 1], [1,]]) | |||||
bcast_shape.extend([shape[: axis + 1], [repeats,]]) | |||||
target_shape.extend( | |||||
[shape[axis] * repeats,] | |||||
) | |||||
if axis + 1 <= max_axis: | |||||
base_shape.append(shape[axis + 1 :]) | |||||
bcast_shape.append(shape[axis + 1 :]) | |||||
target_shape.append(shape[axis + 1 :]) | |||||
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( | |||||
concat(target_shape) | |||||
) | |||||
return out | |||||
def _tile_one_dim(inp, rep, axis): | |||||
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||||
# assume inp.ndim is not changed during trace | |||||
max_axis = len(shape) - 1 | |||||
base_shape, bcast_shape, target_shape = [], [], [] | |||||
if axis != 0: | |||||
base_shape.append(shape[:axis]) | |||||
bcast_shape.append(shape[:axis]) | |||||
target_shape.append(shape[:axis]) | |||||
base_shape.extend([[1,], shape[axis:]]) | |||||
bcast_shape.extend([rep, shape[axis:]]) | |||||
target_shape.append(shape[axis] * rep) | |||||
if axis + 1 <= max_axis: | |||||
target_shape.append(shape[axis + 1 :]) | |||||
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( | |||||
concat(target_shape) | |||||
) | |||||
return out | |||||
def tile(inp: Tensor, reps: Iterable[int]): | |||||
""" | |||||
Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d, | |||||
the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``, | |||||
``inp`` is promoted to be ``d``-dimensional by prepending new axis. | |||||
:param inp: input tensor. | |||||
:param reps: The number of repetitions of inp along each axis. | |||||
:return: output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
import megengine.functional as F | |||||
from megengine import tensor | |||||
x = tensor([[1, 2], [3, 4]], np.int32) | |||||
y = F.tile(x, (2,1)) | |||||
print(y.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[[1 2] | |||||
[3 4] | |||||
[1 2] | |||||
[3 4]] | |||||
""" | |||||
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device) | |||||
reps = astensor1d(reps, inp, dtype="int32", device=inp.device) | |||||
l_shape = len(shape) | |||||
l_reps = len(reps) | |||||
assert ( | |||||
l_reps >= l_shape | |||||
), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor" | |||||
for i in range(l_shape): | |||||
rep = reps[i + (l_reps - l_shape)] | |||||
inp = _tile_one_dim(inp, rep, i) | |||||
if l_reps > l_shape: | |||||
shape = inp.shape | |||||
extra = reps[:-l_shape] | |||||
extra_ones = ones_like(extra) | |||||
base_shape = concat([extra_ones, shape]) | |||||
bcast_shape = concat([extra, shape]) | |||||
target_shape = concat([extra, shape]) | |||||
inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) | |||||
return inp |
@@ -51,10 +51,6 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
cn = device._cn | cn = device._cn | ||||
if isinstance(data, _Tensor): | if isinstance(data, _Tensor): | ||||
if dtype is not None: | |||||
logger.warning( | |||||
"dtype does not work when creating a new Tensor with another Tensor" | |||||
) | |||||
obj = _Tensor.__new__(cls, data) | obj = _Tensor.__new__(cls, data) | ||||
else: | else: | ||||
if isinstance(data, np.ndarray): | if isinstance(data, np.ndarray): | ||||
@@ -557,6 +557,11 @@ void TensorWrapper::setscalar() { | |||||
} | } | ||||
void TensorWrapper::unsetscalar() { | |||||
m_tensor->m_flags &= ~Tensor::Flags::SCALAR; | |||||
} | |||||
struct TensorWeakRef { | struct TensorWeakRef { | ||||
std::weak_ptr<Tensor> wptr; | std::weak_ptr<Tensor> wptr; | ||||
@@ -794,8 +799,9 @@ void init_tensor(py::module m) { | |||||
.def_getset<&TensorWrapper::dtype>("dtype") | .def_getset<&TensorWrapper::dtype>("dtype") | ||||
.def_getset<&TensorWrapper::device>("device") | .def_getset<&TensorWrapper::device>("device") | ||||
.def<&TensorWrapper::reset>("_reset") | .def<&TensorWrapper::reset>("_reset") | ||||
.def<&TensorWrapper::isscalar>("isscalar") | |||||
.def<&TensorWrapper::setscalar>("setscalar") | |||||
.def<&TensorWrapper::isscalar>("_isscalar") | |||||
.def<&TensorWrapper::setscalar>("_setscalar") | |||||
.def<&TensorWrapper::unsetscalar>("_unsetscalar") | |||||
.def<&TensorWrapper::detach>("detach") | .def<&TensorWrapper::detach>("detach") | ||||
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") | .def<&TensorWrapper::_dev_tensor>("_dev_tensor") | ||||
.def<&TensorWrapper::_swap_out>("_swap_out") | .def<&TensorWrapper::_swap_out>("_swap_out") | ||||
@@ -153,6 +153,7 @@ struct TensorWrapper { | |||||
PyObject* detach(); | PyObject* detach(); | ||||
PyObject* isscalar(); | PyObject* isscalar(); | ||||
void setscalar(); | void setscalar(); | ||||
void unsetscalar(); | |||||
PyObject* _dev_tensor(); | PyObject* _dev_tensor(); | ||||
void _swap_in(); | void _swap_in(); | ||||
void _swap_out(); | void _swap_out(); | ||||
@@ -406,3 +406,53 @@ def test_copy_d2h(): | |||||
def test_copy_d2d(): | def test_copy_d2d(): | ||||
copy_test("gpu0", "gpu1") | copy_test("gpu0", "gpu1") | ||||
copy_test("gpu0:0", "gpu0:1") | copy_test("gpu0:0", "gpu0:1") | ||||
@pytest.mark.parametrize( | |||||
"shape, repeats, axis", | |||||
[ | |||||
((2,), 2, 0), | |||||
((2, 3, 4, 5), 3, 0), | |||||
((2, 3, 4, 5), 4, 3), | |||||
((2,), 2, None), | |||||
((2, 3, 4, 5), 3, None), | |||||
((), 1, None), | |||||
((), 10, None), | |||||
], | |||||
) | |||||
def test_repeat(shape, repeats, axis): | |||||
def repeat_func(inp): | |||||
return F.repeat(inp=inp, repeats=repeats, axis=axis) | |||||
if shape != (): | |||||
cases = [ | |||||
{"input": np.random.randn(*shape).astype("float32")}, | |||||
] | |||||
else: | |||||
cases = [{"input": np.array(1.23)}] | |||||
opr_test( | |||||
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"shape, reps", | |||||
[ | |||||
((2,), (2,)), | |||||
((2, 3, 4, 5), (1, 1, 1, 1)), | |||||
((2, 3, 4, 5), (1, 2, 3, 4)), | |||||
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||||
], | |||||
) | |||||
def test_tile(shape, reps): | |||||
def tile_func(inp): | |||||
return F.tile(inp=inp, reps=reps) | |||||
cases = [ | |||||
{"input": np.random.randn(*shape).astype("float32")}, | |||||
] | |||||
opr_test( | |||||
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), | |||||
) |
@@ -7,6 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import io | import io | ||||
import itertools | |||||
from tempfile import mkstemp | from tempfile import mkstemp | ||||
import numpy as np | import numpy as np | ||||
@@ -359,7 +360,7 @@ def test_trace_warp_perspective(): | |||||
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) | np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) | ||||
return out | return out | ||||
for i in range(1): | |||||
for i in range(3): | |||||
f(x, M) | f(x, M) | ||||