You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serialization.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import pickle
  9. from collections import defaultdict
  10. from tempfile import TemporaryFile
  11. import numpy as np
  12. import megengine.functional as F
  13. import megengine.module as M
  14. import megengine.traced_module.serialization as S
  15. from megengine import Tensor
  16. from megengine.core._imperative_rt.core2 import apply
  17. from megengine.core.ops import builtin
  18. from megengine.core.ops.builtin import Elemwise
  19. from megengine.module import Module
  20. from megengine.traced_module import trace_module
  21. from megengine.traced_module.expr import CallMethod, Constant
  22. from megengine.traced_module.node import TensorNode
  23. from megengine.traced_module.serialization import (
  24. register_functional_loader,
  25. register_module_loader,
  26. register_opdef_loader,
  27. register_tensor_method_loader,
  28. )
  29. from megengine.traced_module.utils import _convert_kwargs_to_args
  30. def _check_id(traced_module):
  31. _total_ids = traced_module.graph._total_ids
  32. node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
  33. assert len(set(node_ids)) == len(node_ids)
  34. assert max(node_ids) + 1 == _total_ids[0]
  35. expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
  36. assert len(set(expr_ids)) == len(expr_ids)
  37. assert max(expr_ids) + 1 == _total_ids[1]
  38. def _check_name(flatened_module):
  39. node_names = [n._name for n in flatened_module.graph.nodes().as_list()]
  40. assert len(set(node_names)) == len(node_names)
  41. def _check_expr_users(traced_module):
  42. node_user = defaultdict(list)
  43. for expr in traced_module.graph._exprs:
  44. for node in expr.inputs:
  45. node_user[node].append(expr)
  46. if isinstance(expr, CallMethod) and expr.graph:
  47. _check_expr_users(expr.inputs[0].owner)
  48. for node in traced_module.graph.nodes(False):
  49. node.users.sort(key=lambda m: m._id)
  50. node_user[node].sort(key=lambda m: m._id)
  51. assert node.users == node_user[node]
  52. class MyBlock(Module):
  53. def __init__(self, in_channels, channels):
  54. super(MyBlock, self).__init__()
  55. self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
  56. self.bn1 = M.BatchNorm2d(channels)
  57. def forward(self, x):
  58. x = self.conv1(x)
  59. x = self.bn1(x)
  60. x = F.relu(x) + 1
  61. return x
  62. class MyModule(Module):
  63. def __init__(self):
  64. super(MyModule, self).__init__()
  65. self.block0 = MyBlock(8, 4)
  66. self.block1 = MyBlock(4, 2)
  67. def forward(self, x):
  68. x = self.block0(x)
  69. x = self.block1(x)
  70. return x
  71. def test_dump_and_load():
  72. module = MyModule()
  73. x = Tensor(np.ones((1, 8, 14, 14)))
  74. expect = module(x)
  75. traced_module = trace_module(module, x)
  76. np.testing.assert_array_equal(expect, traced_module(x))
  77. obj = pickle.dumps(traced_module)
  78. new_tm = pickle.loads(obj)
  79. _check_id(new_tm)
  80. _check_expr_users(new_tm)
  81. traced_module.graph._reset_ids()
  82. old_nodes = traced_module.graph.nodes().as_list()
  83. new_nodes = new_tm.graph.nodes().as_list()
  84. old_exprs = traced_module.graph.exprs().as_list()
  85. new_exprs = new_tm.graph.exprs().as_list()
  86. assert len(old_nodes) == len(new_nodes)
  87. for i, j in zip(old_nodes, new_nodes):
  88. assert i._name == j._name
  89. assert i._qualname == j._qualname
  90. assert i._id == j._id
  91. assert len(old_exprs) == len(new_exprs)
  92. for i, j in zip(old_exprs, new_exprs):
  93. assert i._id == j._id
  94. np.testing.assert_array_equal(expect, traced_module(x))
  95. def test_opdef_loader():
  96. class MyModule1(Module):
  97. def forward(self, x, y):
  98. op = Elemwise("ADD")
  99. return apply(op, x, y)[0]
  100. m = MyModule1()
  101. x = Tensor(np.ones((20)))
  102. y = Tensor(np.ones((20)))
  103. traced_module = trace_module(m, x, y)
  104. orig_loader_dict = S.OPDEF_LOADER
  105. S.OPDEF_LOADER = {}
  106. @register_opdef_loader(Elemwise)
  107. def add_opdef_loader(expr):
  108. if expr.opdef_state["mode"] == "ADD":
  109. expr.opdef_state["mode"] = "MUL"
  110. node = expr.inputs[1]
  111. astype_expr = CallMethod(node, "astype")
  112. oup = TensorNode(
  113. astype_expr,
  114. shape=node.shape,
  115. dtype=expr.inputs[0].dtype,
  116. qparams=node.qparams,
  117. )
  118. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  119. astype_expr.return_val = (oup,)
  120. expr.inputs[1] = oup
  121. obj = pickle.dumps(traced_module)
  122. new_module = pickle.loads(obj)
  123. _check_id(new_module)
  124. _check_expr_users(new_module)
  125. _check_name(new_module.flatten())
  126. assert (
  127. isinstance(new_module.graph._exprs[0], CallMethod)
  128. and new_module.graph._exprs[1].opdef.mode == "MUL"
  129. and len(new_module.graph._exprs) == 2
  130. )
  131. result = new_module(x, y)
  132. np.testing.assert_equal(result.numpy(), x.numpy())
  133. S.OPDEF_LOADER = orig_loader_dict
  134. def test_functional_loader():
  135. class MyModule2(Module):
  136. def forward(self, x, y):
  137. return F.conv2d(x, y)
  138. m = MyModule2()
  139. x = Tensor(np.random.random((1, 3, 32, 32)))
  140. y = Tensor(np.random.random((3, 3, 3, 3)))
  141. traced_module = trace_module(m, x, y)
  142. orig_loader_dict = S.FUNCTIONAL_LOADER
  143. S.FUNCTIONAL_LOADER = {}
  144. @register_functional_loader(("megengine.functional.nn", "conv2d"))
  145. def conv2df_loader(expr):
  146. # expr.func = ("megengine.functional.nn","conv2d")
  147. kwargs = expr.kwargs
  148. orig_weight = expr.named_args["weight"]
  149. astype_expr = CallMethod(orig_weight, "astype")
  150. oup = TensorNode(
  151. astype_expr,
  152. shape=orig_weight.shape,
  153. dtype=orig_weight.dtype,
  154. qparams=orig_weight.qparams,
  155. )
  156. astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
  157. astype_expr.return_val = (oup,)
  158. expr.set_arg("weight", oup)
  159. obj = pickle.dumps(traced_module)
  160. new_module = pickle.loads(obj)
  161. _check_expr_users(new_module)
  162. _check_id(new_module)
  163. result = new_module(x, y)
  164. gt = m(x, y)
  165. assert (
  166. isinstance(new_module.graph._exprs[0], CallMethod)
  167. and len(new_module.graph._exprs) == 2
  168. )
  169. np.testing.assert_equal(result.numpy(), gt.numpy())
  170. S.FUNCTIONAL_LOADER = orig_loader_dict
  171. def test_tensor_method_loader():
  172. class MyModule3(Module):
  173. def forward(self, x):
  174. return x + 1
  175. m = MyModule3()
  176. x = Tensor(np.ones((20)))
  177. traced_module = trace_module(m, x)
  178. orig_loader_dict = S.TENSORMETHOD_LOADER
  179. S.TENSORMETHOD_LOADER = {}
  180. @register_tensor_method_loader("__add__")
  181. def add_loader(expr):
  182. args = list(expr.args)
  183. if not isinstance(args[1], TensorNode):
  184. args[1] = Tensor(args[1])
  185. node = Constant(args[1], "const").outputs[0]
  186. astype_expr = CallMethod(node, "astype")
  187. oup = TensorNode(
  188. astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
  189. )
  190. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  191. astype_expr.return_val = (oup,)
  192. add_expr = CallMethod(oup, "__add__")
  193. add_expr.set_args_kwargs(oup, oup)
  194. oup1 = TensorNode(
  195. add_expr, shape=oup.shape, dtype=oup.dtype, qparams=node.qparams,
  196. )
  197. add_expr.return_val = oup1
  198. args[1] = oup1
  199. expr.set_args_kwargs(*args)
  200. obj = pickle.dumps(traced_module)
  201. new_module = pickle.loads(obj)
  202. _check_expr_users(new_module)
  203. _check_id(new_module)
  204. result = new_module(x)
  205. gt = m(x)
  206. assert (
  207. isinstance(new_module.graph._exprs[0], Constant)
  208. and len(new_module.graph._exprs) == 4
  209. )
  210. np.testing.assert_equal(result.numpy(), (x + 2).numpy())
  211. S.TENSORMETHOD_LOADER = orig_loader_dict
  212. def test_module_loader():
  213. class MyModule4(Module):
  214. def __init__(self):
  215. super().__init__()
  216. self.conv = M.Conv2d(3, 3, 3)
  217. def forward(self, x):
  218. return self.conv(x)
  219. m = MyModule4()
  220. x = Tensor(np.random.random((1, 3, 32, 32)))
  221. traced_module = trace_module(m, x)
  222. orig_loader_dict = S.MODULE_LOADER
  223. S.MODULE_LOADER = {}
  224. @register_module_loader(("megengine.module.conv", "Conv2d"))
  225. def conv2dm_loader(expr):
  226. module = expr.inputs[0].owner
  227. args = list(expr.args)
  228. orig_inp = args[1]
  229. astype_expr = CallMethod(orig_inp, "astype")
  230. oup = TensorNode(
  231. astype_expr,
  232. shape=orig_inp.shape,
  233. dtype=orig_inp.dtype,
  234. qparams=orig_inp.qparams,
  235. )
  236. astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
  237. astype_expr.return_val = (oup,)
  238. args[1] = oup
  239. expr.set_args_kwargs(*args)
  240. obj = pickle.dumps(traced_module)
  241. new_module = pickle.loads(obj)
  242. result = new_module(x)
  243. gt = m(x)
  244. assert (
  245. isinstance(new_module.graph._exprs[1], CallMethod)
  246. and len(new_module.graph._exprs) == 3
  247. )
  248. np.testing.assert_equal(result.numpy(), gt.numpy())
  249. S.MODULE_LOADER = orig_loader_dict
  250. def test_shared_module():
  251. class MyModule(M.Module):
  252. def __init__(self):
  253. super().__init__()
  254. self.a = M.Elemwise("ADD")
  255. self.b = self.a
  256. def forward(self, x, y):
  257. z = self.a(x, y)
  258. z = self.b(z, y)
  259. return z
  260. x = Tensor(1)
  261. y = Tensor(2)
  262. m = MyModule()
  263. tm = trace_module(m, x, y)
  264. obj = pickle.dumps(tm)
  265. load_tm = pickle.loads(obj)
  266. _check_expr_users(load_tm)
  267. _check_name(load_tm.flatten())
  268. _check_id(load_tm)
  269. assert load_tm.a is load_tm.b
  270. def test_convert_kwargs_to_args():
  271. def func(a, b, c=4, *, d, e=3, f=4):
  272. pass
  273. args = (1,)
  274. kwargs = {"b": 1, "d": 6}
  275. new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs)
  276. assert new_args == (1, 1, 4)
  277. assert new_kwargs == {"d": 6, "e": 3, "f": 4}
  278. args = (1,)
  279. kwargs = {"d": 6}
  280. new_args, new_kwargs = _convert_kwargs_to_args(func, args, kwargs, is_bounded=True)
  281. assert new_args == (1, 4)
  282. assert new_kwargs == {"d": 6, "e": 3, "f": 4}
  283. def func1(a, b, c, d, e, *, f):
  284. pass
  285. args = ()
  286. kwargs = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6}
  287. new_args, new_kwargs = _convert_kwargs_to_args(func1, args, kwargs)
  288. assert new_args == (1, 2, 3, 4, 5)
  289. assert new_kwargs == {"f": 6}
  290. def test_opdef_serialization():
  291. with TemporaryFile() as f:
  292. x = builtin.Elemwise(mode="Add")
  293. pickle.dump(x, f)
  294. f.seek(0)
  295. load_x = pickle.load(f)
  296. assert x == load_x
  297. with TemporaryFile() as f:
  298. x = builtin.Convolution(stride_h=9, compute_mode="float32")
  299. x.strategy = (
  300. builtin.Convolution.Strategy.PROFILE
  301. | builtin.Convolution.Strategy.HEURISTIC
  302. | builtin.Convolution.Strategy.REPRODUCIBLE
  303. )
  304. pickle.dump(x, f)
  305. f.seek(0)
  306. load_x = pickle.load(f)
  307. assert x.strategy == load_x.strategy
  308. assert x == load_x

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台