Browse Source

fix(mge/batch_norm): fix batch_norm check when trace(symbolic=True)

GitOrigin-RevId: 2032eb5f7d
release-1.1
Megvii Engine Team 4 years ago
parent
commit
7cd846a570
2 changed files with 25 additions and 2 deletions
  1. +1
    -1
      imperative/python/megengine/jit/tracing.py
  2. +24
    -1
      imperative/python/test/integration/test_bn.py

+ 1
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -394,6 +394,7 @@ class trace:


def _apply_graph_options(self, graph): def _apply_graph_options(self, graph):


graph.options.no_force_inplace = True
graph.options.seq_opt.enable_seq_comp_node_opt = False graph.options.seq_opt.enable_seq_comp_node_opt = False
# graph opt level # graph opt level
if self._graph_opt_level is not None: if self._graph_opt_level is not None:
@@ -417,7 +418,6 @@ class trace:


def _compile(self): def _compile(self):
graph = self._graph = G.Graph() graph = self._graph = G.Graph()
graph.options.no_force_inplace = True
graph.options.async_exec_level = 0b100 graph.options.async_exec_level = 0b100
self._apply_graph_options(graph) self._apply_graph_options(graph)
# graph.options.graph_opt_level = 0 # graph.options.graph_opt_level = 0


+ 24
- 1
imperative/python/test/integration/test_bn.py View File

@@ -13,7 +13,8 @@ import megengine
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.module import BatchNorm2d
from megengine.jit import trace
from megengine.module import BatchNorm2d, Module




def test_frozen_bn(): def test_frozen_bn():
@@ -89,3 +90,25 @@ def test_bn_no_track_stat3():
data = np.random.random((6, nchannel, 2, 2)).astype("float32") data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with pytest.raises(Exception): with pytest.raises(Exception):
m(data) m(data)


def test_trace_bn_forward_twice():
class Simple(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(1)

def forward(self, inp):
x = self.bn(inp)
x = self.bn(x)
return x

@trace(symbolic=True)
def train_bn(inp, net=None):
net.train()
pred = net(inp)
return pred

x = np.ones((1, 1, 32, 32), dtype=np.float32)
y = train_bn(x, net=Simple())
np.testing.assert_equal(y.numpy(), 0)

Loading…
Cancel
Save