You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_modification.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import pickle
  9. from itertools import chain
  10. import numpy as np
  11. import megengine.functional as F
  12. import megengine.module as M
  13. from megengine.module.identity import Identity
  14. from megengine.traced_module import trace_module
  15. from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
  16. from megengine.traced_module.node import ModuleNode, Node
  17. class IdentityMod(M.Module):
  18. def forward(self, x):
  19. return x
  20. class MyBlock(M.Module):
  21. def __init__(self, in_channels=3, channels=3):
  22. super(MyBlock, self).__init__()
  23. self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
  24. self.bn1 = M.BatchNorm2d(channels)
  25. self.nothing = IdentityMod()
  26. def forward(self, x):
  27. x = self.conv1(x)
  28. x = self.bn1(x)
  29. x = F.relu(x) + 1
  30. x = self.nothing(x)
  31. return x
  32. class MyModule(M.Module):
  33. def __init__(self):
  34. super(MyModule, self).__init__()
  35. self.block0 = MyBlock()
  36. self.block1 = MyBlock()
  37. self.nothing = IdentityMod()
  38. def forward(self, x):
  39. x = self.block0(x)
  40. x = self.block1(x)
  41. x = self.nothing(x)
  42. return x
  43. class NewModule(M.Module):
  44. def __init__(self, traced_module):
  45. super(NewModule, self).__init__()
  46. self.module = traced_module
  47. def forward(self, x):
  48. x = x - 1
  49. x = self.module(x)
  50. x = x + 1
  51. return x
  52. def _init_cls(cls):
  53. module = cls()
  54. x = F.ones((1, 3, 3, 3))
  55. y = module(x)
  56. traced_module = trace_module(module, x)
  57. return traced_module, x, y
  58. def _init_block():
  59. return _init_cls(MyBlock)
  60. def _init_module():
  61. return _init_cls(MyModule)
  62. def test_search():
  63. traced_module, *_ = _init_block()
  64. graph = traced_module.graph
  65. relu_expr = graph.get_function_by_type(F.relu).as_unique()
  66. assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu
  67. conv_node = graph.get_module_by_type(M.Conv2d).as_unique()
  68. assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d
  69. add_expr = graph.get_method_by_type("__add__").as_unique()
  70. assert isinstance(add_expr, CallMethod) and add_expr.method == "__add__"
  71. conv_node = graph.get_node_by_name("MyBlock_conv1").as_unique()
  72. assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d
  73. def test_producer_and_users():
  74. traced_module, *_ = _init_module()
  75. def _check(exprs):
  76. for expr in exprs:
  77. for n in chain(expr.inputs, expr.outputs):
  78. if not isinstance(n.expr, Input):
  79. assert n.expr in exprs
  80. for e in n.users:
  81. assert e in exprs
  82. assert n in e.inputs
  83. for mod in traced_module.modules():
  84. if not hasattr(mod, "argdef_graph_map"):
  85. continue
  86. for g in mod.argdef_graph_map.values():
  87. _check(g._exprs)
  88. def test_insert():
  89. traced_module, x, expect = _init_block()
  90. graph = traced_module.graph
  91. relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
  92. with graph.insert_exprs():
  93. neg_out = F.neg(relu_out)
  94. graph.replace_node({relu_out: neg_out})
  95. graph.compile()
  96. np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
  97. def test_insert_module():
  98. class Neg(M.Module):
  99. def forward(self, x):
  100. return F.neg(x)
  101. traced_module, x, expect = _init_block()
  102. graph = traced_module.graph
  103. relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
  104. self = graph.inputs[0]
  105. setattr(traced_module, "neg", Neg())
  106. with graph.insert_exprs():
  107. neg_out = self.neg(relu_out)
  108. graph.replace_node({relu_out: neg_out})
  109. graph.compile()
  110. np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
  111. assert traced_module.neg.graph is not None
  112. assert len(traced_module.neg.graph._exprs) == 1
  113. def test_add_input_and_output():
  114. traced_module, x, y = _init_module()
  115. data_node = traced_module.graph.add_input_node(shape=(1, 3, 224, 224), name="data")
  116. traced_module.graph.add_output_node(data_node)
  117. assert data_node.name == "data"
  118. assert traced_module.graph.inputs[-1] == data_node
  119. assert len(traced_module.graph.inputs) == 3
  120. assert len(traced_module.graph.outputs) == 2
  121. y1, y2 = traced_module(x, x)
  122. np.testing.assert_equal(y1.numpy(), y.numpy())
  123. np.testing.assert_equal(y2.numpy(), x.numpy())
  124. y1, y2 = traced_module(x, y)
  125. np.testing.assert_equal(y2.numpy(), y.numpy())
  126. traced_module.graph.reset_outputs(
  127. ({"orig_out": traced_module.graph.outputs[0]}, traced_module.graph.outputs[1])
  128. )
  129. out = traced_module(x, x)
  130. assert isinstance(out, tuple)
  131. assert isinstance(out[0], dict)
  132. np.testing.assert_equal(out[0]["orig_out"].numpy(), y.numpy())
  133. np.testing.assert_equal(out[1].numpy(), x.numpy())
  134. def test_delete():
  135. traced_module, x, expect = _init_block()
  136. graph = traced_module.graph
  137. relu_expr = graph.get_function_by_type(F.relu).as_unique()
  138. node = relu_expr.outputs
  139. repl_node = relu_expr.inputs
  140. graph.replace_node({node[0]: repl_node[0]})
  141. graph.compile()
  142. np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)
  143. # clear graph
  144. graph.replace_node({graph.outputs[0]: graph.inputs[1]})
  145. graph.compile()
  146. np.testing.assert_equal(len(list(graph._exprs)), 0)
  147. np.testing.assert_equal(traced_module(x).numpy(), x.numpy())
  148. def test_flatten():
  149. traced_module, x, expect = _init_module()
  150. traced_module = traced_module.flatten()
  151. assert len(traced_module.graph._exprs) == 12
  152. np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())
  153. traced_module = traced_module.flatten()
  154. assert len(traced_module.graph._exprs) == 12
  155. np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())
  156. def test_id_and_name():
  157. def _check_id(traced_module):
  158. _total_ids = traced_module.graph._total_ids
  159. node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
  160. assert len(set(node_ids)) == len(node_ids)
  161. assert max(node_ids) + 1 == _total_ids[0]
  162. expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
  163. assert len(set(expr_ids)) == len(expr_ids)
  164. assert max(expr_ids) + 1 == _total_ids[1]
  165. def _check_name(flatened_module):
  166. node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
  167. assert len(set(node_names)) == len(node_names)
  168. traced_module, x, expect = _init_module()
  169. _check_id(traced_module)
  170. flattened_module = traced_module.flatten()
  171. _check_id(flattened_module)
  172. _check_name(flattened_module)
  173. # pickle check
  174. obj = pickle.dumps(traced_module)
  175. traced_module = pickle.loads(obj)
  176. Node._set_next_id(159)
  177. Expr._set_next_id(1024)
  178. graph = traced_module.graph
  179. for expr in graph.get_function_by_type(F.relu).as_list():
  180. relu_out = expr.outputs[0]
  181. cur_graph = expr.top_graph
  182. with cur_graph.insert_exprs():
  183. neg_out = F.neg(relu_out)
  184. cur_graph.replace_node({relu_out: neg_out})
  185. cur_graph.compile()
  186. _check_id(traced_module)
  187. flattened_module = traced_module.flatten()
  188. _check_id(flattened_module)
  189. _check_name(flattened_module)
  190. # check trace TracedModule
  191. obj = pickle.dumps(traced_module)
  192. traced_module = pickle.loads(obj)
  193. module = NewModule(traced_module)
  194. traced_module = trace_module(module, x)
  195. _check_id(traced_module)
  196. flattened_module = traced_module.flatten()
  197. _check_id(flattened_module)
  198. _check_name(flattened_module)
  199. def test_set_node_name():
  200. traced_module, x, expect = _init_module()
  201. graph = traced_module.graph
  202. output_node = graph.outputs[0]
  203. def rename(name):
  204. output_node.name = name
  205. np.testing.assert_raises(AssertionError, rename, "block1_out")
  206. rename("output")
  207. np.testing.assert_equal(str(graph.outputs[0]), "output")
  208. def test_set_graph_name():
  209. traced_module, x, expect = _init_module()
  210. graph = traced_module.graph
  211. output_node = graph.outputs[0]
  212. node_name = output_node.name
  213. graph.name = "Top"
  214. node = graph.get_node_by_name("{}_{}".format("Top", node_name)).as_unique()
  215. assert node is output_node
  216. def test_extra_block():
  217. class PostProcess(M.Module):
  218. def forward(self, x):
  219. return x * 2
  220. class Net(M.Module):
  221. def __init__(self, traced_module):
  222. super().__init__()
  223. self.post_process = PostProcess()
  224. self.traced_module = traced_module
  225. def forward(self, x):
  226. x = self.traced_module(x)
  227. x = self.post_process(x)
  228. return x
  229. traced_module, x, expect = _init_block()
  230. module = Net(traced_module)
  231. np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
  232. traced_module = trace_module(module, x)
  233. np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台