diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 8f617364..54b913d9 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -279,7 +279,16 @@ class trace: # Const op is represented by a str assert isinstance(op_, str) and op_ == "Const" - eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy()) + expected = self._tinfo[ohandles[0]].bound_data.numpy() + shape = value.shape + if shape != expected.shape or dtype != expected.dtype: + eq = False + elif shape == (): + eq = expected.item() == value.item() + elif shape == (1,): + eq = expected[0] == value[0] + else: + eq = np.all(value == expected) if not eq: raise TraceMismatchError( "const tensor violated: got a different tensor this time"