|
|
@@ -279,7 +279,16 @@ class trace: |
|
|
|
# Const op is represented by a str |
|
|
|
assert isinstance(op_, str) and op_ == "Const" |
|
|
|
|
|
|
|
eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy()) |
|
|
|
expected = self._tinfo[ohandles[0]].bound_data.numpy() |
|
|
|
shape = value.shape |
|
|
|
if shape != expected.shape or dtype != expected.dtype: |
|
|
|
eq = False |
|
|
|
elif shape == (): |
|
|
|
eq = expected.item() == value.item() |
|
|
|
elif shape == (1,): |
|
|
|
eq = expected[0] == value[0] |
|
|
|
else: |
|
|
|
eq = np.all(value == expected) |
|
|
|
if not eq: |
|
|
|
raise TraceMismatchError( |
|
|
|
"const tensor violated: got a different tensor this time" |
|
|
|