You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_module.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from collections import OrderedDict
  2. import numpy as np
  3. import megengine.functional as F
  4. import megengine.module as M
  5. from megengine import Tensor
  6. from megengine.traced_module import TracedModule, trace_module
  7. from megengine.traced_module.expr import CallFunction
  8. class MyModule1(M.Module):
  9. def forward(self, x):
  10. y = Tensor(x)
  11. y += 1
  12. x = x + 2
  13. return x, y
  14. class MyModule2(M.Module):
  15. def forward(self, x):
  16. y = Tensor([1, x, 1])
  17. y += 1
  18. x = x + 2
  19. return x, y
  20. class MyModule3(M.Module):
  21. def __init__(self):
  22. super().__init__()
  23. self.modules = [
  24. M.Elemwise("ADD"),
  25. M.Elemwise("ADD"),
  26. OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
  27. M.Elemwise("RELU"),
  28. M.Elemwise("RELU"),
  29. ]
  30. def forward(self, a, b):
  31. x = self.modules[0](a, b)
  32. y = self.modules[1](a, b)
  33. assert list(self.modules[2].keys()) == ["a", "b"]
  34. for _, m in self.modules[2].items():
  35. y = m(x, y)
  36. for m in self.modules[3:]:
  37. y = m(y)
  38. return y
  39. class MyModule4(M.Module):
  40. def __init__(self):
  41. super().__init__()
  42. self.add = F.add
  43. def forward(self, x, y):
  44. return self.add(x, y)
  45. def test_trace_module():
  46. x = Tensor(1)
  47. m1 = MyModule1()
  48. tm1 = trace_module(m1, x)
  49. m2 = MyModule2()
  50. tm2 = trace_module(m2, x)
  51. inp = Tensor(2)
  52. gt = m1(inp)
  53. output = tm1(inp)
  54. for a, b in zip(output, gt):
  55. np.testing.assert_equal(a.numpy(), b.numpy())
  56. gt1 = m2(inp)
  57. output1 = tm2(inp)
  58. for a, b in zip(output1, gt1):
  59. np.testing.assert_equal(a.numpy(), b.numpy())
  60. a, b = Tensor(1), Tensor(2)
  61. m3 = MyModule3()
  62. gt = m3(a, b)
  63. tm3 = trace_module(m3, a, b)
  64. out = tm3(a, b)
  65. np.testing.assert_equal(out.numpy(), gt.numpy())
  66. assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
  67. assert isinstance(tm3.modules.__dict__["2"], TracedModule)
  68. assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
  69. assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)
  70. m4 = MyModule4()
  71. tm4 = trace_module(m4, a, b)
  72. assert len(tm4.graph._exprs) == 1
  73. assert isinstance(tm4.graph._exprs[0], CallFunction)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台