GitOrigin-RevId: 7c04a9efba
tags/v1.0.0-rc1
@@ -74,9 +74,7 @@ def _reshape(x, shape): | |||||
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) | ||||
unspec_axis = i | unspec_axis = i | ||||
if not isinstance(shape, (TensorBase, TensorWrapperBase)): | |||||
# TODO: device should be None (cpu) | |||||
(shape,) = Const(shape, dtype=np.int32, device=x.device)(x) | |||||
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device) | |||||
if unspec_axis is None: | if unspec_axis is None: | ||||
op = builtin.Reshape() | op = builtin.Reshape() | ||||
@@ -266,3 +266,20 @@ def test_trace_cvt_bool(): | |||||
for i in range(3): | for i in range(3): | ||||
np.testing.assert_equal(f(x).numpy()[0], False) | np.testing.assert_equal(f(x).numpy()[0], False) | ||||
def test_trace_reshape(): | |||||
for symbolic in [False, True]: | |||||
set_tensor_shape(True) | |||||
x1 = tensor(np.random.randn(2, 10, 10)) | |||||
x2 = tensor(np.random.randn(4, 10, 10)) | |||||
x3 = tensor(np.random.randn(8, 10, 10)) | |||||
@trace(symbolic=symbolic, capture_as_const=True) | |||||
def f(x): | |||||
y = x.reshape(x.shape[0], 100) | |||||
return y | |||||
f(x1) | |||||
f(x2) | |||||
f(x3) |