Browse Source

fix(imperative): remove convert_inputs from concat

GitOrigin-RevId: 1511cb4b43
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
69673f14f7
1 changed files with 8 additions and 8 deletions
  1. +8
    -8
      imperative/python/megengine/functional/tensor.py

+ 8
- 8
imperative/python/megengine/functional/tensor.py View File

@@ -399,8 +399,6 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
if len(inps) == 1: if len(inps) == 1:
return inps[0] return inps[0]


# FIXME: remove this convert_inputs
inps = convert_inputs(*inps, device=device)
if device is None: if device is None:
device = get_device(inps) device = get_device(inps)
device = as_device(device) device = as_device(device)
@@ -1168,9 +1166,10 @@ def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
bcast_shape.append(shape[axis + 1 :]) bcast_shape.append(shape[axis + 1 :])
target_shape.append(shape[axis + 1 :]) target_shape.append(shape[axis + 1 :])


out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
base_shape = astensor1d(base_shape)
bcast_shape = astensor1d(bcast_shape)
target_shape = astensor1d(target_shape)
out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)
return out return out




@@ -1191,9 +1190,10 @@ def _tile_one_dim(inp, rep, axis):
if axis + 1 <= max_axis: if axis + 1 <= max_axis:
target_shape.append(shape[axis + 1 :]) target_shape.append(shape[axis + 1 :])


out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
base_shape = astensor1d(base_shape)
bcast_shape = astensor1d(bcast_shape)
target_shape = astensor1d(target_shape)
out = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)
return out return out






Loading…
Cancel
Save