|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- import contextlib
- import os
- import tempfile
-
- import numpy as np
- import pytest
-
- import megengine as mge
- import megengine.functional as F
- import megengine.module as M
- import megengine.optimizer as optim
- from megengine import tensor
- from megengine.autodiff import GradManager
- from megengine.jit import trace
- from megengine.traced_module import trace_module
-
-
- @contextlib.contextmanager
- def mkstemp():
- fd, path = tempfile.mkstemp()
- try:
- os.close(fd)
- yield path
- finally:
- os.remove(path)
-
-
- def minibatch_generator(batch_size):
- while True:
- inp_data = np.zeros((batch_size, 2))
- label = np.zeros(batch_size, dtype=np.int32)
- for i in range(batch_size):
- inp_data[i, :] = np.random.rand(2) * 2 - 1
- label[i] = 1 if np.prod(inp_data[i]) < 0 else 0
- yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)}
-
-
- class XORNet(M.Module):
- def __init__(self):
- self.mid_dim = 14
- self.num_class = 2
- super().__init__()
- self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
- self.bn0 = M.BatchNorm1d(self.mid_dim)
- self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
- self.bn1 = M.BatchNorm1d(self.mid_dim)
- self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)
-
- def forward(self, x):
- x = self.fc0(x)
- x = self.bn0(x)
- x = F.tanh(x)
- x = self.fc1(x)
- x = self.bn1(x)
- x = F.tanh(x)
- x = self.fc2(x)
- return x
-
-
- def test_xornet_trace_dump():
- net = XORNet()
- opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
- gm = GradManager().attach(net.parameters())
- batch_size = 64
- train_dataset = minibatch_generator(batch_size)
- val_dataset = minibatch_generator(batch_size)
-
- @trace
- def train_fun(data, label):
- with gm:
- net.train()
- pred = net(data)
- loss = F.nn.cross_entropy(pred, label)
- gm.backward(loss)
- return pred, loss
-
- @trace
- def val_fun(data, label):
- net.eval()
- pred = net(data)
- loss = F.nn.cross_entropy(pred, label)
- return pred, loss
-
- @trace(symbolic=True, capture_as_const=True)
- def pred_fun(data):
- net.eval()
- pred = net(data)
- pred_normalized = F.softmax(pred)
- return pred_normalized
-
- train_loss = []
- val_loss = []
- for step, minibatch in enumerate(train_dataset):
- if step > 100:
- break
- data = tensor(minibatch["data"])
- label = tensor(minibatch["label"])
- opt.clear_grad()
- _, loss = train_fun(data, label)
- train_loss.append((step, loss.numpy()))
- if step % 50 == 0:
- minibatch = next(val_dataset)
- _, loss = val_fun(data, label)
- loss = loss.numpy()
- val_loss.append((step, loss))
- print("Step: {} loss={}".format(step, loss))
- opt.step()
-
- test_data = np.array(
- [
- (0.5, 0.5),
- (0.3, 0.7),
- (0.1, 0.9),
- (-0.5, -0.5),
- (-0.3, -0.7),
- (-0.9, -0.1),
- (0.5, -0.5),
- (0.3, -0.7),
- (0.9, -0.1),
- (-0.5, 0.5),
- (-0.3, 0.7),
- (-0.1, 0.9),
- ]
- )
-
- data = tensor(test_data.astype(np.float32))
- out = pred_fun(data)
-
- with mkstemp() as out:
- pred_fun.dump(out, arg_names=["data"], output_names=["label"])
-
-
- def test_dump_bn_train_mode():
- @trace(symbolic=True, capture_as_const=True)
- def bn_train(data):
- pred = M.BatchNorm2d(10)(data).sum()
- return pred
-
- data = mge.tensor(np.random.random((10, 10, 10, 10)))
- bn_train(data)
- with pytest.raises(RuntimeError):
- bn_train.dump("test.mge")
|