You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_module.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from collections import OrderedDict
  2. import numpy as np
  3. import megengine.functional as F
  4. import megengine.module as M
  5. from megengine import Tensor
  6. from megengine.core._imperative_rt.core2 import apply
  7. from megengine.core.ops import builtin
  8. from megengine.module import Module
  9. from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
  10. from megengine.traced_module.expr import Apply, CallFunction, CallMethod, Constant
  11. class MyModule1(M.Module):
  12. def forward(self, x):
  13. y = Tensor(x)
  14. y += 1
  15. x = x + 2
  16. return x, y
  17. class MyModule2(M.Module):
  18. def forward(self, x):
  19. y = Tensor([1, x, 1])
  20. y += 1
  21. x = x + 2
  22. return x, y
  23. class MyModule3(M.Module):
  24. def __init__(self):
  25. super().__init__()
  26. self.modules = [
  27. M.Elemwise("ADD"),
  28. M.Elemwise("ADD"),
  29. OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
  30. M.Elemwise("RELU"),
  31. M.Elemwise("RELU"),
  32. ]
  33. def forward(self, a, b):
  34. x = self.modules[0](a, b)
  35. y = self.modules[1](a, b)
  36. assert list(self.modules[2].keys()) == ["a", "b"]
  37. for _, m in self.modules[2].items():
  38. y = m(x, y)
  39. for m in self.modules[3:]:
  40. y = m(y)
  41. return y
  42. class MyModule4(M.Module):
  43. def __init__(self):
  44. super().__init__()
  45. self.add = F.add
  46. def forward(self, x, y):
  47. return self.add(x, y)
  48. class MyModule5(M.Module):
  49. def forward(self, x):
  50. a = x + x
  51. b = x * a
  52. b.name = "result"
  53. return b
  54. def test_trace_module():
  55. enable_expr_checker()
  56. x = Tensor(1)
  57. m1 = MyModule1()
  58. tm1 = trace_module(m1, x)
  59. m2 = MyModule2()
  60. tm2 = trace_module(m2, x)
  61. inp = Tensor(2)
  62. gt = m1(inp)
  63. output = tm1(inp)
  64. for a, b in zip(output, gt):
  65. np.testing.assert_equal(a.numpy(), b.numpy())
  66. gt1 = m2(inp)
  67. output1 = tm2(inp)
  68. for a, b in zip(output1, gt1):
  69. np.testing.assert_equal(a.numpy(), b.numpy())
  70. a, b = Tensor(1), Tensor(2)
  71. m3 = MyModule3()
  72. gt = m3(a, b)
  73. tm3 = trace_module(m3, a, b)
  74. out = tm3(a, b)
  75. np.testing.assert_equal(out.numpy(), gt.numpy())
  76. assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
  77. assert isinstance(tm3.modules.__dict__["2"], TracedModule)
  78. assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
  79. assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)
  80. m4 = MyModule4()
  81. tm4 = trace_module(m4, a, b)
  82. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  83. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  84. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  85. tm4 = trace_module(m4, a, y=b)
  86. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  87. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  88. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  89. tm4 = trace_module(m4, x=a, y=b)
  90. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  91. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  92. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  93. tm5 = trace_module(tm4, a, b)
  94. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  95. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  96. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  97. tm5 = trace_module(tm4, a, y=b)
  98. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  99. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  100. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  101. tm5 = trace_module(tm4, x=a, y=b)
  102. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  103. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  104. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  105. assert len(tm4.graph._exprs) == 1
  106. assert isinstance(tm4.graph._exprs[0], CallFunction)
  107. class MyModule5(Module):
  108. def __init__(self):
  109. super().__init__()
  110. self.m1 = tm4
  111. def forward(self, x, y):
  112. return self.m1(x, y)
  113. tm6 = trace_module(MyModule5(), a, b)
  114. assert tm6.m1.argspec is None
  115. assert tm6.m1._is_top is False
  116. def test_trace_module_2():
  117. class Model(M.Module):
  118. def __init__(self):
  119. super().__init__()
  120. def forward(self, x):
  121. out = x.shape
  122. out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1))
  123. return out
  124. traced_model = trace_module(Model(), Tensor(([1,])))
  125. assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance(
  126. traced_model.graph._exprs[0].opdef, builtin.GetVarShape
  127. )
  128. assert isinstance(traced_model.graph._exprs[1], Constant)
  129. assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance(
  130. traced_model.graph._exprs[2].opdef, builtin.Elemwise
  131. )
  132. assert int(traced_model(Tensor([1, 2]))[0]) == 3
  133. def test_rename():
  134. model = MyModule5()
  135. tm_model = trace_module(model, Tensor(1))
  136. assert isinstance(tm_model.graph.outputs[0].expr, CallMethod)