|
|
@@ -10,6 +10,7 @@ import weakref |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ..core._imperative_rt import GraphProfiler |
|
|
|
from ..core._imperative_rt.ops import OprAttr |
|
|
|
from ..core.ops.special import Const |
|
|
|
from ..core.tensor import megbrain_graph as G |
|
|
|
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply |
|
|
@@ -127,8 +128,12 @@ class trace: |
|
|
|
record = self._seq[self._pc] |
|
|
|
op_, ihandles, ohandles = record |
|
|
|
if op != op_: |
|
|
|
if op.type == "UniformRNG": |
|
|
|
pass |
|
|
|
# FIXME: will be removed once better rng implementation is done |
|
|
|
if isinstance(op, OprAttr) and ( |
|
|
|
op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type |
|
|
|
): |
|
|
|
if op.param[8:] != op_.param[8:]: |
|
|
|
raise TraceMismatchError("op different from last time") |
|
|
|
else: |
|
|
|
raise TraceMismatchError("op different from last time") |
|
|
|
if len(ihandles) != len(args): |
|
|
|