diff --git a/python_module/megengine/functional/tensor.py b/python_module/megengine/functional/tensor.py index 2cc58a57..eff4d34a 100644 --- a/python_module/megengine/functional/tensor.py +++ b/python_module/megengine/functional/tensor.py @@ -353,8 +353,9 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0 ) out = x.flatten() - out = mgb.opr.set_advanced_indexing(out, v0)[index0] - out = mgb.opr.set_advanced_indexing(out, v1)[index1] + index = mgb.opr.concat(index0, index1, axis=0) + v = mgb.opr.concat(v0, v1, axis=0) + out = mgb.opr.set_advanced_indexing(out, v)[index] out = out.reshape(x.shape) return out diff --git a/python_module/test/unit/functional/test_functional.py b/python_module/test/unit/functional/test_functional.py index b9f0cebf..b4e81129 100644 --- a/python_module/test/unit/functional/test_functional.py +++ b/python_module/test/unit/functional/test_functional.py @@ -53,7 +53,24 @@ def test_where(): xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) - cases = [{"input": [maskv0, xv0, yv0]}, {"input": [maskv1, xv1, yv1]}] + cases = [ + {"input": [maskv0, xv0, yv0]}, + {"input": [maskv1, xv1, yv1]}, + ] + opr_test(cases, F.where, ref_fn=np.where) + + maskv2 = np.array([1, 1, 1], dtype=np.int32) + xv2 = np.array([1, 3, 2], dtype=np.float32) + yv2 = np.array([5, 6, 9], dtype=np.float32) + + maskv3 = np.array([0, 0, 0], dtype=np.int32) + xv3 = np.array([1, 3, 2], dtype=np.float32) + yv3 = np.array([5, 6, 9], dtype=np.float32) + + cases = [ + {"input": [maskv2, xv2, yv2]}, + {"input": [maskv3, xv3, yv3]}, + ] opr_test(cases, F.where, ref_fn=np.where)