You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_modification.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. import pickle
  2. from collections import defaultdict
  3. from itertools import chain
  4. import numpy as np
  5. import megengine.functional as F
  6. import megengine.module as M
  7. import megengine.module.qat as qat
  8. from megengine.module.identity import Identity
  9. from megengine.traced_module import TracedModule, trace_module
  10. from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
  11. from megengine.traced_module.node import ModuleNode, Node, TensorNode
  12. class IdentityMod(M.Module):
  13. def forward(self, x):
  14. return x
  15. class MyBlock(M.Module):
  16. def __init__(self, in_channels=3, channels=3):
  17. super(MyBlock, self).__init__()
  18. self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
  19. self.bn1 = M.BatchNorm2d(channels)
  20. self.nothing = IdentityMod()
  21. def forward(self, x):
  22. x = self.conv1(x)
  23. x = self.bn1(x)
  24. x = F.relu(x) + 1
  25. x = self.nothing(x)
  26. return x
  27. class MyModule(M.Module):
  28. def __init__(self):
  29. super(MyModule, self).__init__()
  30. self.block0 = MyBlock()
  31. self.block1 = MyBlock()
  32. self.nothing = IdentityMod()
  33. def forward(self, x):
  34. x = self.block0(x)
  35. x = self.block1(x)
  36. x = self.nothing(x)
  37. return x
  38. class MyBlock1(M.Module):
  39. def forward(self, a):
  40. y = F.concat([a, a])
  41. return a, y
  42. class MyModule1(M.Module):
  43. def __init__(self):
  44. super().__init__()
  45. self.block0 = MyBlock1()
  46. self.block1 = MyBlock1()
  47. def forward(self, a):
  48. a, y1 = self.block0(a)
  49. a = a + 1
  50. a, y2 = self.block1(a)
  51. return a, y1 + y2
  52. class NewModule(M.Module):
  53. def __init__(self, traced_module):
  54. super(NewModule, self).__init__()
  55. self.module = traced_module
  56. def forward(self, x):
  57. x = x - 1
  58. x = self.module(x)
  59. x = x + 1
  60. return x
  61. def _check_expr_users(flattened_module):
  62. node_user = defaultdict(list)
  63. for expr in flattened_module.graph._exprs:
  64. for node in expr.inputs:
  65. node_user[node].append(expr)
  66. for node in flattened_module.graph.nodes():
  67. node.users.sort(key=lambda m: m._id)
  68. node_user[node].sort(key=lambda m: m._id)
  69. assert node.users == node_user[node]
  70. def _init_cls(cls):
  71. module = cls()
  72. x = F.ones((1, 3, 3, 3))
  73. y = module(x)
  74. traced_module = trace_module(module, x)
  75. return traced_module, x, y
  76. def _init_block():
  77. return _init_cls(MyBlock)
  78. def _init_module():
  79. return _init_cls(MyModule)
  80. def test_search():
  81. traced_module, *_ = _init_block()
  82. graph = traced_module.graph
  83. relu_expr = graph.get_function_by_type(F.relu).as_unique()
  84. assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu
  85. conv_node = graph.get_module_by_type(M.Conv2d).as_unique()
  86. assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d
  87. add_expr = graph.get_method_by_type("__add__").as_unique()
  88. assert isinstance(add_expr, CallMethod) and add_expr.method == "__add__"
  89. conv_node = graph.get_node_by_name("MyBlock_conv1").as_unique()
  90. assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d
  91. def test_producer_and_users():
  92. traced_module, *_ = _init_module()
  93. def _check(exprs):
  94. for expr in exprs:
  95. for n in chain(expr.inputs, expr.outputs):
  96. if not isinstance(n.expr, Input):
  97. assert n.expr in exprs
  98. for e in n.users:
  99. assert e in exprs
  100. assert n in e.inputs
  101. for mod in traced_module.modules():
  102. if not hasattr(mod, "argdef_graph_map"):
  103. continue
  104. for g in mod.argdef_graph_map.values():
  105. _check(g._exprs)
  106. def test_insert():
  107. traced_module, x, expect = _init_block()
  108. graph = traced_module.graph
  109. relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
  110. with graph.insert_exprs():
  111. neg_out = F.neg(relu_out)
  112. graph.replace_node({relu_out: neg_out})
  113. graph.compile()
  114. np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
  115. def test_insert_module():
  116. class Neg(M.Module):
  117. def __init__(self, name):
  118. super().__init__(name)
  119. self.identity = M.Identity()
  120. self.identity_list = [M.Identity(), M.Identity()]
  121. self.identity_dict = {"0": M.Identity(), "1": M.Identity()}
  122. self.param = F.zeros((1,))
  123. def forward(self, x):
  124. x = self.identity(x)
  125. for m in self.identity_dict:
  126. x = self.identity_dict[m](x)
  127. for m in self.identity_list:
  128. x = m(x)
  129. return F.neg(x) + self.param
  130. traced_module, x, expect = _init_block()
  131. graph = traced_module.graph
  132. relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
  133. self = graph.inputs[0]
  134. setattr(traced_module, "neg", Neg(name="neg"))
  135. setattr(traced_module, "neg2", Neg(name="neg"))
  136. setattr(traced_module, "param", F.zeros((1,)))
  137. with graph.insert_exprs():
  138. neg_out = self.neg(relu_out)
  139. neg_out = self.neg2(relu_out)
  140. neg_out = neg_out + self.param
  141. graph.replace_node({relu_out: neg_out})
  142. graph.compile()
  143. np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
  144. assert traced_module.neg.graph is not None
  145. assert traced_module.neg2.graph is not None
  146. assert traced_module.neg2.param is not None
  147. assert len(traced_module.neg.graph._exprs) == 13
  148. for n in traced_module.graph.nodes():
  149. if isinstance(n, TensorNode):
  150. assert n.value is None
  151. traced_module, x, expect = _init_module()
  152. setattr(traced_module.block0, "neg", Neg(name=None))
  153. graph = traced_module.graph
  154. self = graph.inputs[0]
  155. out_node = graph.outputs[0]
  156. with graph.insert_exprs():
  157. neg_out = self.block0.neg(out_node)
  158. graph.replace_node({out_node: neg_out})
  159. graph.compile()
  160. np.testing.assert_allclose(expect, -traced_module(x), atol=1e-6)
  161. assert isinstance(traced_module.block0.neg, TracedModule)
  162. assert traced_module.block0.neg.graph is not None
  163. setattr(traced_module.block0.neg, "neg", Neg(name=None))
  164. setattr(traced_module.block0.neg.neg, "relu", M.ReLU())
  165. out_node = graph.outputs[0]
  166. with graph.insert_exprs():
  167. neg_out = self.block0.neg.neg(out_node)
  168. neg_out = self.block0.neg.neg(neg_out)
  169. relu_out = self.block0.neg.neg.relu(neg_out)
  170. graph.replace_node({out_node: relu_out})
  171. graph.compile()
  172. np.testing.assert_allclose(F.relu(-expect), traced_module(x), atol=1e-6)
  173. assert isinstance(traced_module.block0.neg.neg, TracedModule)
  174. assert traced_module.block0.neg.neg.graph is not None
  175. def test_insert_qat_module():
  176. class concat(qat.Concat):
  177. pass
  178. traced_module, x, expect = _init_block()
  179. graph = traced_module.graph
  180. self = graph.inputs[0]
  181. out = graph.outputs[0]
  182. setattr(traced_module, "cat_0", qat.Concat())
  183. setattr(traced_module, "cat_1", concat())
  184. with graph.insert_exprs():
  185. x_0 = self.cat_0([out, out])
  186. x_1 = self.cat_1([out, x_0])
  187. graph.replace_node({out: x_1})
  188. graph.compile()
  189. x = F.copy(x)
  190. np.testing.assert_allclose(
  191. F.concat([expect, expect, expect]), traced_module(x), atol=1e-6
  192. )
  193. assert not hasattr(traced_module.cat_0, "graph")
  194. assert traced_module.cat_1.graph is not None
  195. def test_add_input_and_output():
  196. traced_module, x, y = _init_module()
  197. data_node = traced_module.graph.add_input_node(shape=(1, 3, 224, 224), name="data")
  198. traced_module.graph.add_output_node(data_node)
  199. assert data_node.name == "data"
  200. assert traced_module.graph.inputs[-1] == data_node
  201. assert len(traced_module.graph.inputs) == 3
  202. assert len(traced_module.graph.outputs) == 2
  203. y1, y2 = traced_module(x, x)
  204. np.testing.assert_equal(y1.numpy(), y.numpy())
  205. np.testing.assert_equal(y2.numpy(), x.numpy())
  206. y1, y2 = traced_module(x, y)
  207. np.testing.assert_equal(y2.numpy(), y.numpy())
  208. traced_module.graph.reset_outputs(
  209. ({"orig_out": traced_module.graph.outputs[0]}, traced_module.graph.outputs[1])
  210. )
  211. out = traced_module(x, x)
  212. assert isinstance(out, tuple)
  213. assert isinstance(out[0], dict)
  214. np.testing.assert_equal(out[0]["orig_out"].numpy(), y.numpy())
  215. np.testing.assert_equal(out[1].numpy(), x.numpy())
  216. def test_delete():
  217. traced_module, x, expect = _init_block()
  218. graph = traced_module.graph
  219. relu_expr = graph.get_function_by_type(F.relu).as_unique()
  220. node = relu_expr.outputs
  221. repl_node = relu_expr.inputs
  222. graph.replace_node({node[0]: repl_node[0]})
  223. graph.compile()
  224. np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6)
  225. # clear graph
  226. graph.replace_node({graph.outputs[0]: graph.inputs[1]})
  227. graph.compile()
  228. np.testing.assert_equal(len(list(graph._exprs)), 0)
  229. np.testing.assert_equal(traced_module(x).numpy(), x.numpy())
  230. def test_flatten():
  231. traced_module, x, expect = _init_module()
  232. traced_module = traced_module.flatten()
  233. assert len(traced_module.graph._exprs) == 12
  234. np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())
  235. traced_module = traced_module.flatten()
  236. assert len(traced_module.graph._exprs) == 12
  237. np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())
  238. traced_module, x, expect = _init_cls(MyModule1)
  239. traced_module = traced_module.flatten()
  240. _check_expr_users(traced_module)
  241. def test_id_and_name():
  242. def _check_id(traced_module):
  243. _total_ids = traced_module.graph._total_ids
  244. node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
  245. assert len(set(node_ids)) == len(node_ids)
  246. assert max(node_ids) + 1 == _total_ids[0]
  247. expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
  248. assert len(set(expr_ids)) == len(expr_ids)
  249. assert max(expr_ids) + 1 == _total_ids[1]
  250. def _check_name(flatened_module):
  251. node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
  252. assert len(set(node_names)) == len(node_names)
  253. traced_module, x, expect = _init_module()
  254. _check_id(traced_module)
  255. flattened_module = traced_module.flatten()
  256. _check_id(flattened_module)
  257. _check_name(flattened_module)
  258. # pickle check
  259. obj = pickle.dumps(traced_module)
  260. traced_module = pickle.loads(obj)
  261. Node._set_next_id(159)
  262. Expr._set_next_id(1024)
  263. graph = traced_module.graph
  264. for expr in graph.get_function_by_type(F.relu).as_list():
  265. relu_out = expr.outputs[0]
  266. cur_graph = expr.top_graph
  267. with cur_graph.insert_exprs():
  268. neg_out = F.neg(relu_out)
  269. cur_graph.replace_node({relu_out: neg_out})
  270. cur_graph.compile()
  271. _check_id(traced_module)
  272. flattened_module = traced_module.flatten()
  273. _check_id(flattened_module)
  274. _check_name(flattened_module)
  275. # check trace TracedModule
  276. obj = pickle.dumps(traced_module)
  277. traced_module = pickle.loads(obj)
  278. module = NewModule(traced_module)
  279. traced_module = trace_module(module, x)
  280. _check_id(traced_module)
  281. flattened_module = traced_module.flatten()
  282. _check_id(flattened_module)
  283. _check_name(flattened_module)
  284. def test_set_node_name():
  285. traced_module, x, expect = _init_module()
  286. graph = traced_module.graph
  287. output_node = graph.outputs[0]
  288. def rename(name):
  289. output_node.name = name
  290. np.testing.assert_raises(AssertionError, rename, "block1_out")
  291. rename("output")
  292. np.testing.assert_equal(str(graph.outputs[0]), "output")
  293. def add_1(x):
  294. x = x + 1
  295. x.name = "func_add_1"
  296. return x
  297. class ModuleAdd_3(M.Module):
  298. def forward(self, x):
  299. x = x + 1
  300. x.name = "module_add_1"
  301. x = x + 2
  302. return x
  303. setattr(traced_module, "add_3", ModuleAdd_3())
  304. self = graph.inputs[0]
  305. with graph.insert_exprs():
  306. x = output_node + 1
  307. x.name = "_add_1"
  308. x = add_1(x)
  309. x = self.add_3(x)
  310. graph.replace_node({output_node: x})
  311. graph.compile()
  312. assert "_add_1" in graph._namespace.used_names
  313. assert "func_add_1" in graph._namespace.used_names
  314. assert "module_add_1" in traced_module.add_3.graph._namespace.used_names
  315. def test_set_graph_name():
  316. traced_module, x, expect = _init_module()
  317. graph = traced_module.graph
  318. output_node = graph.outputs[0]
  319. node_name = output_node.name
  320. graph.name = "Top"
  321. node = graph.get_node_by_name("{}_{}".format("Top", node_name)).as_unique()
  322. assert node is output_node
  323. def test_extra_block():
  324. class PostProcess(M.Module):
  325. def forward(self, x):
  326. return x * 2
  327. class Net(M.Module):
  328. def __init__(self, traced_module):
  329. super().__init__()
  330. self.post_process = PostProcess()
  331. self.traced_module = traced_module
  332. def forward(self, x):
  333. x = self.traced_module(x)
  334. x = self.post_process(x)
  335. return x
  336. traced_module, x, expect = _init_block()
  337. module = Net(traced_module)
  338. np.testing.assert_allclose(2 * expect, module(x), atol=1e-6)
  339. traced_module = trace_module(module, x)
  340. np.testing.assert_allclose(2 * expect, traced_module(x), atol=1e-6)