Browse Source

fix(mge/tensor): fix const target shape in reshape

GitOrigin-RevId: 7c04a9efba
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
b111baf1b6
2 changed files with 18 additions and 3 deletions
  1. +1
    -3
      imperative/python/megengine/core/tensor/tensor_wrapper.py
  2. +17
    -0
      imperative/python/test/unit/test_tracing.py

+ 1
- 3
imperative/python/megengine/core/tensor/tensor_wrapper.py View File

@@ -74,9 +74,7 @@ def _reshape(x, shape):
raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i)) raise ValueError("multiple -1 in shape: {} & {}".format(unspec_axis, i))
unspec_axis = i unspec_axis = i


if not isinstance(shape, (TensorBase, TensorWrapperBase)):
# TODO: device should be None (cpu)
(shape,) = Const(shape, dtype=np.int32, device=x.device)(x)
shape = utils.astensor1d(shape, x, dtype="int32", device=x.device)


if unspec_axis is None: if unspec_axis is None:
op = builtin.Reshape() op = builtin.Reshape()


+ 17
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -266,3 +266,20 @@ def test_trace_cvt_bool():


for i in range(3): for i in range(3):
np.testing.assert_equal(f(x).numpy()[0], False) np.testing.assert_equal(f(x).numpy()[0], False)


def test_trace_reshape():
for symbolic in [False, True]:
set_tensor_shape(True)
x1 = tensor(np.random.randn(2, 10, 10))
x2 = tensor(np.random.randn(4, 10, 10))
x3 = tensor(np.random.randn(8, 10, 10))

@trace(symbolic=symbolic, capture_as_const=True)
def f(x):
y = x.reshape(x.shape[0], 100)
return y

f(x1)
f(x2)
f(x3)

Loading…
Cancel
Save