Browse Source

fix(mge/tensor): where support x or y is empty

GitOrigin-RevId: 2b69ee2e79
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
80d9e56b38
2 changed files with 21 additions and 3 deletions
  1. +3
    -2
      python_module/megengine/functional/tensor.py
  2. +18
    -1
      python_module/test/unit/functional/test_functional.py

+ 3
- 2
python_module/megengine/functional/tensor.py View File

@@ -353,8 +353,9 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:
y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0 y, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=0
) )
out = x.flatten() out = x.flatten()
out = mgb.opr.set_advanced_indexing(out, v0)[index0]
out = mgb.opr.set_advanced_indexing(out, v1)[index1]
index = mgb.opr.concat(index0, index1, axis=0)
v = mgb.opr.concat(v0, v1, axis=0)
out = mgb.opr.set_advanced_indexing(out, v)[index]
out = out.reshape(x.shape) out = out.reshape(x.shape)
return out return out




+ 18
- 1
python_module/test/unit/functional/test_functional.py View File

@@ -53,7 +53,24 @@ def test_where():
xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)


cases = [{"input": [maskv0, xv0, yv0]}, {"input": [maskv1, xv1, yv1]}]
cases = [
{"input": [maskv0, xv0, yv0]},
{"input": [maskv1, xv1, yv1]},
]
opr_test(cases, F.where, ref_fn=np.where)

maskv2 = np.array([1, 1, 1], dtype=np.int32)
xv2 = np.array([1, 3, 2], dtype=np.float32)
yv2 = np.array([5, 6, 9], dtype=np.float32)

maskv3 = np.array([0, 0, 0], dtype=np.int32)
xv3 = np.array([1, 3, 2], dtype=np.float32)
yv3 = np.array([5, 6, 9], dtype=np.float32)

cases = [
{"input": [maskv2, xv2, yv2]},
{"input": [maskv3, xv3, yv3]},
]
opr_test(cases, F.where, ref_fn=np.where) opr_test(cases, F.where, ref_fn=np.where)






Loading…
Cancel
Save