|
@@ -122,34 +122,34 @@ def test_flatten(): |
|
|
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) |
|
|
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# def test_where(): |
|
|
|
|
|
# maskv0 = np.array([[1, 0], [0, 1]], dtype=np.int32) |
|
|
|
|
|
# xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32) |
|
|
|
|
|
# yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32) |
|
|
|
|
|
|
|
|
def test_where(): |
|
|
|
|
|
maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_) |
|
|
|
|
|
xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32) |
|
|
|
|
|
yv0 = np.array([[5, 6], [7, 8]], dtype=np.float32) |
|
|
|
|
|
|
|
|
# maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.int32) |
|
|
|
|
|
# xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) |
|
|
|
|
|
# yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) |
|
|
|
|
|
|
|
|
maskv1 = np.array([[1, 0, 1], [1, 0, 0], [1, 1, 0]], dtype=np.bool_) |
|
|
|
|
|
xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32) |
|
|
|
|
|
yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32) |
|
|
|
|
|
|
|
|
# cases = [ |
|
|
|
|
|
# {"input": [maskv0, xv0, yv0]}, |
|
|
|
|
|
# {"input": [maskv1, xv1, yv1]}, |
|
|
|
|
|
# ] |
|
|
|
|
|
# opr_test(cases, F.where, ref_fn=np.where) |
|
|
|
|
|
|
|
|
cases = [ |
|
|
|
|
|
{"input": [maskv0, xv0, yv0]}, |
|
|
|
|
|
{"input": [maskv1, xv1, yv1]}, |
|
|
|
|
|
] |
|
|
|
|
|
opr_test(cases, F.where, ref_fn=np.where) |
|
|
|
|
|
|
|
|
# maskv2 = np.array([1, 1, 1], dtype=np.int32) |
|
|
|
|
|
# xv2 = np.array([1, 3, 2], dtype=np.float32) |
|
|
|
|
|
# yv2 = np.array([5, 6, 9], dtype=np.float32) |
|
|
|
|
|
|
|
|
maskv2 = np.array([1, 1, 1], dtype=np.bool_) |
|
|
|
|
|
xv2 = np.array([1, 3, 2], dtype=np.float32) |
|
|
|
|
|
yv2 = np.array([5, 6, 9], dtype=np.float32) |
|
|
|
|
|
|
|
|
# maskv3 = np.array([0, 0, 0], dtype=np.int32) |
|
|
|
|
|
# xv3 = np.array([1, 3, 2], dtype=np.float32) |
|
|
|
|
|
# yv3 = np.array([5, 6, 9], dtype=np.float32) |
|
|
|
|
|
|
|
|
maskv3 = np.array([0, 0, 0], dtype=np.bool_) |
|
|
|
|
|
xv3 = np.array([1, 3, 2], dtype=np.float32) |
|
|
|
|
|
yv3 = np.array([5, 6, 9], dtype=np.float32) |
|
|
|
|
|
|
|
|
# cases = [ |
|
|
|
|
|
# {"input": [maskv2, xv2, yv2]}, |
|
|
|
|
|
# {"input": [maskv3, xv3, yv3]}, |
|
|
|
|
|
# ] |
|
|
|
|
|
# opr_test(cases, F.where, ref_fn=np.where) |
|
|
|
|
|
|
|
|
cases = [ |
|
|
|
|
|
{"input": [maskv2, xv2, yv2]}, |
|
|
|
|
|
{"input": [maskv3, xv3, yv3]}, |
|
|
|
|
|
] |
|
|
|
|
|
opr_test(cases, F.where, ref_fn=np.where) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_matmul(): |
|
|
def test_matmul(): |
|
|