You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_modification.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import pickle
  9. import numpy as np
  10. import megengine.functional as F
  11. import megengine.module as M
  12. from megengine.module.identity import Identity
  13. from megengine.traced_module import trace_module
  14. from megengine.traced_module.expr import CallFunction, Expr, GetAttr
  15. from megengine.traced_module.node import Node
  16. class IdentityMod(M.Module):
  17. def forward(self, x):
  18. return x
  19. class MyBlock(M.Module):
  20. def __init__(self, in_channels=3, channels=3):
  21. super(MyBlock, self).__init__()
  22. self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
  23. self.bn1 = M.BatchNorm2d(channels)
  24. self.nothing = IdentityMod()
  25. def forward(self, x):
  26. x = self.conv1(x)
  27. x = self.bn1(x)
  28. x = F.relu(x) + 1
  29. x = self.nothing(x)
  30. return x
  31. class MyModule(M.Module):
  32. def __init__(self):
  33. super(MyModule, self).__init__()
  34. self.block0 = MyBlock()
  35. self.block1 = MyBlock()
  36. self.nothing = IdentityMod()
  37. def forward(self, x):
  38. x = self.block0(x)
  39. x = self.block1(x)
  40. x = self.nothing(x)
  41. return x
  42. class NewModule(M.Module):
  43. def __init__(self, traced_module):
  44. super(NewModule, self).__init__()
  45. self.module = traced_module
  46. def forward(self, x):
  47. x = x - 1
  48. x = self.module(x)
  49. x = x + 1
  50. return x
  51. def _init_cls(cls):
  52. module = cls()
  53. x = F.ones((1, 3, 3, 3))
  54. y = module(x)
  55. traced_module = trace_module(module, x)
  56. return traced_module, x, y
  57. def _init_block():
  58. return _init_cls(MyBlock)
  59. def _init_module():
  60. return _init_cls(MyModule)
  61. def test_search():
  62. traced_module, *_ = _init_block()
  63. graph = traced_module.graph
  64. relu_expr = graph.get_function_by_type(F.relu).as_unique()
  65. assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu
  66. def test_insert():
  67. traced_module, x, expect = _init_block()
  68. graph = traced_module.graph
  69. relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
  70. with graph.insert_exprs():
  71. neg_out = F.neg(relu_out)
  72. graph.replace_node({relu_out: neg_out})
  73. graph.compile()
  74. np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
  75. def test_delete():
  76. traced_module, x, expect = _init_block()
  77. graph = traced_module.graph
  78. relu_expr = graph.get_function_by_type(F.relu).as_unique()
  79. node = relu_expr.outputs
  80. repl_node = relu_expr.inputs
  81. graph.replace_node({node[0]: repl_node[0]})
  82. graph.compile()
  83. np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)
  84. # clear graph
  85. graph.replace_node({graph.outputs[0]: graph.inputs[1]})
  86. graph.compile()
  87. np.testing.assert_equal(len(list(graph._exprs)), 0)
  88. np.testing.assert_equal(traced_module(x).numpy(), x.numpy())
  89. def test_flatten():
  90. traced_module, x, expect = _init_module()
  91. traced_module = traced_module.flatten()
  92. traced_module.graph.compile()
  93. assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
  94. assert len(traced_module.graph._exprs) == 12
  95. np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())
  96. def test_id_and_name():
  97. def _check_id(traced_module):
  98. _total_ids = traced_module.graph._total_ids
  99. node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
  100. assert len(set(node_ids)) == len(node_ids)
  101. assert max(node_ids) + 1 == len(node_ids)
  102. expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
  103. assert len(set(expr_ids)) == len(expr_ids)
  104. assert max(expr_ids) + 1 == _total_ids[1]
  105. def _check_name(flatened_module):
  106. node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
  107. assert len(set(node_names)) == len(node_names)
  108. traced_module, x, expect = _init_module()
  109. _check_id(traced_module)
  110. flattened_module = traced_module.flatten()
  111. _check_id(flattened_module)
  112. _check_name(flattened_module)
  113. # pickle check
  114. obj = pickle.dumps(traced_module)
  115. traced_module = pickle.loads(obj)
  116. Node._set_next_id(159)
  117. Expr._set_next_id(1024)
  118. graph = traced_module.graph
  119. for expr in graph.get_function_by_type(F.relu).as_list():
  120. relu_out = expr.outputs[0]
  121. cur_graph = expr.top_graph
  122. with cur_graph.insert_exprs():
  123. neg_out = F.neg(relu_out)
  124. cur_graph.replace_node({relu_out: neg_out})
  125. cur_graph.compile()
  126. _check_id(traced_module)
  127. flattened_module = traced_module.flatten()
  128. _check_id(flattened_module)
  129. _check_name(flattened_module)
  130. # check trace TracedModule
  131. obj = pickle.dumps(traced_module)
  132. traced_module = pickle.loads(obj)
  133. module = NewModule(traced_module)
  134. traced_module = trace_module(module, x)
  135. _check_id(traced_module)
  136. flattened_module = traced_module.flatten()
  137. _check_id(flattened_module)
  138. _check_name(flattened_module)
  139. def test_set_name():
  140. traced_module, x, expect = _init_module()
  141. graph = traced_module.graph
  142. output_node = graph.outputs[0]
  143. def rename(name):
  144. output_node.name = name
  145. np.testing.assert_raises(AssertionError, rename, "block1_out")
  146. rename("output")
  147. np.testing.assert_equal(str(graph.outputs[0]), "output")
  148. def test_extra_block():
  149. class PostProcess(M.Module):
  150. def forward(self, x):
  151. return x * 2
  152. class Net(M.Module):
  153. def __init__(self, traced_module):
  154. super().__init__()
  155. self.post_process = PostProcess()
  156. self.traced_module = traced_module
  157. def forward(self, x):
  158. x = self.traced_module(x)
  159. x = self.post_process(x)
  160. return x
  161. traced_module, x, expect = _init_block()
  162. module = Net(traced_module)
  163. np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
  164. traced_module = trace_module(module, x)
  165. np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台