Browse Source

fix(mge/jit): error out if dump bn in training mode

GitOrigin-RevId: edc7ea2962
release-1.4
Megvii Engine Team 4 years ago
parent
commit
acb239927b
2 changed files with 18 additions and 1 deletions
  1. +5
    -1
      imperative/python/megengine/jit/tracing.py
  2. +13
    -0
      imperative/python/test/integration/test_trace_dump.py

+ 5
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -36,7 +36,7 @@ from ..core._imperative_rt.ops import (
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
@@ -833,6 +833,10 @@ class trace:
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
if isinstance(op, BatchNorm):
assert (
op.fwd_mode == BatchNorm.FwdMode.INFERENCE
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
ovars = G.apply_normal_varnode(op, *ivars)

AutoNaming.record_opnode(ovars[0].op)


+ 13
- 0
imperative/python/test/integration/test_trace_dump.py View File

@@ -11,6 +11,7 @@ import os
import tempfile

import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
@@ -140,3 +141,15 @@ def test_xornet_trace_dump():

with mkstemp() as out:
pred_fun.dump(out, arg_names=["data"], output_names=["label"])


def test_dump_bn_train_mode():
@trace(symbolic=True, capture_as_const=True)
def bn_train(data):
pred = M.BatchNorm2d(10)(data).sum()
return pred

data = mge.tensor(np.random.random((10, 10, 10, 10)))
bn_train(data)
with pytest.raises(AssertionError):
bn_train.dump("test.mge")

Loading…
Cancel
Save