|
|
@@ -223,6 +223,34 @@ def test_reshape(is_varnode): |
|
|
|
np.testing.assert_equal(yy.numpy(), y) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("is_varnode", [True, False]) |
|
|
|
def test_broadcast_auto_infer(is_varnode): |
|
|
|
if is_varnode: |
|
|
|
network = Network() |
|
|
|
else: |
|
|
|
network = None |
|
|
|
|
|
|
|
x = np.random.random((1, 2, 3)).astype(np.float32) |
|
|
|
xx = make_tensor(x, network) |
|
|
|
|
|
|
|
for shape in [ |
|
|
|
(1, 2, 3), |
|
|
|
(1, None, 3), |
|
|
|
]: |
|
|
|
yy = F.broadcast_to(xx, shape) |
|
|
|
np.testing.assert_equal(yy.numpy(), x) |
|
|
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
|
F.broadcast_to(xx, (1, -1, 3)) |
|
|
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
|
F.broadcast_to(xx, (None, 1, 2, 3)) |
|
|
|
|
|
|
|
F.broadcast_to(xx, (1, None, 2, 3)) |
|
|
|
t = tensor(2, dtype=np.int32) |
|
|
|
F.broadcast_to(xx, (t, None, 2, 3)) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("is_trace", [True, False]) |
|
|
|
def test_reshape_on_empty_tensor(is_trace): |
|
|
|
input1_shape = (100, 0, 1) |
|
|
|