From a398d4b5a5d4e2e1b11914b8a578dd948593444e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 16 Sep 2020 00:16:15 +0800 Subject: [PATCH] chore(mge/functional): remove duplicated code GitOrigin-RevId: f9efea46cb996b1583d8b8d038a9edf21d7ac83c --- .../python/megengine/core/tensor/tensor_wrapper.py | 3 +- imperative/python/megengine/functional/elemwise.py | 7 --- imperative/python/megengine/functional/tensor.py | 51 +++------------------- 3 files changed, 7 insertions(+), 54 deletions(-) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index d32eb9c3..afe180b3 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -210,7 +210,8 @@ def _todo(*_): def _expand_args(args): if len(args) == 1: if isinstance( - args[0], (collections.abc.Sequence, TensorBase, TensorWrapperBase) + args[0], + (collections.abc.Sequence, TensorBase, TensorWrapperBase, np.ndarray), ): args = args[0] return args diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 3781ae64..ea6d001a 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -88,13 +88,6 @@ def _elwise(*args, mode): return result -def _logical(*args, mode): - op = builtin.CondExecPredLogical(mode=mode) - args = utils.convert_inputs(*args) - (result,) = apply(op, *args) - return result - - def _elemwise_multi_type(*args, mode, **kwargs): op = builtin.ElemwiseMultiType(mode=mode, **kwargs) args = utils.convert_inputs(*args) diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 31053b78..4cdebc15 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -19,6 +19,7 @@ from ..core.ops import builtin from ..core.ops._internal import param_defs as P from ..core.ops.special import Const from ..core.tensor.core import TensorBase, TensorWrapperBase, apply +from ..core.tensor.tensor_wrapper import _remove_axis from ..core.tensor.utils import ( astensor1d, convert_inputs, @@ -231,9 +232,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: [3. 4. 5.]]] """ - shape = astensor1d(shape, inp, dtype="int32", device=inp.device) - (result,) = apply(builtin.Broadcast(), inp, shape) - return result + return inp.broadcast(shape) def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: @@ -730,10 +729,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: [1 0]] """ - op = builtin.Dimshuffle(pattern) - (inp,) = convert_inputs(inp) - (result,) = apply(op, inp) - return result + return inp.transpose(pattern) dimshuffle = transpose @@ -773,26 +769,7 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: [10 11]]] """ - if isinstance(target_shape, (TensorBase, TensorWrapperBase)): - target_shape = target_shape.numpy() - target_shape = tuple(map(int, target_shape)) - unspec_axis = None - for i, s in enumerate(target_shape): - if s < 0: - if s != -1: - raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) - if unspec_axis is not None: - raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) - unspec_axis = i - - # TODO: device should be None (cpu) - (target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp) - if unspec_axis is None: - op = builtin.Reshape() - else: - op = builtin.Reshape(unspec_axis=unspec_axis) - (x,) = apply(op, inp, target_shape) - return x + return inp.reshape(target_shape) AxisAddRemove = builtin.AxisAddRemove @@ -915,25 +892,7 @@ def remove_axis( (1, 1, 2) """ - Param = builtin.AxisAddRemove.Param - - def get_axes(): - if axis is None: - return [i for i, s in enumerate(inp.shape) if s == 1] - try: - return [int(axis)] - except (TypeError, ValueError): - pass - return list(map(int, axis)) - - axis = get_axes() - axis = sorted(i + inp.ndim if i < 0 else i for i in axis) - axis = [a - i for i, a in enumerate(axis)] - - param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) - op = builtin.AxisAddRemove(param=param) - (result,) = apply(op, inp) - return result + return _remove_axis(inp, axis) def linspace(