You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_module.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import numpy as np
  2. import megengine.functional as F
  3. import megengine.module as M
  4. from megengine import Tensor
  5. from megengine.traced_module import TracedModule, trace_module
  6. from megengine.traced_module.expr import CallFunction
  7. class MyModule1(M.Module):
  8. def forward(self, x):
  9. y = Tensor(x)
  10. y += 1
  11. x = x + 2
  12. return x, y
  13. class MyModule2(M.Module):
  14. def forward(self, x):
  15. y = Tensor([1, x, 1])
  16. y += 1
  17. x = x + 2
  18. return x, y
  19. class MyModule3(M.Module):
  20. def __init__(self):
  21. super().__init__()
  22. self.modules = [
  23. M.Elemwise("ADD"),
  24. M.Elemwise("ADD"),
  25. {"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")},
  26. ]
  27. def forward(self, a, b):
  28. x = self.modules[0](a, b)
  29. y = self.modules[1](a, b)
  30. y = self.modules[2]["a"](x, y)
  31. y = self.modules[2]["b"](x, y)
  32. return y
  33. class MyModule4(M.Module):
  34. def __init__(self):
  35. super().__init__()
  36. self.add = F.add
  37. def forward(self, x, y):
  38. return self.add(x, y)
  39. def test_trace_module():
  40. x = Tensor(1)
  41. m1 = MyModule1()
  42. tm1 = trace_module(m1, x)
  43. m2 = MyModule2()
  44. tm2 = trace_module(m2, x)
  45. inp = Tensor(2)
  46. gt = m1(inp)
  47. output = tm1(inp)
  48. for a, b in zip(output, gt):
  49. np.testing.assert_equal(a.numpy(), b.numpy())
  50. gt1 = m2(inp)
  51. output1 = tm2(inp)
  52. for a, b in zip(output1, gt1):
  53. np.testing.assert_equal(a.numpy(), b.numpy())
  54. a, b = Tensor(1), Tensor(2)
  55. m3 = MyModule3()
  56. gt = m3(a, b)
  57. tm3 = trace_module(m3, a, b)
  58. out = tm3(a, b)
  59. np.testing.assert_equal(out.numpy(), gt.numpy())
  60. assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
  61. assert isinstance(tm3.modules.__dict__["2"], TracedModule)
  62. assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
  63. m4 = MyModule4()
  64. tm4 = trace_module(m4, a, b)
  65. assert len(tm4.graph._exprs) == 1
  66. assert isinstance(tm4.graph._exprs[0], CallFunction)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台