You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_converge.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # -*- coding: utf-8 -*-
  2. import itertools
  3. import numpy as np
  4. import pytest
  5. import megengine as mge
  6. import megengine.autodiff as ad
  7. import megengine.functional as F
  8. from megengine import Tensor
  9. from megengine.module import Linear, Module
  10. from megengine.optimizer import SGD
  11. from megengine.traced_module import trace_module
  12. batch_size = 64
  13. data_shape = (batch_size, 2)
  14. label_shape = (batch_size,)
  15. def minibatch_generator():
  16. while True:
  17. inp_data = np.zeros((batch_size, 2))
  18. label = np.zeros(batch_size, dtype=np.int32)
  19. for i in range(batch_size):
  20. # [x0, x1], sampled from U[-1, 1]
  21. inp_data[i, :] = np.random.rand(2) * 2 - 1
  22. label[i] = 0 if np.prod(inp_data[i]) < 0 else 1
  23. yield inp_data.astype(np.float32), label.astype(np.int32)
  24. def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float:
  25. """ Calculate precision for given data and prediction.
  26. :type data: [[x, y], ...]
  27. :param data: Input data
  28. :type pred: [[x_pred, y_pred], ...]
  29. :param pred: Network output data
  30. """
  31. correct = 0
  32. assert len(data) == len(pred)
  33. for inp_data, pred_output in zip(data, pred):
  34. label = 0 if np.prod(inp_data) < 0 else 1
  35. pred_label = np.argmax(pred_output)
  36. if pred_label == label:
  37. correct += 1
  38. return float(correct) / len(data)
  39. class XORNet(Module):
  40. def __init__(self):
  41. self.mid_layers = 14
  42. self.num_class = 2
  43. super().__init__()
  44. self.fc0 = Linear(self.num_class, self.mid_layers, bias=True)
  45. self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True)
  46. self.fc2 = Linear(self.mid_layers, self.num_class, bias=True)
  47. def forward(self, x):
  48. x = self.fc0(x)
  49. x = F.tanh(x)
  50. x = self.fc1(x)
  51. x = F.tanh(x)
  52. x = self.fc2(x)
  53. return x
  54. @pytest.mark.parametrize("test_traced_module", [True, False])
  55. def test_training_converge(test_traced_module):
  56. net = XORNet()
  57. if test_traced_module:
  58. inp = Tensor(np.random.random((14, 2)))
  59. net = trace_module(net, inp)
  60. opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
  61. gm = ad.GradManager().attach(net.parameters())
  62. def train(data, label):
  63. with gm:
  64. pred = net(data)
  65. loss = F.nn.cross_entropy(pred, label)
  66. gm.backward(loss)
  67. return loss
  68. def infer(data):
  69. return net(data)
  70. train_dataset = minibatch_generator()
  71. losses = []
  72. for data, label in itertools.islice(train_dataset, 2000):
  73. data = Tensor(data, dtype=np.float32)
  74. label = Tensor(label, dtype=np.int32)
  75. opt.clear_grad()
  76. loss = train(data, label)
  77. opt.step()
  78. losses.append(loss.numpy())
  79. assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
  80. ngrid = 10
  81. x = np.linspace(-1.0, 1.0, ngrid)
  82. xx, yy = np.meshgrid(x, x)
  83. xx = xx.reshape((ngrid * ngrid, 1))
  84. yy = yy.reshape((ngrid * ngrid, 1))
  85. data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))
  86. pred = infer(data)
  87. precision = calculate_precision(data.numpy(), pred.numpy())
  88. assert precision == 1.0, "Test precision must be high enough, get {}".format(
  89. precision
  90. )