You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_module.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from collections import OrderedDict
  2. import numpy as np
  3. import megengine.functional as F
  4. import megengine.module as M
  5. from megengine import Tensor
  6. from megengine.module.module import Module
  7. from megengine.traced_module import TracedModule, trace_module
  8. from megengine.traced_module.expr import CallFunction
  9. class MyModule1(M.Module):
  10. def forward(self, x):
  11. y = Tensor(x)
  12. y += 1
  13. x = x + 2
  14. return x, y
  15. class MyModule2(M.Module):
  16. def forward(self, x):
  17. y = Tensor([1, x, 1])
  18. y += 1
  19. x = x + 2
  20. return x, y
  21. class MyModule3(M.Module):
  22. def __init__(self):
  23. super().__init__()
  24. self.modules = [
  25. M.Elemwise("ADD"),
  26. M.Elemwise("ADD"),
  27. OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
  28. M.Elemwise("RELU"),
  29. M.Elemwise("RELU"),
  30. ]
  31. def forward(self, a, b):
  32. x = self.modules[0](a, b)
  33. y = self.modules[1](a, b)
  34. assert list(self.modules[2].keys()) == ["a", "b"]
  35. for _, m in self.modules[2].items():
  36. y = m(x, y)
  37. for m in self.modules[3:]:
  38. y = m(y)
  39. return y
  40. class MyModule4(M.Module):
  41. def __init__(self):
  42. super().__init__()
  43. self.add = F.add
  44. def forward(self, x, y):
  45. return self.add(x, y)
  46. def test_trace_module():
  47. x = Tensor(1)
  48. m1 = MyModule1()
  49. tm1 = trace_module(m1, x)
  50. m2 = MyModule2()
  51. tm2 = trace_module(m2, x)
  52. inp = Tensor(2)
  53. gt = m1(inp)
  54. output = tm1(inp)
  55. for a, b in zip(output, gt):
  56. np.testing.assert_equal(a.numpy(), b.numpy())
  57. gt1 = m2(inp)
  58. output1 = tm2(inp)
  59. for a, b in zip(output1, gt1):
  60. np.testing.assert_equal(a.numpy(), b.numpy())
  61. a, b = Tensor(1), Tensor(2)
  62. m3 = MyModule3()
  63. gt = m3(a, b)
  64. tm3 = trace_module(m3, a, b)
  65. out = tm3(a, b)
  66. np.testing.assert_equal(out.numpy(), gt.numpy())
  67. assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
  68. assert isinstance(tm3.modules.__dict__["2"], TracedModule)
  69. assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
  70. assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)
  71. m4 = MyModule4()
  72. tm4 = trace_module(m4, a, b)
  73. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  74. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  75. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  76. tm4 = trace_module(m4, a, y=b)
  77. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  78. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  79. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  80. tm4 = trace_module(m4, x=a, y=b)
  81. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  82. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  83. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  84. tm5 = trace_module(tm4, a, b)
  85. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  86. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  87. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  88. tm5 = trace_module(tm4, a, y=b)
  89. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  90. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  91. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  92. tm5 = trace_module(tm4, x=a, y=b)
  93. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  94. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  95. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  96. assert len(tm4.graph._exprs) == 1
  97. assert isinstance(tm4.graph._exprs[0], CallFunction)
  98. class MyModule5(Module):
  99. def __init__(self):
  100. super().__init__()
  101. self.m1 = tm4
  102. def forward(self, x, y):
  103. return self.m1(x, y)
  104. tm6 = trace_module(MyModule5(), a, b)
  105. assert tm6.m1.argspec is None
  106. assert tm6.m1._is_top is False

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台