You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pytorch.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import torch
  11. from helpers import randomTorch
  12. import megengine as mge
  13. import megengine._internal as mgb
  14. import megengine.functional
  15. import megengine.optimizer as optimizer
  16. from megengine import get_default_device, set_default_device
  17. from megengine.core import Parameter, tensor
  18. from megengine.jit import trace
  19. from megengine.module import Module as MGEModule
  20. from megengine.module.pytorch import PyTorchModule
  21. from megengine.test import assertTensorClose
  22. def test_pytorch_forward():
  23. class APlusB(torch.nn.Module):
  24. def __init__(self):
  25. super(APlusB, self).__init__()
  26. def forward(self, a, b):
  27. return a + b
  28. a = randomTorch(15, 15)
  29. b = randomTorch(15, 15)
  30. def get_pytorch_forward():
  31. return APlusB()(a, b)
  32. def get_mge_forward():
  33. mge_module = PyTorchModule(APlusB())
  34. mge_a = tensor(a.numpy(), dtype=np.float32)
  35. mge_b = tensor(b.numpy(), dtype=np.float32)
  36. return mge_module(mge_a, mge_b)
  37. assertTensorClose(get_pytorch_forward().numpy(), get_mge_forward().numpy())
  38. def test_pytorch_backward():
  39. class APlusB(torch.nn.Module):
  40. def __init__(self):
  41. super(APlusB, self).__init__()
  42. def forward(self, a, b):
  43. return a + b
  44. a = randomTorch(15, 15)
  45. b = randomTorch(15, 15)
  46. def get_pytorch_backward():
  47. parameter_a = a.clone()
  48. parameter_a.requires_grad = True
  49. c = APlusB()(parameter_a, b)
  50. d = APlusB()(c, b)
  51. e = torch.sum(d)
  52. e.backward()
  53. return parameter_a.grad
  54. def get_mge_backward():
  55. mge_module = PyTorchModule(APlusB())
  56. mge_a = Parameter(a.numpy(), dtype=np.float32)
  57. mge_b = tensor(b.numpy(), dtype=np.float32)
  58. mge_c = mge_module(mge_a, mge_b)
  59. mge_d = mge_module(mge_c, mge_b)
  60. mge_e = mge.functional.sum(mge_d)
  61. return mge.functional.grad(mge_e, mge_a, use_virtual_grad=False)
  62. assertTensorClose(get_pytorch_backward().numpy(), get_mge_backward().numpy())
  63. def test_pytorch_mixed():
  64. init_param = (np.array([2.0], dtype=np.float32), np.array([3.0], dtype=np.float32))
  65. lr = 1.0
  66. class Mixed(MGEModule):
  67. class SubModule(torch.nn.Module):
  68. def __init__(self):
  69. super().__init__()
  70. self.multiplier = torch.nn.Parameter(torch.tensor(init_param[0]))
  71. def forward(self, inp):
  72. return inp * self.multiplier
  73. def __init__(self):
  74. super().__init__()
  75. self.torch_module = PyTorchModule(self.SubModule())
  76. self.multiplier = Parameter(init_param[1], dtype=np.float32)
  77. def forward(self, inp):
  78. return self.torch_module(inp) * self.multiplier
  79. def run(step, enable_trace, use_symbolic):
  80. def train_func(data, net=None, opt=None):
  81. pred = net(data)
  82. opt.backward(pred)
  83. return pred
  84. if enable_trace:
  85. train_func = trace(train_func, symbolic=use_symbolic)
  86. net = Mixed()
  87. data = tensor()
  88. opt = optimizer.SGD(net.parameters(), lr=lr)
  89. saved_param = init_param
  90. for i in range(step):
  91. opt.zero_grad()
  92. data.set_value([i + 1.0])
  93. output = train_func(data, net=net, opt=opt)
  94. opt.step()
  95. expect_param = (
  96. saved_param[0] - lr * saved_param[1] * data.numpy(),
  97. saved_param[1] - lr * saved_param[0] * data.numpy(),
  98. )
  99. assertTensorClose(
  100. output.numpy(), saved_param[0] * saved_param[1] * data.numpy()
  101. )
  102. torch_param = net.torch_module._torch_params[0].detach().cpu()
  103. assertTensorClose(torch_param.numpy(), expect_param[0])
  104. assertTensorClose(net.multiplier.numpy(), expect_param[1])
  105. saved_param = expect_param
  106. run(1, False, False)
  107. run(1, True, True)
  108. run(1, True, False)
  109. run(2, False, False)
  110. run(2, True, True)
  111. run(2, True, False)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)