You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

xornet.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import numpy as np
  2. import megengine as mge
  3. import megengine.functional as F
  4. import megengine.module as M
  5. import megengine.optimizer as optim
  6. from megengine.jit import trace
  7. def minibatch_generator(batch_size):
  8. while True:
  9. inp_data = np.zeros((batch_size, 2))
  10. label = np.zeros(batch_size, dtype=np.int32)
  11. for i in range(batch_size):
  12. inp_data[i, :] = np.random.rand(2) * 2 - 1
  13. label[i] = 1 if np.prod(inp_data[i]) < 0 else 0
  14. yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)}
  15. class XORNet(M.Module):
  16. def __init__(self):
  17. self.mid_dim = 14
  18. self.num_class = 2
  19. super().__init__()
  20. self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True)
  21. self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True)
  22. self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True)
  23. def forward(self, x):
  24. x = self.fc0(x)
  25. x = F.tanh(x)
  26. x = self.fc1(x)
  27. x = F.tanh(x)
  28. x = self.fc2(x)
  29. return x
  30. @trace(symbolic=True)
  31. def train_fun(data, label, net=None, opt=None):
  32. net.train()
  33. pred = net(data)
  34. loss = F.cross_entropy_with_softmax(pred, label)
  35. opt.backward(loss)
  36. return pred, loss
  37. @trace(symbolic=True)
  38. def val_fun(data, label, net=None):
  39. net.eval()
  40. pred = net(data)
  41. loss = F.cross_entropy_with_softmax(pred, label)
  42. return pred, loss
  43. @trace(symbolic=True)
  44. def pred_fun(data, net=None):
  45. net.eval()
  46. pred = net(data)
  47. pred_normalized = F.softmax(pred)
  48. return pred_normalized
  49. def main():
  50. if not mge.is_cuda_available():
  51. mge.set_default_device("cpux")
  52. net = XORNet()
  53. opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9)
  54. batch_size = 64
  55. train_dataset = minibatch_generator(batch_size)
  56. val_dataset = minibatch_generator(batch_size)
  57. data = mge.tensor()
  58. label = mge.tensor(np.zeros((batch_size,)), dtype=np.int32)
  59. train_loss = []
  60. val_loss = []
  61. for step, minibatch in enumerate(train_dataset):
  62. if step > 1000:
  63. break
  64. data.set_value(minibatch["data"])
  65. label.set_value(minibatch["label"])
  66. opt.zero_grad()
  67. _, loss = train_fun(data, label, net=net, opt=opt)
  68. train_loss.append((step, loss.numpy()))
  69. if step % 50 == 0:
  70. minibatch = next(val_dataset)
  71. _, loss = val_fun(data, label, net=net)
  72. loss = loss.numpy()[0]
  73. val_loss.append((step, loss))
  74. print("Step: {} loss={}".format(step, loss))
  75. opt.step()
  76. test_data = np.array(
  77. [
  78. (0.5, 0.5),
  79. (0.3, 0.7),
  80. (0.1, 0.9),
  81. (-0.5, -0.5),
  82. (-0.3, -0.7),
  83. (-0.9, -0.1),
  84. (0.5, -0.5),
  85. (0.3, -0.7),
  86. (0.9, -0.1),
  87. (-0.5, 0.5),
  88. (-0.3, 0.7),
  89. (-0.1, 0.9),
  90. ]
  91. )
  92. data.set_value(test_data)
  93. out = pred_fun(data, net=net)
  94. pred_output = out.numpy()
  95. pred_label = np.argmax(pred_output, 1)
  96. print("Test data")
  97. print(test_data)
  98. with np.printoptions(precision=4, suppress=True):
  99. print("Predicated probability:")
  100. print(pred_output)
  101. print("Predicated label")
  102. print(pred_label)
  103. model_name = "xornet_deploy.mge"
  104. if pred_fun.enabled:
  105. print("Dump model as {}".format(model_name))
  106. pred_fun.dump(model_name, arg_names=["data"])
  107. else:
  108. print("pred_fun must be run with trace enabled in order to dump model")
  109. if __name__ == "__main__":
  110. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)