|
|
@@ -180,7 +180,7 @@ def try_condtake(tensor, index): |
|
|
|
if index.dtype != np.bool_ or index.shape != tensor.shape: |
|
|
|
return [] |
|
|
|
if isinstance(index, np.ndarray): |
|
|
|
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) |
|
|
|
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) |
|
|
|
assert isinstance(index, (TensorBase, TensorWrapperBase)) |
|
|
|
if not isinstance(tensor, (TensorWrapperBase, TensorBase)): |
|
|
|
raise TypeError("input must be a tensor") |
|
|
|