You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_dump.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import contextlib
  2. import os
  3. import tempfile
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.functional as F
  8. import megengine.module as M
  9. import megengine.optimizer as optim
  10. from megengine import tensor
  11. from megengine.autodiff import GradManager
  12. from megengine.jit import trace
  13. @contextlib.contextmanager
  14. def mkstemp():
  15. fd, path = tempfile.mkstemp()
  16. try:
  17. os.close(fd)
  18. yield path
  19. finally:
  20. os.remove(path)
  21. def minibatch_generator(batch_size):
  22. while True:
  23. inp_data = np.zeros((batch_size, 2))
  24. label = np.zeros(batch_size, dtype=np.int32)
  25. for i in range(batch_size):
  26. inp_data[i, :] = np.random.rand(2) * 2 - 1
  27. label[i] = 1 if np.prod(inp_data[i]) < 0 else 0
  28. yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)}
  29. class XORNet(M.Module):
  30. def __init__(self):
  31. self.mid_dim = 14
  32. self.num_class = 2
  33. super().__init__()
  34. self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
  35. self.bn0 = M.BatchNorm1d(self.mid_dim)
  36. self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
  37. self.bn1 = M.BatchNorm1d(self.mid_dim)
  38. self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)
  39. def forward(self, x):
  40. x = self.fc0(x)
  41. x = self.bn0(x)
  42. x = F.tanh(x)
  43. x = self.fc1(x)
  44. x = self.bn1(x)
  45. x = F.tanh(x)
  46. x = self.fc2(x)
  47. return x
  48. def test_xornet_trace_dump():
  49. net = XORNet()
  50. opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
  51. gm = GradManager().attach(net.parameters())
  52. batch_size = 64
  53. train_dataset = minibatch_generator(batch_size)
  54. val_dataset = minibatch_generator(batch_size)
  55. @trace
  56. def train_fun(data, label):
  57. with gm:
  58. net.train()
  59. pred = net(data)
  60. loss = F.nn.cross_entropy(pred, label)
  61. gm.backward(loss)
  62. return pred, loss
  63. @trace
  64. def val_fun(data, label):
  65. net.eval()
  66. pred = net(data)
  67. loss = F.nn.cross_entropy(pred, label)
  68. return pred, loss
  69. @trace(symbolic=True, capture_as_const=True)
  70. def pred_fun(data):
  71. net.eval()
  72. pred = net(data)
  73. pred_normalized = F.softmax(pred)
  74. return pred_normalized
  75. train_loss = []
  76. val_loss = []
  77. for step, minibatch in enumerate(train_dataset):
  78. if step > 100:
  79. break
  80. data = tensor(minibatch["data"])
  81. label = tensor(minibatch["label"])
  82. opt.clear_grad()
  83. _, loss = train_fun(data, label)
  84. train_loss.append((step, loss.numpy()))
  85. if step % 50 == 0:
  86. minibatch = next(val_dataset)
  87. _, loss = val_fun(data, label)
  88. loss = loss.numpy()
  89. val_loss.append((step, loss))
  90. opt.step()
  91. test_data = np.array(
  92. [
  93. (0.5, 0.5),
  94. (0.3, 0.7),
  95. (0.1, 0.9),
  96. (-0.5, -0.5),
  97. (-0.3, -0.7),
  98. (-0.9, -0.1),
  99. (0.5, -0.5),
  100. (0.3, -0.7),
  101. (0.9, -0.1),
  102. (-0.5, 0.5),
  103. (-0.3, 0.7),
  104. (-0.1, 0.9),
  105. ]
  106. )
  107. data = tensor(test_data.astype(np.float32))
  108. out = pred_fun(data)
  109. with mkstemp() as out:
  110. pred_fun.dump(out, arg_names=["data"], output_names=["label"])
  111. def test_dump_bn_train_mode():
  112. @trace(symbolic=True, capture_as_const=True)
  113. def bn_train(data):
  114. pred = M.BatchNorm2d(10)(data).sum()
  115. return pred
  116. data = mge.tensor(np.random.random((10, 10, 10, 10)))
  117. bn_train(data)
  118. with pytest.raises(RuntimeError):
  119. bn_train.dump("test.mge")