You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_module.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import numpy as np
  2. import megengine.module as M
  3. from megengine import Tensor
  4. from megengine.traced_module import TracedModule, trace_module
  5. class MyModule1(M.Module):
  6. def forward(self, x):
  7. y = Tensor(x)
  8. y += 1
  9. x = x + 2
  10. return x, y
  11. class MyModule2(M.Module):
  12. def forward(self, x):
  13. y = Tensor([1, x, 1])
  14. y += 1
  15. x = x + 2
  16. return x, y
  17. class MyModule3(M.Module):
  18. def __init__(self):
  19. super().__init__()
  20. self.modules = [
  21. M.Elemwise("ADD"),
  22. M.Elemwise("ADD"),
  23. {"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")},
  24. ]
  25. def forward(self, a, b):
  26. x = self.modules[0](a, b)
  27. y = self.modules[1](a, b)
  28. y = self.modules[2]["a"](x, y)
  29. y = self.modules[2]["b"](x, y)
  30. return y
  31. def test_trace_module():
  32. x = Tensor(1)
  33. m1 = MyModule1()
  34. tm1 = trace_module(m1, x)
  35. m2 = MyModule2()
  36. tm2 = trace_module(m2, x)
  37. inp = Tensor(2)
  38. gt = m1(inp)
  39. output = tm1(inp)
  40. for a, b in zip(output, gt):
  41. np.testing.assert_equal(a.numpy(), b.numpy())
  42. gt1 = m2(inp)
  43. output1 = tm2(inp)
  44. for a, b in zip(output1, gt1):
  45. np.testing.assert_equal(a.numpy(), b.numpy())
  46. a, b = Tensor(1), Tensor(2)
  47. m3 = MyModule3()
  48. gt = m3(a, b)
  49. tm3 = trace_module(m3, a, b)
  50. out = tm3(a, b)
  51. np.testing.assert_equal(out.numpy(), gt.numpy())
  52. assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
  53. assert isinstance(tm3.modules.__dict__["2"], TracedModule)
  54. assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台