Browse Source

feat(mge): remove add_axis and remove_axis

GitOrigin-RevId: 59611d43f9
release-1.1
Megvii Engine Team 4 years ago
parent
commit
1e2117f66d
6 changed files with 31 additions and 40 deletions
  1. +1
    -1
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +5
    -5
      imperative/python/megengine/functional/math.py
  3. +7
    -7
      imperative/python/megengine/functional/nn.py
  4. +15
    -24
      imperative/python/megengine/functional/tensor.py
  5. +1
    -1
      imperative/python/test/unit/core/test_autodiff.py
  6. +2
    -2
      imperative/python/test/unit/functional/test_tensor.py

+ 1
- 1
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -411,7 +411,7 @@ class ArrayMethodMixin(abc.ABC):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
If ``axis`` is a list of axises, reduce over all of them. If ``axis`` is a list of axises, reduce over all of them.


If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.remove_axis`).
If ``keepdims`` is ``True``, the shape of output tensor is the same as the input tensor, except in the dimension(s) ``axis`` where it is of size 1. Otherwise, ``axis`` is squeezed(see :meth:`~.functional.tensor.squeeze`).


Same for prod/mean/max/min. Same for prod/mean/max/min.




+ 5
- 5
imperative/python/megengine/functional/math.py View File

@@ -19,7 +19,7 @@ from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor import Tensor from ..tensor import Tensor
from .elemwise import clip, exp, log, log1p from .elemwise import clip, exp, log, log1p
from .tensor import add_axis, remove_axis, reshape
from .tensor import reshape, squeeze


__all__ = [ __all__ = [
"argmax", "argmax",
@@ -459,7 +459,7 @@ def argmin(
(inp,) = apply(op, inp) (inp,) = apply(op, inp)


if not keepdims: if not keepdims:
inp = remove_axis(inp, ai)
inp = squeeze(inp, ai)


return inp return inp


@@ -471,7 +471,7 @@ def argmin(
op = builtin.Argmin(axis=axis) op = builtin.Argmin(axis=axis)
(result,) = apply(op, inp) (result,) = apply(op, inp)
if not keepdims: if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result return result




@@ -517,7 +517,7 @@ def argmax(
(inp,) = apply(op, inp) (inp,) = apply(op, inp)


if not keepdims: if not keepdims:
inp = remove_axis(inp, ai)
inp = squeeze(inp, ai)


return inp return inp


@@ -529,7 +529,7 @@ def argmax(
op = builtin.Argmax(axis=axis) op = builtin.Argmax(axis=axis)
(result,) = apply(op, inp) (result,) = apply(op, inp)
if not keepdims: if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result return result






+ 7
- 7
imperative/python/megengine/functional/nn.py View File

@@ -23,7 +23,7 @@ from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum from .math import argsort, max, sum
from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshape, zeros
from .tensor import broadcast, concat, expand_dims, full, ones, reshape, squeeze, zeros
from .types import _pair, _pair_nonzero from .types import _pair, _pair_nonzero


__all__ = [ __all__ = [
@@ -542,7 +542,7 @@ def logsumexp(
if keepdims: if keepdims:
return max_value + log(sum(exp(inp - max_value), axis, keepdims)) return max_value + log(sum(exp(inp - max_value), axis, keepdims))
else: else:
return remove_axis(max_value, axis=None) + log(
return squeeze(max_value, axis=None) + log(
sum(exp(inp - max_value), axis, keepdims) sum(exp(inp - max_value), axis, keepdims)
) )


@@ -640,7 +640,7 @@ def batch_norm2d(
def expand_or_full(x, value): def expand_or_full(x, value):
if x is None: if x is None:
return full_value(value) return full_value(value)
return add_axis(x, [0, 2, 3])
return expand_dims(x, [0, 2, 3])


def make_full_if_none(x, value): def make_full_if_none(x, value):
if x is None: if x is None:
@@ -998,10 +998,10 @@ def matmul(
else: else:
if dim1 == 1: if dim1 == 1:
shp = (inp2.shape[1],) shp = (inp2.shape[1],)
inp1 = add_axis(inp1, 0)
inp1 = expand_dims(inp1, 0)
if dim2 == 1: if dim2 == 1:
shp = (inp1.shape[0],) shp = (inp1.shape[0],)
inp2 = add_axis(inp2, 1)
inp2 = expand_dims(inp2, 1)
op = builtin.MatrixMul( op = builtin.MatrixMul(
transposeA=transpose_a, transposeA=transpose_a,
transposeB=transpose_b, transposeB=transpose_b,
@@ -1135,7 +1135,7 @@ def interpolate(
align_corners = False align_corners = False


if mode == "LINEAR": if mode == "LINEAR":
inp = add_axis(inp, 3)
inp = expand_dims(inp, 3)


if inp.ndim != 4: if inp.ndim != 4:
raise ValueError("shape of input tensor must correspond to the operartion mode") raise ValueError("shape of input tensor must correspond to the operartion mode")
@@ -1452,7 +1452,7 @@ def indexing_one_hot(
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index) (result,) = apply(op, src, index)
if not keepdims: if not keepdims:
result = remove_axis(result, axis)
result = squeeze(result, axis)
return result return result






+ 15
- 24
imperative/python/megengine/functional/tensor.py View File

@@ -32,12 +32,12 @@ from ..tensor import Tensor
from .elemwise import ceil from .elemwise import ceil


__all__ = [ __all__ = [
"add_axis",
"arange", "arange",
"broadcast", "broadcast",
"concat", "concat",
"cond_take", "cond_take",
"dimshuffle", "dimshuffle",
"expand_dims",
"eye", "eye",
"flatten", "flatten",
"full", "full",
@@ -50,7 +50,6 @@ __all__ = [
"param_pack_concat", "param_pack_concat",
"param_pack_split", "param_pack_split",
"reshape", "reshape",
"remove_axis",
"split", "split",
"squeeze", "squeeze",
"stack", "stack",
@@ -160,7 +159,7 @@ def zeros_like(inp: Tensor) -> Tensor:
print(out.numpy()) print(out.numpy())


Outputs: Outputs:
.. testoutput:: .. testoutput::


[[0 0 0] [[0 0 0]
@@ -320,7 +319,7 @@ def stack(inps, axis=0, device=None):
if len(shapes) != 1: if len(shapes) != 1:
raise ValueError("All input tensors must have the same shape") raise ValueError("All input tensors must have the same shape")


inps = [add_axis(inp, axis=axis) for inp in inps]
inps = [expand_dims(inp, axis=axis) for inp in inps]
return concat(inps, axis=axis, device=device) return concat(inps, axis=axis, device=device)




@@ -480,7 +479,7 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor:




def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor:
r"""Writes all values from the tensor source into input tensor
r"""Writes all values from the tensor source into input tensor
at the indices specified in the index tensor. at the indices specified in the index tensor.


For each value in source, its output index is specified by its index For each value in source, its output index is specified by its index
@@ -694,7 +693,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
Swaps shapes and strides according to given pattern. Swaps shapes and strides according to given pattern.


:param inp: input tensor. :param inp: input tensor.
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1,
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples:


* (``'x'``) -> make a 0d (scalar) into a 1d vector * (``'x'``) -> make a 0d (scalar) into a 1d vector
@@ -815,7 +814,7 @@ def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
return inp.reshape(*target_shape) return inp.reshape(*target_shape)




def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
r""" r"""
Adds dimension before given axis. Adds dimension before given axis.


@@ -832,7 +831,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
import megengine.functional as F import megengine.functional as F


x = tensor([1, 2]) x = tensor([1, 2])
out = F.add_axis(x, 0)
out = F.expand_dims(x, 0)
print(out.shape) print(out.shape)


Outputs: Outputs:
@@ -861,12 +860,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
return result return result




expand_dims = add_axis


def remove_axis(
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None
) -> Tensor:
def squeeze(inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None) -> Tensor:
r""" r"""
Removes dimension of shape 1. Removes dimension of shape 1.


@@ -883,7 +877,7 @@ def remove_axis(
import megengine.functional as F import megengine.functional as F


x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1))
out = F.remove_axis(x, 3)
out = F.squeeze(x, 3)
print(out.shape) print(out.shape)


Outputs: Outputs:
@@ -896,9 +890,6 @@ def remove_axis(
return _remove_axis(inp, axis) return _remove_axis(inp, axis)




squeeze = remove_axis


def linspace( def linspace(
start: Union[int, float, Tensor], start: Union[int, float, Tensor],
stop: Union[int, float, Tensor], stop: Union[int, float, Tensor],
@@ -925,7 +916,7 @@ def linspace(
print(a.numpy()) print(a.numpy())


Outputs: Outputs:
.. testoutput:: .. testoutput::


[ 3. 4.75 6.5 8.25 10. ] [ 3. 4.75 6.5 8.25 10. ]
@@ -967,7 +958,7 @@ def arange(


a = F.arange(5) a = F.arange(5)
print(a.numpy()) print(a.numpy())
Outputs: Outputs:


Outputs: Outputs:
@@ -1018,9 +1009,9 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor:
b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)]) b, c = F.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy()) print(b.numpy())
print(c.numpy()) print(c.numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::


[1] [1]
@@ -1059,9 +1050,9 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor:
offsets = tensor(offsets_val, np.int32) offsets = tensor(offsets_val, np.int32)
c = F.param_pack_concat([a, b], offsets, offsets_val) c = F.param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy()) print(c.numpy())
Outputs: Outputs:
.. testoutput:: .. testoutput::


[1 1 1 1 1 1 1 1 1 1] [1 1 1 1 1 1 1 1 1 1]


+ 1
- 1
imperative/python/test/unit/core/test_autodiff.py View File

@@ -306,7 +306,7 @@ def test_AxisAddRemove():
x = TensorWrapper(x_np) x = TensorWrapper(x_np)


grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.remove_axis(F.add_axis(x, 2), 0)
y = F.squeeze(F.expand_dims(x, 2), 0)


grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(


+ 2
- 2
imperative/python/test/unit/functional/test_tensor.py View File

@@ -100,7 +100,7 @@ def test_squeeze():


for axis in [None, 3, -4, (3, -4)]: for axis in [None, 3, -4, (3, -4)]:
y = np.squeeze(x, axis) y = np.squeeze(x, axis)
yy = F.remove_axis(xx, axis)
yy = F.squeeze(xx, axis)
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())




@@ -110,7 +110,7 @@ def test_expand_dims():


for axis in [2, -3, (3, -4), (1, -4)]: for axis in [2, -3, (3, -4), (1, -4)]:
y = np.expand_dims(x, axis) y = np.expand_dims(x, axis)
yy = F.add_axis(xx, axis)
yy = F.expand_dims(xx, axis)
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())






Loading…
Cancel
Save