You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_converge.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # -*- coding: utf-8 -*-
  2. import itertools
  3. import numpy as np
  4. import pytest
  5. import megengine as mge
  6. import megengine.autodiff as ad
  7. import megengine.functional as F
  8. import megengine.optimizer as optim
  9. from megengine import Tensor
  10. from megengine.core import set_option
  11. from megengine.module import Linear, Module
  12. from megengine.optimizer import SGD
  13. from megengine.traced_module import trace_module
  14. batch_size = 64
  15. data_shape = (batch_size, 2)
  16. label_shape = (batch_size,)
  17. def minibatch_generator():
  18. while True:
  19. inp_data = np.zeros((batch_size, 2))
  20. label = np.zeros(batch_size, dtype=np.int32)
  21. for i in range(batch_size):
  22. # [x0, x1], sampled from U[-1, 1]
  23. inp_data[i, :] = np.random.rand(2) * 2 - 1
  24. label[i] = 0 if np.prod(inp_data[i]) < 0 else 1
  25. yield inp_data.astype(np.float32), label.astype(np.int32)
  26. def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float:
  27. """ Calculate precision for given data and prediction.
  28. :type data: [[x, y], ...]
  29. :param data: Input data
  30. :type pred: [[x_pred, y_pred], ...]
  31. :param pred: Network output data
  32. """
  33. correct = 0
  34. assert len(data) == len(pred)
  35. for inp_data, pred_output in zip(data, pred):
  36. label = 0 if np.prod(inp_data) < 0 else 1
  37. pred_label = np.argmax(pred_output)
  38. if pred_label == label:
  39. correct += 1
  40. return float(correct) / len(data)
  41. class XORNet(Module):
  42. def __init__(self):
  43. self.mid_layers = 14
  44. self.num_class = 2
  45. super().__init__()
  46. self.fc0 = Linear(self.num_class, self.mid_layers, bias=True)
  47. self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True)
  48. self.fc2 = Linear(self.mid_layers, self.num_class, bias=True)
  49. def forward(self, x):
  50. x = self.fc0(x)
  51. x = F.tanh(x)
  52. x = self.fc1(x)
  53. x = F.tanh(x)
  54. x = self.fc2(x)
  55. return x
  56. @pytest.mark.parametrize(
  57. "test_traced_module, with_drop, grad_clip",
  58. [(False, False, False), (True, True, True)],
  59. )
  60. def test_training_converge(test_traced_module, with_drop, grad_clip):
  61. if with_drop:
  62. set_option("enable_drop", 1)
  63. net = XORNet()
  64. if test_traced_module:
  65. inp = Tensor(np.random.random((14, 2)))
  66. net = trace_module(net, inp)
  67. opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
  68. gm = ad.GradManager().attach(net.parameters())
  69. def train(data, label):
  70. with gm:
  71. pred = net(data)
  72. loss = F.nn.cross_entropy(pred, label)
  73. gm.backward(loss)
  74. if grad_clip:
  75. optim.clip_grad_norm(net.parameters(), max_norm=0.2, ord=2.0)
  76. return loss
  77. def infer(data):
  78. return net(data)
  79. train_dataset = minibatch_generator()
  80. losses = []
  81. for data, label in itertools.islice(train_dataset, 1500):
  82. data = Tensor(data, dtype=np.float32)
  83. label = Tensor(label, dtype=np.int32)
  84. opt.clear_grad()
  85. loss = train(data, label)
  86. if grad_clip:
  87. optim.clip_grad_value(net.parameters(), lower=-0.1, upper=0.1)
  88. opt.step()
  89. losses.append(loss.numpy())
  90. assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
  91. ngrid = 10
  92. x = np.linspace(-1.0, 1.0, ngrid)
  93. xx, yy = np.meshgrid(x, x)
  94. xx = xx.reshape((ngrid * ngrid, 1))
  95. yy = yy.reshape((ngrid * ngrid, 1))
  96. data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
  97. pred = infer(data)
  98. precision = calculate_precision(data.numpy(), pred.numpy())
  99. assert precision == 1.0, "Test precision must be high enough, get {}".format(
  100. precision
  101. )
  102. if with_drop:
  103. set_option("enable_drop", 0)