Browse Source

feat(mge): remove add_axis and remove_axis

GitOrigin-RevId: 59611d43f9
release-1.1
Megvii Engine Team 4 years ago
parent
commit
1e2117f66d
6 changed files with 31 additions and 40 deletions
  1. +1
    -1
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +5
    -5
      imperative/python/megengine/functional/math.py
  3. +7
    -7
      imperative/python/megengine/functional/nn.py
  4. +15
    -24
      imperative/python/megengine/functional/tensor.py
  5. +1
    -1
      imperative/python/test/unit/core/test_autodiff.py
  6. +2
    -2
      imperative/python/test/unit/functional/test_tensor.py

+ 1
- 1
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -411,7 +411,7 @@ class ArrayMethodMixin(abc.ABC):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them.

If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.squeeze`).

Same for prod/mean/max/min.



+ 5
- 5
imperative/python/megengine/functional/math.py View File

@@ -19,7 +19,7 @@ from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor import Tensor
from .elemwise import clip, exp, log, log1p
from .tensor import add_axis, remove_axis, reshape
from .tensor import reshape, squeeze

__all__ = [
"argmax",
@@ -459,7 +459,7 @@ def argmin(
(inp,) = apply(op, inp)

if not keepdims:
inp = remove_axis(inp, ai)
inp = squeeze(inp, ai)

return inp

@@ -471,7 +471,7 @@ def argmin(
op = builtin.Argmin(axis=axis)
(result,) = apply(op, inp)
if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result


@@ -517,7 +517,7 @@ def argmax(
(inp,) = apply(op, inp)

if not keepdims:
inp = remove_axis(inp, ai)
inp = squeeze(inp, ai)

return inp

@@ -529,7 +529,7 @@ def argmax(
op = builtin.Argmax(axis=axis)
(result,) = apply(op, inp)
if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result




+ 7
- 7
imperative/python/megengine/functional/nn.py View File

@@ -23,7 +23,7 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum
from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshape, zeros
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros
from .types import _pair, _pair_nonzero

__all__ = [
@@ -542,7 +542,7 @@ def logsumexp(
if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else:
return remove_axis(max_value, axis=None) + log(
return squeeze(max_value, axis=None) + log(
sum(exp(inp - max_value), axis, keepdims)
)

@@ -640,7 +640,7 @@ def batch_norm2d(
def expand_or_full(x, value):
if x is None:
return full_value(value)
return add_axis(x, [0, 2, 3])
return expand_dims(x, [0, 2, 3])

def make_full_if_none(x, value):
if x is None:
@@ -998,10 +998,10 @@ def matmul(
else:
if dim1 == 1:
shp = (inp2.shape[1],)
inp1 = add_axis(inp1, 0)
inp1 = expand_dims(inp1, 0)
if dim2 == 1:
shp = (inp1.shape[0],)
inp2 = add_axis(inp2, 1)
inp2 = expand_dims(inp2, 1)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
@@ -1135,7 +1135,7 @@ def interpolate(
align_corners = False

if mode == "LINEAR":
inp = add_axis(inp, 3)
inp = expand_dims(inp, 3)

if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode")
@@ -1452,7 +1452,7 @@ def indexing_one_hot(
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result




+ 15
- 24
imperative/python/megengine/functional/tensor.py View File

@@ -32,12 +32,12 @@ from ..tensor import Tensor
from .elemwise import ceil

__all__ = [
"add_axis",
"arange",
"broadcast",
"concat",
"cond_take",
"dimshuffle",
"expand_dims",
"eye",
"flatten",
"full",
@@ -50,7 +50,6 @@ __all__ = [
"param_pack_concat",
"param_pack_split",
"reshape",
"remove_axis",
"split",
"squeeze",
"stack",
@@ -160,7 +159,7 @@ def zeros_like(inp: Tensor) -> Tensor:
print(out.numpy())

Outputs:
.. testoutput::

[[0 0 0]
@@ -320,7 +319,7 @@ def stack(inps, axis=0, device=None):
if len(shapes) != 1:
raise ValueError("All input tensors must have the same shape")

inps = [add_axis(inp, axis=axis) for inp in inps]
inps = [expand_dims(inp, axis=axis) for inp in inps]
return concat(inps, axis=axis, device=device)


@@ -480,7 +479,7 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:


def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
r"""Writes all values from the tensor source into input tensor
r"""Writes all values from the tensor source into input tensor
at the indices specified in the index tensor.

For each value in source, its output index is specified by its index
@@ -694,7 +693,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
Swaps shapes and strides according to given pattern.

:param inp: input tensor.
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:

* (``'x'``) -> make a 0d (scalar) into a 1d vector
@@ -815,7 +814,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
return inp.reshape(*target_shape)


def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
r"""
Adds dimension before given axis.

@@ -832,7 +831,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
import megengine.functional as F

x = tensor([1, 2])
out = F.add_axis(x, 0)
out = F.expand_dims(x, 0)
print(out.shape)

Outputs:
@@ -861,12 +860,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return result


expand_dims = add_axis


def remove_axis(
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None
) -> Tensor:
def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
r"""
Removes dimension of shape 1.

@@ -883,7 +877,7 @@ def remove_axis(
import megengine.functional as F

x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
out = F.remove_axis(x, 3)
out = F.squeeze(x, 3)
print(out.shape)

Outputs:
@@ -896,9 +890,6 @@ def remove_axis(
return _remove_axis(inp, axis)


squeeze = remove_axis


def linspace(
start: Union[int, float, Tensor],
stop: Union[int, float, Tensor],
@@ -925,7 +916,7 @@ def linspace(
print(a.numpy())

Outputs:
.. testoutput::

[ 3. 4.75 6.5 8.25 10. ]
@@ -967,7 +958,7 @@ def arange(

a = F.arange(5)
print(a.numpy())
Outputs:

Outputs:
@@ -1018,9 +1009,9 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy())
print(c.numpy())
Outputs:
.. testoutput::

[1]
@@ -1059,9 +1050,9 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
offsets = tensor(offsets_val, np.int32)
c = F.param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy())
Outputs:
.. testoutput::

[1 1 1 1 1 1 1 1 1 1]


+ 1
- 1
imperative/python/test/unit/core/test_autodiff.py View File

@@ -306,7 +306,7 @@ def test_AxisAddRemove():
x = TensorWrapper(x_np)

grad = Grad().wrt(x, callback=save_to(x))
y = F.remove_axis(F.add_axis(x, 2), 0)
y = F.squeeze(F.expand_dims(x, 2), 0)

grad(y, F.ones_like(y))
np.testing.assert_equal(


+ 2
- 2
imperative/python/test/unit/functional/test_tensor.py View File

@@ -100,7 +100,7 @@ def test_squeeze():

for axis in [None, 3, -4, (3, -4)]:
y = np.squeeze(x, axis)
yy = F.remove_axis(xx, axis)
yy = F.squeeze(xx, axis)
np.testing.assert_equal(y, yy.numpy())


@@ -110,7 +110,7 @@ def test_expand_dims():

for axis in [2, -3, (3, -4), (1, -4)]:
y = np.expand_dims(x, axis)
yy = F.add_axis(xx, axis)
yy = F.expand_dims(xx, axis)
np.testing.assert_equal(y, yy.numpy())




Loading…
Cancel
Save