|
|
@@ -611,6 +611,25 @@ def test_subtensor_on_empty_tensor(symbolic): |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("symbolic", [True, False, None]) |
|
|
|
def test_indexingMultiAxisVec_on_empty_tensor(symbolic): |
|
|
|
np_x = np.array([], dtype=np.float32).reshape(10, 10, 0) |
|
|
|
mge_x = megengine.tensor(np_x) |
|
|
|
|
|
|
|
def run_test(fn): |
|
|
|
out_ref = fn(np_x) |
|
|
|
if symbolic is not None: |
|
|
|
fn = jit.trace(symbolic=symbolic)(fn) |
|
|
|
for i in range(3): |
|
|
|
out = fn(mge_x) |
|
|
|
np.testing.assert_equal(out.numpy(), out_ref) |
|
|
|
|
|
|
|
run_test(lambda x: x[[1, 2, 3]]) |
|
|
|
run_test(lambda x: x[[1, 2, 3], [4, 5, 6]]) |
|
|
|
run_test(lambda x: x[[]]) |
|
|
|
run_test(lambda x: x[[], [], []]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("symbolic", [True, False, None]) |
|
|
|
def test_setsubtensor_on_empty_tensor(symbolic): |
|
|
|
def run_test(inp_shp, fn): |
|
|
|
np_x = np.random.randn(*inp_shp).astype(np.float32) |
|
|
@@ -655,3 +674,39 @@ def test_setsubtensor_on_empty_tensor(symbolic): |
|
|
|
run_test((10, 10, 10), test4) |
|
|
|
run_test((10, 10, 10), test5) |
|
|
|
run_test((10, 10, 10), test6) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("symbolic", [True, False, None]) |
|
|
|
def test_indexingSetMultiAxisVec_on_empty_tensor(symbolic): |
|
|
|
def run_test(inp_shp, fn): |
|
|
|
np_x = np.random.randn(*inp_shp).astype(np.float32) |
|
|
|
mge_x = megengine.tensor(np_x) |
|
|
|
out_ref = fn(np_x) |
|
|
|
if symbolic is not None: |
|
|
|
fn = jit.trace(symbolic=symbolic)(fn) |
|
|
|
for i in range(3): |
|
|
|
out = fn(mge_x) |
|
|
|
np.testing.assert_equal(out.numpy(), out_ref) |
|
|
|
|
|
|
|
def test1(x): |
|
|
|
x[[1, 2, 3]] = x[[1, 2, 3]] |
|
|
|
return x |
|
|
|
|
|
|
|
def test2(x): |
|
|
|
x[[1, 2, 3], [1, 2, 3]] = x[[1, 2, 3], [1, 2, 3]] |
|
|
|
return x |
|
|
|
|
|
|
|
def test3(x): |
|
|
|
x[[]] = x[[]] |
|
|
|
return x |
|
|
|
|
|
|
|
def test4(x): |
|
|
|
x[[], [], []] = x[[], [], []] |
|
|
|
return x |
|
|
|
|
|
|
|
run_test((10, 10, 0), test1) |
|
|
|
run_test((10, 10, 0), test2) |
|
|
|
run_test((10, 10, 0), test3) |
|
|
|
run_test((10, 10, 0), test4) |
|
|
|
run_test((10, 10, 10), test3) |
|
|
|
run_test((10, 10, 10), test4) |