Browse Source

perf(trace): add fastpath for const value assert

GitOrigin-RevId: 9a966f257f
release-1.5
Megvii Engine Team huangxinda 3 years ago
parent
commit
a95f6d4f75
1 changed files with 10 additions and 1 deletions
  1. +10
    -1
      imperative/python/megengine/jit/tracing.py

+ 10
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -279,7 +279,16 @@ class trace:
# Const op is represented by a str
assert isinstance(op_, str) and op_ == "Const"

eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
expected = self._tinfo[ohandles[0]].bound_data.numpy()
shape = value.shape
if shape != expected.shape or dtype != expected.dtype:
eq = False
elif shape == ():
eq = expected.item() == value.item()
elif shape == (1,):
eq = expected[0] == value[0]
else:
eq = np.all(value == expected)
if not eq:
raise TraceMismatchError(
"const tensor violated: got a different tensor this time"


Loading…
Cancel
Save