@@ -411,7 +411,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. | r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. | ||||
If ``axis`` is a list of axises, reduce over all of them. | If ``axis`` is a list of axises, reduce over all of them. | ||||
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`). | |||||
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.squeeze`). | |||||
Same for prod/mean/max/min. | Same for prod/mean/max/min. | ||||
@@ -19,7 +19,7 @@ from ..core.tensor import utils | |||||
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | from ..core.tensor.core import TensorBase, TensorWrapperBase, apply | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .elemwise import clip, exp, log, log1p | from .elemwise import clip, exp, log, log1p | ||||
from .tensor import add_axis, remove_axis, reshape | |||||
from .tensor import reshape, squeeze | |||||
__all__ = [ | __all__ = [ | ||||
"argmax", | "argmax", | ||||
@@ -459,7 +459,7 @@ def argmin( | |||||
(inp,) = apply(op, inp) | (inp,) = apply(op, inp) | ||||
if not keepdims: | if not keepdims: | ||||
inp = remove_axis(inp, ai) | |||||
inp = squeeze(inp, ai) | |||||
return inp | return inp | ||||
@@ -471,7 +471,7 @@ def argmin( | |||||
op = builtin.Argmin(axis=axis) | op = builtin.Argmin(axis=axis) | ||||
(result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
if not keepdims: | if not keepdims: | ||||
result = remove_axis(result, axis) | |||||
result = squeeze(result, axis) | |||||
return result | return result | ||||
@@ -517,7 +517,7 @@ def argmax( | |||||
(inp,) = apply(op, inp) | (inp,) = apply(op, inp) | ||||
if not keepdims: | if not keepdims: | ||||
inp = remove_axis(inp, ai) | |||||
inp = squeeze(inp, ai) | |||||
return inp | return inp | ||||
@@ -529,7 +529,7 @@ def argmax( | |||||
op = builtin.Argmax(axis=axis) | op = builtin.Argmax(axis=axis) | ||||
(result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
if not keepdims: | if not keepdims: | ||||
result = remove_axis(result, axis) | |||||
result = squeeze(result, axis) | |||||
return result | return result | ||||
@@ -23,7 +23,7 @@ from .debug_param import get_conv_execution_strategy | |||||
from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | ||||
from .math import argsort, max, sum | from .math import argsort, max, sum | ||||
from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshape, zeros | |||||
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros | |||||
from .types import _pair, _pair_nonzero | from .types import _pair, _pair_nonzero | ||||
__all__ = [ | __all__ = [ | ||||
@@ -542,7 +542,7 @@ def logsumexp( | |||||
if keepdims: | if keepdims: | ||||
return max_value + log(sum(exp(inp - max_value), axis, keepdims)) | return max_value + log(sum(exp(inp - max_value), axis, keepdims)) | ||||
else: | else: | ||||
return remove_axis(max_value, axis=None) + log( | |||||
return squeeze(max_value, axis=None) + log( | |||||
sum(exp(inp - max_value), axis, keepdims) | sum(exp(inp - max_value), axis, keepdims) | ||||
) | ) | ||||
@@ -640,7 +640,7 @@ def batch_norm2d( | |||||
def expand_or_full(x, value): | def expand_or_full(x, value): | ||||
if x is None: | if x is None: | ||||
return full_value(value) | return full_value(value) | ||||
return add_axis(x, [0, 2, 3]) | |||||
return expand_dims(x, [0, 2, 3]) | |||||
def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
if x is None: | if x is None: | ||||
@@ -998,10 +998,10 @@ def matmul( | |||||
else: | else: | ||||
if dim1 == 1: | if dim1 == 1: | ||||
shp = (inp2.shape[1],) | shp = (inp2.shape[1],) | ||||
inp1 = add_axis(inp1, 0) | |||||
inp1 = expand_dims(inp1, 0) | |||||
if dim2 == 1: | if dim2 == 1: | ||||
shp = (inp1.shape[0],) | shp = (inp1.shape[0],) | ||||
inp2 = add_axis(inp2, 1) | |||||
inp2 = expand_dims(inp2, 1) | |||||
op = builtin.MatrixMul( | op = builtin.MatrixMul( | ||||
transposeA=transpose_a, | transposeA=transpose_a, | ||||
transposeB=transpose_b, | transposeB=transpose_b, | ||||
@@ -1135,7 +1135,7 @@ def interpolate( | |||||
align_corners = False | align_corners = False | ||||
if mode == "LINEAR": | if mode == "LINEAR": | ||||
inp = add_axis(inp, 3) | |||||
inp = expand_dims(inp, 3) | |||||
if inp.ndim != 4: | if inp.ndim != 4: | ||||
raise ValueError("shape of input tensor must correspond to the operartion mode") | raise ValueError("shape of input tensor must correspond to the operartion mode") | ||||
@@ -1452,7 +1452,7 @@ def indexing_one_hot( | |||||
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | ||||
(result,) = apply(op, src, index) | (result,) = apply(op, src, index) | ||||
if not keepdims: | if not keepdims: | ||||
result = remove_axis(result, axis) | |||||
result = squeeze(result, axis) | |||||
return result | return result | ||||
@@ -32,12 +32,12 @@ from ..tensor import Tensor | |||||
from .elemwise import ceil | from .elemwise import ceil | ||||
__all__ = [ | __all__ = [ | ||||
"add_axis", | |||||
"arange", | "arange", | ||||
"broadcast", | "broadcast", | ||||
"concat", | "concat", | ||||
"cond_take", | "cond_take", | ||||
"dimshuffle", | "dimshuffle", | ||||
"expand_dims", | |||||
"eye", | "eye", | ||||
"flatten", | "flatten", | ||||
"full", | "full", | ||||
@@ -50,7 +50,6 @@ __all__ = [ | |||||
"param_pack_concat", | "param_pack_concat", | ||||
"param_pack_split", | "param_pack_split", | ||||
"reshape", | "reshape", | ||||
"remove_axis", | |||||
"split", | "split", | ||||
"squeeze", | "squeeze", | ||||
"stack", | "stack", | ||||
@@ -160,7 +159,7 @@ def zeros_like(inp: Tensor) -> Tensor: | |||||
print(out.numpy()) | print(out.numpy()) | ||||
Outputs: | Outputs: | ||||
.. testoutput:: | .. testoutput:: | ||||
[[0 0 0] | [[0 0 0] | ||||
@@ -320,7 +319,7 @@ def stack(inps, axis=0, device=None): | |||||
if len(shapes) != 1: | if len(shapes) != 1: | ||||
raise ValueError("All input tensors must have the same shape") | raise ValueError("All input tensors must have the same shape") | ||||
inps = [add_axis(inp, axis=axis) for inp in inps] | |||||
inps = [expand_dims(inp, axis=axis) for inp in inps] | |||||
return concat(inps, axis=axis, device=device) | return concat(inps, axis=axis, device=device) | ||||
@@ -480,7 +479,7 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||||
def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | ||||
r"""Writes all values from the tensor source into input tensor | |||||
r"""Writes all values from the tensor source into input tensor | |||||
at the indices specified in the index tensor. | at the indices specified in the index tensor. | ||||
For each value in source, its output index is specified by its index | For each value in source, its output index is specified by its index | ||||
@@ -694,7 +693,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||||
Swaps shapes and strides according to given pattern. | Swaps shapes and strides according to given pattern. | ||||
:param inp: input tensor. | :param inp: input tensor. | ||||
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, | |||||
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, | |||||
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | ||||
* (``'x'``) -> make a 0d (scalar) into a 1d vector | * (``'x'``) -> make a 0d (scalar) into a 1d vector | ||||
@@ -815,7 +814,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: | |||||
return inp.reshape(*target_shape) | return inp.reshape(*target_shape) | ||||
def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
r""" | r""" | ||||
Adds dimension before given axis. | Adds dimension before given axis. | ||||
@@ -832,7 +831,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
import megengine.functional as F | import megengine.functional as F | ||||
x = tensor([1, 2]) | x = tensor([1, 2]) | ||||
out = F.add_axis(x, 0) | |||||
out = F.expand_dims(x, 0) | |||||
print(out.shape) | print(out.shape) | ||||
Outputs: | Outputs: | ||||
@@ -861,12 +860,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||||
return result | return result | ||||
expand_dims = add_axis | |||||
def remove_axis( | |||||
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None | |||||
) -> Tensor: | |||||
def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor: | |||||
r""" | r""" | ||||
Removes dimension of shape 1. | Removes dimension of shape 1. | ||||
@@ -883,7 +877,7 @@ def remove_axis( | |||||
import megengine.functional as F | import megengine.functional as F | ||||
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | ||||
out = F.remove_axis(x, 3) | |||||
out = F.squeeze(x, 3) | |||||
print(out.shape) | print(out.shape) | ||||
Outputs: | Outputs: | ||||
@@ -896,9 +890,6 @@ def remove_axis( | |||||
return _remove_axis(inp, axis) | return _remove_axis(inp, axis) | ||||
squeeze = remove_axis | |||||
def linspace( | def linspace( | ||||
start: Union[int, float, Tensor], | start: Union[int, float, Tensor], | ||||
stop: Union[int, float, Tensor], | stop: Union[int, float, Tensor], | ||||
@@ -925,7 +916,7 @@ def linspace( | |||||
print(a.numpy()) | print(a.numpy()) | ||||
Outputs: | Outputs: | ||||
.. testoutput:: | .. testoutput:: | ||||
[ 3. 4.75 6.5 8.25 10. ] | [ 3. 4.75 6.5 8.25 10. ] | ||||
@@ -967,7 +958,7 @@ def arange( | |||||
a = F.arange(5) | a = F.arange(5) | ||||
print(a.numpy()) | print(a.numpy()) | ||||
Outputs: | Outputs: | ||||
Outputs: | Outputs: | ||||
@@ -1018,9 +1009,9 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: | |||||
b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) | b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) | ||||
print(b.numpy()) | print(b.numpy()) | ||||
print(c.numpy()) | print(c.numpy()) | ||||
Outputs: | Outputs: | ||||
.. testoutput:: | .. testoutput:: | ||||
[1] | [1] | ||||
@@ -1059,9 +1050,9 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor: | |||||
offsets = tensor(offsets_val, np.int32) | offsets = tensor(offsets_val, np.int32) | ||||
c = F.param_pack_concat([a, b], offsets, offsets_val) | c = F.param_pack_concat([a, b], offsets, offsets_val) | ||||
print(c.numpy()) | print(c.numpy()) | ||||
Outputs: | Outputs: | ||||
.. testoutput:: | .. testoutput:: | ||||
[1 1 1 1 1 1 1 1 1 1] | [1 1 1 1 1 1 1 1 1 1] | ||||
@@ -306,7 +306,7 @@ def test_AxisAddRemove(): | |||||
x = TensorWrapper(x_np) | x = TensorWrapper(x_np) | ||||
grad = Grad().wrt(x, callback=save_to(x)) | grad = Grad().wrt(x, callback=save_to(x)) | ||||
y = F.remove_axis(F.add_axis(x, 2), 0) | |||||
y = F.squeeze(F.expand_dims(x, 2), 0) | |||||
grad(y, F.ones_like(y)) | grad(y, F.ones_like(y)) | ||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
@@ -100,7 +100,7 @@ def test_squeeze(): | |||||
for axis in [None, 3, -4, (3, -4)]: | for axis in [None, 3, -4, (3, -4)]: | ||||
y = np.squeeze(x, axis) | y = np.squeeze(x, axis) | ||||
yy = F.remove_axis(xx, axis) | |||||
yy = F.squeeze(xx, axis) | |||||
np.testing.assert_equal(y, yy.numpy()) | np.testing.assert_equal(y, yy.numpy()) | ||||
@@ -110,7 +110,7 @@ def test_expand_dims(): | |||||
for axis in [2, -3, (3, -4), (1, -4)]: | for axis in [2, -3, (3, -4), (1, -4)]: | ||||
y = np.expand_dims(x, axis) | y = np.expand_dims(x, axis) | ||||
yy = F.add_axis(xx, axis) | |||||
yy = F.expand_dims(xx, axis) | |||||
np.testing.assert_equal(y, yy.numpy()) | np.testing.assert_equal(y, yy.numpy()) | ||||