You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_module.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from collections import OrderedDict
  2. import numpy as np
  3. import megengine.functional as F
  4. import megengine.module as M
  5. from megengine import Tensor
  6. from megengine.module.module import Module
  7. from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
  8. from megengine.traced_module.expr import CallFunction
  9. class MyModule1(M.Module):
  10. def forward(self, x):
  11. y = Tensor(x)
  12. y += 1
  13. x = x + 2
  14. return x, y
  15. class MyModule2(M.Module):
  16. def forward(self, x):
  17. y = Tensor([1, x, 1])
  18. y += 1
  19. x = x + 2
  20. return x, y
  21. class MyModule3(M.Module):
  22. def __init__(self):
  23. super().__init__()
  24. self.modules = [
  25. M.Elemwise("ADD"),
  26. M.Elemwise("ADD"),
  27. OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
  28. M.Elemwise("RELU"),
  29. M.Elemwise("RELU"),
  30. ]
  31. def forward(self, a, b):
  32. x = self.modules[0](a, b)
  33. y = self.modules[1](a, b)
  34. assert list(self.modules[2].keys()) == ["a", "b"]
  35. for _, m in self.modules[2].items():
  36. y = m(x, y)
  37. for m in self.modules[3:]:
  38. y = m(y)
  39. return y
  40. class MyModule4(M.Module):
  41. def __init__(self):
  42. super().__init__()
  43. self.add = F.add
  44. def forward(self, x, y):
  45. return self.add(x, y)
  46. def test_trace_module():
  47. enable_expr_checker()
  48. x = Tensor(1)
  49. m1 = MyModule1()
  50. tm1 = trace_module(m1, x)
  51. m2 = MyModule2()
  52. tm2 = trace_module(m2, x)
  53. inp = Tensor(2)
  54. gt = m1(inp)
  55. output = tm1(inp)
  56. for a, b in zip(output, gt):
  57. np.testing.assert_equal(a.numpy(), b.numpy())
  58. gt1 = m2(inp)
  59. output1 = tm2(inp)
  60. for a, b in zip(output1, gt1):
  61. np.testing.assert_equal(a.numpy(), b.numpy())
  62. a, b = Tensor(1), Tensor(2)
  63. m3 = MyModule3()
  64. gt = m3(a, b)
  65. tm3 = trace_module(m3, a, b)
  66. out = tm3(a, b)
  67. np.testing.assert_equal(out.numpy(), gt.numpy())
  68. assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
  69. assert isinstance(tm3.modules.__dict__["2"], TracedModule)
  70. assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
  71. assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)
  72. m4 = MyModule4()
  73. tm4 = trace_module(m4, a, b)
  74. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  75. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  76. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  77. tm4 = trace_module(m4, a, y=b)
  78. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  79. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  80. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  81. tm4 = trace_module(m4, x=a, y=b)
  82. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  83. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  84. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  85. tm5 = trace_module(tm4, a, b)
  86. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  87. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  88. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  89. tm5 = trace_module(tm4, a, y=b)
  90. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  91. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  92. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  93. tm5 = trace_module(tm4, x=a, y=b)
  94. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  95. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  96. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  97. assert len(tm4.graph._exprs) == 1
  98. assert isinstance(tm4.graph._exprs[0], CallFunction)
  99. class MyModule5(Module):
  100. def __init__(self):
  101. super().__init__()
  102. self.m1 = tm4
  103. def forward(self, x, y):
  104. return self.m1(x, y)
  105. tm6 = trace_module(MyModule5(), a, b)
  106. assert tm6.m1.argspec is None
  107. assert tm6.m1._is_top is False