|
|
@@ -18,14 +18,7 @@ from ..core.ops import builtin |
|
|
|
from ..core.ops.builtin import Copy, Identity |
|
|
|
from ..core.ops.special import Const |
|
|
|
from ..core.tensor.array_method import _broadcast, _remove_axis |
|
|
|
from ..core.tensor.utils import ( |
|
|
|
astensor1d, |
|
|
|
convert_inputs, |
|
|
|
get_device, |
|
|
|
isscalar, |
|
|
|
setscalar, |
|
|
|
subgraph_fn, |
|
|
|
) |
|
|
|
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn |
|
|
|
from ..device import get_default_device |
|
|
|
from ..tensor import Tensor |
|
|
|
from .elemwise import ceil |
|
|
@@ -821,8 +814,6 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: |
|
|
|
|
|
|
|
where = _get_where_op(dtype=dtype, device=device) |
|
|
|
(oup,) = where(mask, x, y) |
|
|
|
if isscalar(mask): |
|
|
|
setscalar(oup) |
|
|
|
return oup |
|
|
|
|
|
|
|
|
|
|
|