Browse Source

fix(mge/batch_norm): fix batch_norm check when trace(symbolic=True)

GitOrigin-RevId: 2032eb5f7d
release-1.1
Megvii Engine Team 4 years ago
parent
commit
7cd846a570
2 changed files with 25 additions and 2 deletions
  1. +1
    -1
      imperative/python/megengine/jit/tracing.py
  2. +24
    -1
      imperative/python/test/integration/test_bn.py

+ 1
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -394,6 +394,7 @@ class trace:

def _apply_graph_options(self, graph):

graph.options.no_force_inplace = True
graph.options.seq_opt.enable_seq_comp_node_opt = False
# graph opt level
if self._graph_opt_level is not None:
@@ -417,7 +418,6 @@ class trace:

def _compile(self):
graph = self._graph = G.Graph()
graph.options.no_force_inplace = True
graph.options.async_exec_level = 0b100
self._apply_graph_options(graph)
# graph.options.graph_opt_level = 0


+ 24
- 1
imperative/python/test/integration/test_bn.py View File

@@ -13,7 +13,8 @@ import megengine
import megengine.autodiff as ad
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.module import BatchNorm2d
from megengine.jit import trace
from megengine.module import BatchNorm2d, Module


def test_frozen_bn():
@@ -89,3 +90,25 @@ def test_bn_no_track_stat3():
data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with pytest.raises(Exception):
m(data)


def test_trace_bn_forward_twice():
class Simple(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(1)

def forward(self, inp):
x = self.bn(inp)
x = self.bn(x)
return x

@trace(symbolic=True)
def train_bn(inp, net=None):
net.train()
pred = net(inp)
return pred

x = np.ones((1, 1, 32, 32), dtype=np.float32)
y = train_bn(x, net=Simple())
np.testing.assert_equal(y.numpy(), 0)

Loading…
Cancel
Save