Browse Source

chore(mge/functional): remove duplicated code

GitOrigin-RevId: f9efea46cb
release-1.1
Megvii Engine Team 4 years ago
parent
commit
a398d4b5a5
3 changed files with 7 additions and 54 deletions
  1. +2
    -1
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +0
    -7
      imperative/python/megengine/functional/elemwise.py
  3. +5
    -46
      imperative/python/megengine/functional/tensor.py

+ 2
- 1
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -210,7 +210,8 @@ def _todo(*_):
def _expand_args(args):
if len(args) == 1:
if isinstance(
args[0], (collections.abc.Sequence, TensorBase, TensorWrapperBase)
args[0],
(collections.abc.Sequence, TensorBase, TensorWrapperBase, np.ndarray),
):
args = args[0]
return args


+ 0
- 7
imperative/python/megengine/functional/elemwise.py View File

@@ -88,13 +88,6 @@ def _elwise(*args, mode):
return result


def _logical(*args, mode):
op = builtin.CondExecPredLogical(mode=mode)
args = utils.convert_inputs(*args)
(result,) = apply(op, *args)
return result


def _elemwise_multi_type(*args, mode, **kwargs):
op = builtin.ElemwiseMultiType(mode=mode, **kwargs)
args = utils.convert_inputs(*args)


+ 5
- 46
imperative/python/megengine/functional/tensor.py View File

@@ -19,6 +19,7 @@ from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _remove_axis
from ..core.tensor.utils import (
astensor1d,
convert_inputs,
@@ -231,9 +232,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
[3. 4. 5.]]]

"""
shape = astensor1d(shape, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result
return inp.broadcast(shape)


def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
@@ -730,10 +729,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
[1 0]]

"""
op = builtin.Dimshuffle(pattern)
(inp,) = convert_inputs(inp)
(result,) = apply(op, inp)
return result
return inp.transpose(pattern)


dimshuffle = transpose
@@ -773,26 +769,7 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
[10 11]]]

"""
if isinstance(target_shape, (TensorBase, TensorWrapperBase)):
target_shape = target_shape.numpy()
target_shape = tuple(map(int, target_shape))
unspec_axis = None
for i, s in enumerate(target_shape):
if s < 0:
if s != -1:
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s))
if unspec_axis is not None:
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
unspec_axis = i

# TODO: device should be None (cpu)
(target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp)
if unspec_axis is None:
op = builtin.Reshape()
else:
op = builtin.Reshape(unspec_axis=unspec_axis)
(x,) = apply(op, inp, target_shape)
return x
return inp.reshape(target_shape)


AxisAddRemove = builtin.AxisAddRemove
@@ -915,25 +892,7 @@ def remove_axis(
(1, 1, 2)

"""
Param = builtin.AxisAddRemove.Param

def get_axes():
if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1]
try:
return [int(axis)]
except (TypeError, ValueError):
pass
return list(map(int, axis))

axis = get_axes()
axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
axis = [a - i for i, a in enumerate(axis)]

param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis))
op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp)
return result
return _remove_axis(inp, axis)


def linspace(


Loading…
Cancel
Save