|
|
@@ -36,7 +36,7 @@ from ..core._imperative_rt.ops import ( |
|
|
|
) |
|
|
|
from ..core._trace_option import set_symbolic_shape |
|
|
|
from ..core._wrap import device as as_device |
|
|
|
from ..core.ops.builtin import BackwardGraph, OpDef |
|
|
|
from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef |
|
|
|
from ..core.ops.special import Const |
|
|
|
from ..core.tensor import megbrain_graph as G |
|
|
|
from ..core.tensor.utils import setscalar |
|
|
@@ -833,6 +833,10 @@ class trace: |
|
|
|
if isinstance(op, BackwardGraph): |
|
|
|
ovars = G.apply_backward_varnode(op, *ivars) |
|
|
|
else: |
|
|
|
if isinstance(op, BatchNorm): |
|
|
|
assert ( |
|
|
|
op.fwd_mode == BatchNorm.FwdMode.INFERENCE |
|
|
|
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" |
|
|
|
ovars = G.apply_normal_varnode(op, *ivars) |
|
|
|
|
|
|
|
AutoNaming.record_opnode(ovars[0].op) |
|
|
|