Browse Source

refactor(mge/jit): skip seed when checking equal rng op

GitOrigin-RevId: dae2086b36
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
2df1ab9633
1 changed files with 7 additions and 2 deletions
  1. +7
    -2
      imperative/python/megengine/jit/tracing.py

+ 7
- 2
imperative/python/megengine/jit/tracing.py View File

@@ -10,6 +10,7 @@ import weakref
import numpy as np

from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
@@ -127,8 +128,12 @@ class trace:
record = self._seq[self._pc]
op_, ihandles, ohandles = record
if op != op_:
if op.type == "UniformRNG":
pass
# FIXME: will be removed once better rng implementation is done
if isinstance(op, OprAttr) and (
op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type
):
if op.param[8:] != op_.param[8:]:
raise TraceMismatchError("op different from last time")
else:
raise TraceMismatchError("op different from last time")
if len(ihandles) != len(args):


Loading…
Cancel
Save