|
@@ -399,8 +399,6 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: |
|
|
if len(inps) == 1: |
|
|
if len(inps) == 1: |
|
|
return inps[0] |
|
|
return inps[0] |
|
|
|
|
|
|
|
|
# FIXME: remove this convert_inputs |
|
|
|
|
|
inps = convert_inputs(*inps, device=device) |
|
|
|
|
|
if device is None: |
|
|
if device is None: |
|
|
device = get_device(inps) |
|
|
device = get_device(inps) |
|
|
device = as_device(device) |
|
|
device = as_device(device) |
|
@@ -1168,9 +1166,10 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None): |
|
|
bcast_shape.append(shape[axis + 1 :]) |
|
|
bcast_shape.append(shape[axis + 1 :]) |
|
|
target_shape.append(shape[axis + 1 :]) |
|
|
target_shape.append(shape[axis + 1 :]) |
|
|
|
|
|
|
|
|
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( |
|
|
|
|
|
concat(target_shape) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
base_shape = astensor1d(base_shape) |
|
|
|
|
|
bcast_shape = astensor1d(bcast_shape) |
|
|
|
|
|
target_shape = astensor1d(target_shape) |
|
|
|
|
|
out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) |
|
|
return out |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1191,9 +1190,10 @@ def _tile_one_dim(inp, rep, axis): |
|
|
if axis + 1 <= max_axis: |
|
|
if axis + 1 <= max_axis: |
|
|
target_shape.append(shape[axis + 1 :]) |
|
|
target_shape.append(shape[axis + 1 :]) |
|
|
|
|
|
|
|
|
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape( |
|
|
|
|
|
concat(target_shape) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
base_shape = astensor1d(base_shape) |
|
|
|
|
|
bcast_shape = astensor1d(bcast_shape) |
|
|
|
|
|
target_shape = astensor1d(target_shape) |
|
|
|
|
|
out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape) |
|
|
return out |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|