|
|
@@ -266,3 +266,20 @@ def test_trace_cvt_bool(): |
|
|
|
|
|
|
|
for i in range(3): |
|
|
|
np.testing.assert_equal(f(x).numpy()[0], False) |
|
|
|
|
|
|
|
|
|
|
|
def test_trace_reshape(): |
|
|
|
for symbolic in [False, True]: |
|
|
|
set_tensor_shape(True) |
|
|
|
x1 = tensor(np.random.randn(2, 10, 10)) |
|
|
|
x2 = tensor(np.random.randn(4, 10, 10)) |
|
|
|
x3 = tensor(np.random.randn(8, 10, 10)) |
|
|
|
|
|
|
|
@trace(symbolic=symbolic, capture_as_const=True) |
|
|
|
def f(x): |
|
|
|
y = x.reshape(x.shape[0], 100) |
|
|
|
return y |
|
|
|
|
|
|
|
f(x1) |
|
|
|
f(x2) |
|
|
|
f(x3) |