You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trace_module.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from collections import OrderedDict
  2. import numpy as np
  3. import megengine.functional as F
  4. import megengine.module as M
  5. from megengine import Tensor
  6. from megengine.core._imperative_rt.core2 import apply
  7. from megengine.core.ops import builtin
  8. from megengine.module import Module
  9. from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
  10. from megengine.traced_module.expr import Apply, CallFunction, Constant
  11. class MyModule1(M.Module):
  12. def forward(self, x):
  13. y = Tensor(x)
  14. y += 1
  15. x = x + 2
  16. return x, y
  17. class MyModule2(M.Module):
  18. def forward(self, x):
  19. y = Tensor([1, x, 1])
  20. y += 1
  21. x = x + 2
  22. return x, y
  23. class MyModule3(M.Module):
  24. def __init__(self):
  25. super().__init__()
  26. self.modules = [
  27. M.Elemwise("ADD"),
  28. M.Elemwise("ADD"),
  29. OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]),
  30. M.Elemwise("RELU"),
  31. M.Elemwise("RELU"),
  32. ]
  33. def forward(self, a, b):
  34. x = self.modules[0](a, b)
  35. y = self.modules[1](a, b)
  36. assert list(self.modules[2].keys()) == ["a", "b"]
  37. for _, m in self.modules[2].items():
  38. y = m(x, y)
  39. for m in self.modules[3:]:
  40. y = m(y)
  41. return y
  42. class MyModule4(M.Module):
  43. def __init__(self):
  44. super().__init__()
  45. self.add = F.add
  46. def forward(self, x, y):
  47. return self.add(x, y)
  48. def test_trace_module():
  49. enable_expr_checker()
  50. x = Tensor(1)
  51. m1 = MyModule1()
  52. tm1 = trace_module(m1, x)
  53. m2 = MyModule2()
  54. tm2 = trace_module(m2, x)
  55. inp = Tensor(2)
  56. gt = m1(inp)
  57. output = tm1(inp)
  58. for a, b in zip(output, gt):
  59. np.testing.assert_equal(a.numpy(), b.numpy())
  60. gt1 = m2(inp)
  61. output1 = tm2(inp)
  62. for a, b in zip(output1, gt1):
  63. np.testing.assert_equal(a.numpy(), b.numpy())
  64. a, b = Tensor(1), Tensor(2)
  65. m3 = MyModule3()
  66. gt = m3(a, b)
  67. tm3 = trace_module(m3, a, b)
  68. out = tm3(a, b)
  69. np.testing.assert_equal(out.numpy(), gt.numpy())
  70. assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
  71. assert isinstance(tm3.modules.__dict__["2"], TracedModule)
  72. assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
  73. assert isinstance(tm3.modules.__dict__["3"], M.Elemwise)
  74. m4 = MyModule4()
  75. tm4 = trace_module(m4, a, b)
  76. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  77. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  78. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  79. tm4 = trace_module(m4, a, y=b)
  80. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  81. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  82. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  83. tm4 = trace_module(m4, x=a, y=b)
  84. np.testing.assert_equal(tm4(a, b).numpy(), 3)
  85. np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
  86. np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
  87. tm5 = trace_module(tm4, a, b)
  88. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  89. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  90. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  91. tm5 = trace_module(tm4, a, y=b)
  92. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  93. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  94. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  95. tm5 = trace_module(tm4, x=a, y=b)
  96. np.testing.assert_equal(tm5(a, b).numpy(), 3)
  97. np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
  98. np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
  99. assert len(tm4.graph._exprs) == 1
  100. assert isinstance(tm4.graph._exprs[0], CallFunction)
  101. class MyModule5(Module):
  102. def __init__(self):
  103. super().__init__()
  104. self.m1 = tm4
  105. def forward(self, x, y):
  106. return self.m1(x, y)
  107. tm6 = trace_module(MyModule5(), a, b)
  108. assert tm6.m1.argspec is None
  109. assert tm6.m1._is_top is False
  110. def test_trace_module_2():
  111. class Model(M.Module):
  112. def __init__(self):
  113. super().__init__()
  114. def forward(self, x):
  115. out = x.shape
  116. out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1))
  117. return out
  118. traced_model = trace_module(Model(), Tensor(([1,])))
  119. assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance(
  120. traced_model.graph._exprs[0].opdef, builtin.GetVarShape
  121. )
  122. assert isinstance(traced_model.graph._exprs[1], Constant)
  123. assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance(
  124. traced_model.graph._exprs[2].opdef, builtin.Elemwise
  125. )
  126. assert int(traced_model(Tensor([1, 2]))[0]) == 3