You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

xornet.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import numpy as np
  2. import megengine as mge
  3. import megengine.autodiff as ad
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.optimizer as optim
  7. from megengine.jit import trace
  8. def minibatch_generator(batch_size):
  9. while True:
  10. inp_data = np.zeros((batch_size, 2))
  11. label = np.zeros(batch_size, dtype=np.int32)
  12. for i in range(batch_size):
  13. inp_data[i, :] = np.random.rand(2) * 2 - 1
  14. label[i] = 1 if np.prod(inp_data[i]) < 0 else 0
  15. yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)}
  16. class XORNet(M.Module):
  17. def __init__(self):
  18. self.mid_dim = 14
  19. self.num_class = 2
  20. super().__init__()
  21. self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
  22. self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
  23. self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)
  24. def forward(self, x):
  25. x = self.fc0(x)
  26. x = F.tanh(x)
  27. x = self.fc1(x)
  28. x = F.tanh(x)
  29. x = self.fc2(x)
  30. return x
  31. def main():
  32. if not mge.is_cuda_available():
  33. mge.set_default_device("cpux")
  34. net = XORNet()
  35. gm = ad.GradManager().attach(net.parameters())
  36. opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
  37. batch_size = 64
  38. train_dataset = minibatch_generator(batch_size)
  39. val_dataset = minibatch_generator(batch_size)
  40. def train_fun(data, label):
  41. opt.clear_grad()
  42. with gm:
  43. pred = net(data)
  44. loss = F.loss.cross_entropy(pred, label)
  45. gm.backward(loss)
  46. opt.step()
  47. return pred, loss
  48. def val_fun(data, label):
  49. pred = net(data)
  50. loss = F.loss.cross_entropy(pred, label)
  51. return pred, loss
  52. @trace(symbolic=True, capture_as_const=True)
  53. def pred_fun(data):
  54. pred = net(data)
  55. pred_normalized = F.softmax(pred)
  56. return pred_normalized
  57. data = np.random.random((batch_size, 2)).astype(np.float32)
  58. label = np.zeros((batch_size,)).astype(np.int32)
  59. train_loss = []
  60. val_loss = []
  61. for step, minibatch in enumerate(train_dataset):
  62. if step > 1000:
  63. break
  64. data = mge.tensor(minibatch["data"])
  65. label = mge.tensor(minibatch["label"])
  66. net.train()
  67. _, loss = train_fun(data, label)
  68. train_loss.append((step, loss.numpy()))
  69. if step % 50 == 0:
  70. minibatch = next(val_dataset)
  71. net.eval()
  72. _, loss = val_fun(data, label)
  73. loss = loss.numpy()
  74. val_loss.append((step, loss))
  75. print("Step: {} loss={}".format(step, loss))
  76. opt.step()
  77. test_data = np.array(
  78. [
  79. (0.5, 0.5),
  80. (0.3, 0.7),
  81. (0.1, 0.9),
  82. (-0.5, -0.5),
  83. (-0.3, -0.7),
  84. (-0.9, -0.1),
  85. (0.5, -0.5),
  86. (0.3, -0.7),
  87. (0.9, -0.1),
  88. (-0.5, 0.5),
  89. (-0.3, 0.7),
  90. (-0.1, 0.9),
  91. ]
  92. )
  93. # tracing only accepts tensor as input
  94. data = mge.tensor(test_data, dtype=np.float32)
  95. net.eval()
  96. out = pred_fun(data)
  97. pred_output = out.numpy()
  98. pred_label = np.argmax(pred_output, 1)
  99. print("Test data")
  100. print(test_data)
  101. with np.printoptions(precision=4, suppress=True):
  102. print("Predicated probability:")
  103. print(pred_output)
  104. print("Predicated label")
  105. print(pred_label)
  106. model_name = "xornet_deploy.mge"
  107. print("Dump model as {}".format(model_name))
  108. pred_fun.dump(model_name, arg_names=["data"])
  109. model_with_testcase_name = "xornet_with_testcase.mge"
  110. print("Dump model with testcase as {}".format(model_with_testcase_name))
  111. pred_fun.dump(model_with_testcase_name, arg_names=["data"], input_data=["#rand(0.1, 0.8, 4, 2)"])
  112. if __name__ == "__main__":
  113. main()