|
|
@@ -65,10 +65,10 @@ def _broadcast(inp, shape): |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
if isinstance(src, (Tensor, TensorWrapperBase)): |
|
|
|
if isinstance(src, (TensorBase, TensorWrapperBase)): |
|
|
|
src = src.numpy() |
|
|
|
|
|
|
|
if isinstance(tar, (Tensor, TensorWrapperBase)): |
|
|
|
if isinstance(tar, (TensorBase, TensorWrapperBase)): |
|
|
|
tar = tar.numpy() |
|
|
|
|
|
|
|
if len(src) > len(tar): |
|
|
@@ -78,8 +78,8 @@ def _broadcast(inp, shape): |
|
|
|
if src[-i - 1] != 1 and src[-i - 1] != tar[-i - 1]: |
|
|
|
failed() |
|
|
|
|
|
|
|
valid_broadcast(inp.shape, shape) |
|
|
|
shape = utils.astensor1d(shape, inp, dtype="int32", device=inp.device) |
|
|
|
valid_broadcast(inp.shape, shape) |
|
|
|
(result,) = apply(builtin.Broadcast(), inp, shape) |
|
|
|
return result |
|
|
|
|
|
|
|