You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serialization.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. import pickle
  2. from collections import defaultdict
  3. from tempfile import TemporaryFile
  4. import numpy as np
  5. import megengine.functional as F
  6. import megengine.module as M
  7. import megengine.traced_module.serialization as S
  8. from megengine import Tensor
  9. from megengine.core._imperative_rt.core2 import apply
  10. from megengine.core.ops import builtin
  11. from megengine.core.ops.builtin import Elemwise
  12. from megengine.module import Module
  13. from megengine.traced_module import trace_module
  14. from megengine.traced_module.expr import CallMethod, Constant
  15. from megengine.traced_module.node import TensorNode
  16. from megengine.traced_module.serialization import (
  17. register_functional_loader,
  18. register_module_loader,
  19. register_opdef_loader,
  20. register_tensor_method_loader,
  21. )
  22. from megengine.traced_module.utils import _convert_kwargs_to_args
  23. def _check_id(traced_module):
  24. _total_ids = traced_module.graph._total_ids
  25. node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
  26. assert len(set(node_ids)) == len(node_ids)
  27. assert max(node_ids) + 1 == _total_ids[0]
  28. expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
  29. assert len(set(expr_ids)) == len(expr_ids)
  30. assert max(expr_ids) + 1 == _total_ids[1]
  31. def _check_name(flatened_module):
  32. node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
  33. assert len(set(node_names)) == len(node_names)
  34. def _check_expr_users(traced_module):
  35. node_user = defaultdict(list)
  36. for expr in traced_module.graph._exprs:
  37. for node in expr.inputs:
  38. node_user[node].append(expr)
  39. if isinstance(expr, CallMethod) and expr.graph:
  40. _check_expr_users(expr.inputs[0].owner)
  41. for node in traced_module.graph.nodes(False):
  42. node.users.sort(key=lambda m: m._id)
  43. node_user[node].sort(key=lambda m: m._id)
  44. assert node.users == node_user[node]
  45. class MyBlock(Module):
  46. def __init__(self, in_channels, channels):
  47. super(MyBlock, self).__init__()
  48. self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
  49. self.bn1 = M.BatchNorm2d(channels)
  50. def forward(self, x):
  51. x = self.conv1(x)
  52. x = self.bn1(x)
  53. x = F.relu(x) + 1
  54. return x
  55. class MyModule(Module):
  56. def __init__(self):
  57. super(MyModule, self).__init__()
  58. self.block0 = MyBlock(8, 4)
  59. self.block1 = MyBlock(4, 2)
  60. def forward(self, x):
  61. x = self.block0(x)
  62. x = self.block1(x)
  63. return x
  64. def test_dump_and_load():
  65. module = MyModule()
  66. x = Tensor(np.ones((1, 8, 14, 14)))
  67. expect = module(x)
  68. traced_module = trace_module(module, x)
  69. np.testing.assert_array_equal(expect, traced_module(x))
  70. obj = pickle.dumps(traced_module)
  71. new_tm = pickle.loads(obj)
  72. _check_id(new_tm)
  73. _check_expr_users(new_tm)
  74. traced_module.graph._reset_ids()
  75. old_nodes = traced_module.graph.nodes().as_list()
  76. new_nodes = new_tm.graph.nodes().as_list()
  77. old_exprs = traced_module.graph.exprs().as_list()
  78. new_exprs = new_tm.graph.exprs().as_list()
  79. assert len(old_nodes) == len(new_nodes)
  80. for i, j in zip(old_nodes, new_nodes):
  81. assert i._name == j._name
  82. assert i._qualname == j._qualname
  83. assert i._id == j._id
  84. assert len(old_exprs) == len(new_exprs)
  85. for i, j in zip(old_exprs, new_exprs):
  86. assert i._id == j._id
  87. np.testing.assert_array_equal(expect, traced_module(x))
  88. def test_opdef_loader():
  89. class MyModule1(Module):
  90. def forward(self, x, y):
  91. op = Elemwise("ADD")
  92. return apply(op, x, y)[0]
  93. m = MyModule1()
  94. x = Tensor(np.ones((20)))
  95. y = Tensor(np.ones((20)))
  96. traced_module = trace_module(m, x, y)
  97. orig_loader_dict = S.OPDEF_LOADER
  98. S.OPDEF_LOADER = {}
  99. @register_opdef_loader(Elemwise)
  100. def add_opdef_loader(expr):
  101. if expr.opdef_state["mode"] == "ADD":
  102. expr.opdef_state["mode"] = "MUL"
  103. node = expr.inputs[1]
  104. astype_expr = CallMethod(node, "astype")
  105. oup = TensorNode(
  106. astype_expr,
  107. shape=node.shape,
  108. dtype=expr.inputs[0].dtype,
  109. qparams=node.qparams,
  110. )
  111. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  112. astype_expr.return_val = (oup,)
  113. expr.inputs[1] = oup
  114. obj = pickle.dumps(traced_module)
  115. new_module = pickle.loads(obj)
  116. _check_id(new_module)
  117. _check_expr_users(new_module)
  118. _check_name(new_module.flatten())
  119. assert (
  120. isinstance(new_module.graph._exprs[0], CallMethod)
  121. and new_module.graph._exprs[1].opdef.mode == "MUL"
  122. and len(new_module.graph._exprs) == 2
  123. )
  124. result = new_module(x, y)
  125. np.testing.assert_equal(result.numpy(), x.numpy())
  126. S.OPDEF_LOADER = orig_loader_dict
  127. def test_functional_loader():
  128. class MyModule2(Module):
  129. def forward(self, x, y):
  130. return F.conv2d(x, y)
  131. m = MyModule2()
  132. x = Tensor(np.random.random((1, 3, 32, 32)))
  133. y = Tensor(np.random.random((3, 3, 3, 3)))
  134. traced_module = trace_module(m, x, y)
  135. orig_loader_dict = S.FUNCTIONAL_LOADER
  136. S.FUNCTIONAL_LOADER = {}
  137. @register_functional_loader(("megengine.functional.nn", "conv2d"))
  138. def conv2df_loader(expr):
  139. # expr.func = ("megengine.functional.nn","conv2d")
  140. kwargs = expr.kwargs
  141. orig_weight = expr.named_args["weight"]
  142. astype_expr = CallMethod(orig_weight, "astype")
  143. oup = TensorNode(
  144. astype_expr,
  145. shape=orig_weight.shape,
  146. dtype=orig_weight.dtype,
  147. qparams=orig_weight.qparams,
  148. )
  149. astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
  150. astype_expr.return_val = (oup,)
  151. expr.set_arg("weight", oup)
  152. obj = pickle.dumps(traced_module)
  153. new_module = pickle.loads(obj)
  154. _check_expr_users(new_module)
  155. _check_id(new_module)
  156. result = new_module(x, y)
  157. gt = m(x, y)
  158. assert (
  159. isinstance(new_module.graph._exprs[0], CallMethod)
  160. and len(new_module.graph._exprs) == 2
  161. )
  162. np.testing.assert_equal(result.numpy(), gt.numpy())
  163. S.FUNCTIONAL_LOADER = orig_loader_dict
  164. def test_tensor_method_loader():
  165. class MyModule3(Module):
  166. def forward(self, x):
  167. return x + 1
  168. m = MyModule3()
  169. x = Tensor(np.ones((20)))
  170. traced_module = trace_module(m, x)
  171. orig_loader_dict = S.TENSORMETHOD_LOADER
  172. S.TENSORMETHOD_LOADER = {}
  173. @register_tensor_method_loader("__add__")
  174. def add_loader(expr):
  175. args = list(expr.args)
  176. if not isinstance(args[1], TensorNode):
  177. args[1] = Tensor(args[1])
  178. node = Constant(args[1], "const").outputs[0]
  179. astype_expr = CallMethod(node, "astype")
  180. oup = TensorNode(
  181. astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
  182. )
  183. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  184. astype_expr.return_val = (oup,)
  185. add_expr = CallMethod(oup, "__add__")
  186. add_expr.set_args_kwargs(oup, oup)
  187. oup1 = TensorNode(
  188. add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams,
  189. )
  190. add_expr.return_val = oup1
  191. args[1] = oup1
  192. expr.set_args_kwargs(*args)
  193. obj = pickle.dumps(traced_module)
  194. new_module = pickle.loads(obj)
  195. _check_expr_users(new_module)
  196. _check_id(new_module)
  197. result = new_module(x)
  198. gt = m(x)
  199. assert (
  200. isinstance(new_module.graph._exprs[0], Constant)
  201. and len(new_module.graph._exprs) == 4
  202. )
  203. np.testing.assert_equal(result.numpy(), (x + 2).numpy())
  204. S.TENSORMETHOD_LOADER = orig_loader_dict
  205. def test_module_loader():
  206. class MyModule4(Module):
  207. def __init__(self):
  208. super().__init__()
  209. self.conv = M.Conv2d(3, 3, 3)
  210. def forward(self, x):
  211. return self.conv(x)
  212. m = MyModule4()
  213. x = Tensor(np.random.random((1, 3, 32, 32)))
  214. traced_module = trace_module(m, x)
  215. orig_loader_dict = S.MODULE_LOADER
  216. S.MODULE_LOADER = {}
  217. @register_module_loader(("megengine.module.conv", "Conv2d"))
  218. def conv2dm_loader(expr):
  219. module = expr.inputs[0].owner
  220. args = list(expr.args)
  221. orig_inp = args[1]
  222. astype_expr = CallMethod(orig_inp, "astype")
  223. oup = TensorNode(
  224. astype_expr,
  225. shape=orig_inp.shape,
  226. dtype=orig_inp.dtype,
  227. qparams=orig_inp.qparams,
  228. )
  229. astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
  230. astype_expr.return_val = (oup,)
  231. args[1] = oup
  232. expr.set_args_kwargs(*args)
  233. obj = pickle.dumps(traced_module)
  234. new_module = pickle.loads(obj)
  235. result = new_module(x)
  236. gt = m(x)
  237. assert (
  238. isinstance(new_module.graph._exprs[1], CallMethod)
  239. and len(new_module.graph._exprs) == 3
  240. )
  241. np.testing.assert_equal(result.numpy(), gt.numpy())
  242. S.MODULE_LOADER = orig_loader_dict
  243. def test_shared_module():
  244. class MyModule(M.Module):
  245. def __init__(self):
  246. super().__init__()
  247. self.a = M.Elemwise("ADD")
  248. self.b = self.a
  249. def forward(self, x, y):
  250. z = self.a(x, y)
  251. z = self.b(z, y)
  252. return z
  253. x = Tensor(1)
  254. y = Tensor(2)
  255. m = MyModule()
  256. tm = trace_module(m, x, y)
  257. obj = pickle.dumps(tm)
  258. load_tm = pickle.loads(obj)
  259. _check_expr_users(load_tm)
  260. _check_name(load_tm.flatten())
  261. _check_id(load_tm)
  262. assert load_tm.a is load_tm.b
  263. def test_convert_kwargs_to_args():
  264. def func(a, b, c=4, *, d, e=3, f=4):
  265. pass
  266. args = (1,)
  267. kwargs = {"b": 1, "d": 6}
  268. new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs)
  269. assert new_args == (1, 1, 4)
  270. assert new_kwargs == {"d": 6, "e": 3, "f": 4}
  271. args = (1,)
  272. kwargs = {"d": 6}
  273. new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True)
  274. assert new_args == (1, 4)
  275. assert new_kwargs == {"d": 6, "e": 3, "f": 4}
  276. def func1(a, b, c, d, e, *, f):
  277. pass
  278. args = ()
  279. kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6}
  280. new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs)
  281. assert new_args == (1, 2, 3, 4, 5)
  282. assert new_kwargs == {"f": 6}
  283. def test_opdef_serialization():
  284. with TemporaryFile() as f:
  285. x = builtin.Elemwise(mode="Add")
  286. pickle.dump(x, f)
  287. f.seek(0)
  288. load_x = pickle.load(f)
  289. assert x == load_x
  290. with TemporaryFile() as f:
  291. x = builtin.Convolution(stride_h=9, compute_mode="float32")
  292. x.strategy = (
  293. builtin.Convolution.Strategy.PROFILE
  294. | builtin.Convolution.Strategy.HEURISTIC
  295. | builtin.Convolution.Strategy.REPRODUCIBLE
  296. )
  297. pickle.dump(x, f)
  298. f.seek(0)
  299. load_x = pickle.load(f)
  300. assert x.strategy == load_x.strategy
  301. assert x == load_x