|
|
@@ -19,6 +19,7 @@ from ..core.ops import builtin |
|
|
|
from ..core.ops._internal import param_defs as P |
|
|
|
from ..core.ops.special import Const |
|
|
|
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply |
|
|
|
from ..core.tensor.tensor_wrapper import _remove_axis |
|
|
|
from ..core.tensor.utils import ( |
|
|
|
astensor1d, |
|
|
|
convert_inputs, |
|
|
@@ -231,9 +232,7 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: |
|
|
|
[3. 4. 5.]]] |
|
|
|
|
|
|
|
""" |
|
|
|
shape = astensor1d(shape, inp, dtype="int32", device=inp.device) |
|
|
|
(result,) = apply(builtin.Broadcast(), inp, shape) |
|
|
|
return result |
|
|
|
return inp.broadcast(shape) |
|
|
|
|
|
|
|
|
|
|
|
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: |
|
|
@@ -730,10 +729,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: |
|
|
|
[1 0]] |
|
|
|
|
|
|
|
""" |
|
|
|
op = builtin.Dimshuffle(pattern) |
|
|
|
(inp,) = convert_inputs(inp) |
|
|
|
(result,) = apply(op, inp) |
|
|
|
return result |
|
|
|
return inp.transpose(pattern) |
|
|
|
|
|
|
|
|
|
|
|
dimshuffle = transpose |
|
|
@@ -773,26 +769,7 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: |
|
|
|
[10 11]]] |
|
|
|
|
|
|
|
""" |
|
|
|
if isinstance(target_shape, (TensorBase, TensorWrapperBase)): |
|
|
|
target_shape = target_shape.numpy() |
|
|
|
target_shape = tuple(map(int, target_shape)) |
|
|
|
unspec_axis = None |
|
|
|
for i, s in enumerate(target_shape): |
|
|
|
if s < 0: |
|
|
|
if s != -1: |
|
|
|
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) |
|
|
|
if unspec_axis is not None: |
|
|
|
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) |
|
|
|
unspec_axis = i |
|
|
|
|
|
|
|
# TODO: device should be None (cpu) |
|
|
|
(target_shape,) = Const(target_shape, dtype="int32", device=inp.device)(inp) |
|
|
|
if unspec_axis is None: |
|
|
|
op = builtin.Reshape() |
|
|
|
else: |
|
|
|
op = builtin.Reshape(unspec_axis=unspec_axis) |
|
|
|
(x,) = apply(op, inp, target_shape) |
|
|
|
return x |
|
|
|
return inp.reshape(target_shape) |
|
|
|
|
|
|
|
|
|
|
|
AxisAddRemove = builtin.AxisAddRemove |
|
|
@@ -915,25 +892,7 @@ def remove_axis( |
|
|
|
(1, 1, 2) |
|
|
|
|
|
|
|
""" |
|
|
|
Param = builtin.AxisAddRemove.Param |
|
|
|
|
|
|
|
def get_axes(): |
|
|
|
if axis is None: |
|
|
|
return [i for i, s in enumerate(inp.shape) if s == 1] |
|
|
|
try: |
|
|
|
return [int(axis)] |
|
|
|
except (TypeError, ValueError): |
|
|
|
pass |
|
|
|
return list(map(int, axis)) |
|
|
|
|
|
|
|
axis = get_axes() |
|
|
|
axis = sorted(i + inp.ndim if i < 0 else i for i in axis) |
|
|
|
axis = [a - i for i, a in enumerate(axis)] |
|
|
|
|
|
|
|
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) |
|
|
|
op = builtin.AxisAddRemove(param=param) |
|
|
|
(result,) = apply(op, inp) |
|
|
|
return result |
|
|
|
return _remove_axis(inp, axis) |
|
|
|
|
|
|
|
|
|
|
|
def linspace( |
|
|
|