You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. import io
  2. import numpy as np
  3. import megengine.core.tensor.megbrain_graph as G
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.utils.network_node as N
  7. from megengine.jit.tracing import trace
  8. from megengine.tensor import Tensor
  9. from megengine.utils.comp_graph_tools import GraphInference
  10. from megengine.utils.network import Network as Net
  11. from megengine.utils.network import as_oprnode, set_symbolic_shape
  12. from megengine.utils.network_node import Host2DeviceCopy, VarNode
  13. def test_metadata():
  14. x = Tensor(0)
  15. @trace(symbolic=True, capture_as_const=True)
  16. def fwd(x):
  17. return x * 2
  18. fwd(x)
  19. orig_model = io.BytesIO()
  20. fwd.dump(orig_model, user_info="test", optimize_for_inference=False)
  21. orig_model.seek(0)
  22. graph = Net.load(orig_model)
  23. assert graph.metadata == {
  24. "user_info": "test",
  25. "graph_modified": False, # False: tracing.dump
  26. "optimized_for_inference": False,
  27. }
  28. orig_model.seek(0)
  29. graph.dump(
  30. orig_model,
  31. user_info={"str": "x", "tensor": x, "module": M.Module, "none": None},
  32. optimize_for_inference=True,
  33. enable_nchw4=True,
  34. enable_ioc16=True,
  35. )
  36. orig_model.seek(0)
  37. graph = Net.load(orig_model)
  38. assert graph.metadata == {
  39. "user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
  40. "graph_modified": True, # True: Network.dump
  41. "optimized_for_inference": True,
  42. "enable_nchw4": True,
  43. "enable_ioc16": True,
  44. }
  45. orig_model.seek(0)
  46. fwd.dump(orig_model, enable_metadata=False)
  47. orig_model.seek(0)
  48. graph = Net.load(orig_model)
  49. assert graph.metadata is None
  50. def test_replace_var():
  51. a = Tensor([1, 2])
  52. b = Tensor([3, 4])
  53. @trace(symbolic=True, capture_as_const=True)
  54. def fwd(a, b):
  55. return (a + b) * 2
  56. fwd(a, b)
  57. orig_model = io.BytesIO()
  58. fwd.dump(
  59. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  60. )
  61. orig_model.seek(0)
  62. graph = Net.load(orig_model)
  63. vara = graph.var_filter.name("a").as_unique()
  64. varb = graph.var_filter.name("b").as_unique()
  65. out = F.mul(vara, varb)
  66. out = F.relu(out)
  67. opnode = list(graph.opr_filter.has_input(vara))
  68. repl_dict = {opnode[0].outputs[0]: out}
  69. graph.replace_vars(repl_dict)
  70. modified_model = io.BytesIO()
  71. graph.dump(modified_model)
  72. modified_model.seek(0)
  73. load_graph = GraphInference(modified_model)
  74. out = load_graph.run(a, b)
  75. np.testing.assert_equal(out["o"], [6, 16])
  76. def test_replace_opr():
  77. a = Tensor([1, 2])
  78. b = Tensor([3, 4])
  79. @trace(symbolic=True, capture_as_const=True)
  80. def fwd(a, b):
  81. return (a + b) * 2
  82. fwd(a, b)
  83. orig_model = io.BytesIO()
  84. fwd.dump(
  85. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  86. )
  87. orig_model.seek(0)
  88. graph = Net.load(orig_model)
  89. vara = graph.var_filter.name("a").as_unique()
  90. varb = graph.var_filter.name("b").as_unique()
  91. out1 = F.sub(vara, varb)
  92. out1 = F.relu(out1)
  93. out1 = graph.add_dep_oprs(out1)
  94. orig_opr = graph.opr_filter.has_input(vara).as_unique()
  95. repl_dict = {orig_opr: out1[0].owner}
  96. graph.replace_oprs(repl_dict)
  97. modified_model1 = io.BytesIO()
  98. graph.dump(modified_model1)
  99. modified_model1.seek(0)
  100. load_graph = GraphInference(modified_model1)
  101. out = load_graph.run(a, b)
  102. np.testing.assert_equal(out["o"], [0, 0])
  103. def test_splice_network():
  104. x = F.ones((2,))
  105. y = F.ones((2,))
  106. @trace(symbolic=True, capture_as_const=True)
  107. def fun1(a, b):
  108. return (a + b) * 2
  109. @trace(symbolic=True, capture_as_const=True)
  110. def fun2(a):
  111. return a * 2 - 1
  112. model = io.BytesIO()
  113. fun1(x, y)
  114. fun2(x)
  115. fun1.dump(
  116. model,
  117. arg_names=["net1_i0", "net1_i1"],
  118. output_names=["net1_o0"],
  119. optimize_for_inference=False,
  120. )
  121. model.seek(0)
  122. net1 = Net.load(model)
  123. model.seek(0)
  124. fun2.dump(
  125. model,
  126. arg_names=["net2_i0"],
  127. output_names=["net2_o0"],
  128. optimize_for_inference=False,
  129. )
  130. model.seek(0)
  131. net2 = Net.load(model)
  132. net1.add_output(*net2.output_vars)
  133. var = net1.var_filter.name("net1_i0").as_unique()
  134. repl_var = net2.var_filter.name("net2_o0").as_unique()
  135. net1.replace_vars({var: repl_var})
  136. assert "net1_i0" not in [var.name for var in net1.all_vars]
  137. assert "net2_i0" in [var.name for var in net1.all_vars]
  138. model.seek(0)
  139. net1.dump(model, keep_var_name=2, optimize_for_inference=False)
  140. model.seek(0)
  141. net = Net.load(model)
  142. assert "net1_i0" not in [var.name for var in net.all_vars]
  143. assert "net2_i0" in [var.name for var in net.all_vars]
  144. def test_modify_params():
  145. a = Tensor([1, 2])
  146. b = Tensor([3, 4])
  147. @trace(symbolic=True, capture_as_const=True)
  148. def fwd(a, b):
  149. return (a + b) * 2
  150. fwd(a, b)
  151. orig_model = io.BytesIO()
  152. fwd.dump(
  153. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  154. )
  155. orig_model.seek(0)
  156. graph = Net.load(orig_model)
  157. param_const = graph.params_filter.as_unique()
  158. param_const.set_value(3)
  159. modified_model = io.BytesIO()
  160. graph.dump(modified_model)
  161. modified_model.seek(0)
  162. load_graph = GraphInference(modified_model)
  163. out = load_graph.run(a, b)
  164. np.testing.assert_equal(out["o"], [12, 18])
  165. def test_make_const():
  166. a = Tensor([1, 2])
  167. b = Tensor([3, 4])
  168. @trace(symbolic=True, capture_as_const=True)
  169. def fwd(a, b):
  170. return (a + b) * 2
  171. fwd(a, b)
  172. orig_model = io.BytesIO()
  173. fwd.dump(
  174. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  175. )
  176. orig_model.seek(0)
  177. graph = Net.load(orig_model)
  178. const_b = graph.make_const(np.array([0.0, 0.0]), name="b")
  179. varb = graph.var_filter.name("b").as_unique()
  180. repl_dict = {varb: const_b}
  181. graph.replace_vars(repl_dict)
  182. modified_model = io.BytesIO()
  183. graph.dump(modified_model)
  184. modified_model.seek(0)
  185. load_graph = GraphInference(modified_model)
  186. out = load_graph.run(a)
  187. np.testing.assert_equal(out["o"], [2, 4])
  188. def test_add_input():
  189. a = Tensor([1, 2])
  190. b = Tensor([3, 4])
  191. @trace(symbolic=True, capture_as_const=True)
  192. def fwd(a, b):
  193. return (a + b) * 2
  194. fwd(a, b)
  195. orig_model = io.BytesIO()
  196. fwd.dump(
  197. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  198. )
  199. orig_model.seek(0)
  200. graph = Net.load(orig_model)
  201. inp_c = graph.make_input_node((2,), np.int32, name="c")
  202. varo = graph.var_filter.name("o").as_unique()
  203. out = F.add(varo, inp_c)
  204. out.name = "o1"
  205. graph.remove_output(varo)
  206. graph.add_output(out)
  207. modified_model = io.BytesIO()
  208. graph.dump(modified_model)
  209. modified_model.seek(0)
  210. load_graph = GraphInference(modified_model)
  211. out = load_graph.run(a, b, a)
  212. np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy())
  213. def test_add_remove_output():
  214. a = Tensor([1.0, 2.0])
  215. b = Tensor([3.0, 4.0])
  216. @trace(symbolic=True, capture_as_const=True)
  217. def fwd(a, b):
  218. return (a + b) * 2, (a - b)
  219. fwd(a, b)
  220. orig_model = io.BytesIO()
  221. fwd.dump(
  222. orig_model,
  223. arg_names=["a", "b"],
  224. output_names=["o1", "o2"],
  225. optimize_for_inference=False,
  226. )
  227. orig_model.seek(0)
  228. net = Net.load(orig_model)
  229. var_a = net.var_filter.name("a").as_unique()
  230. var_b = net.var_filter.name("b").as_unique()
  231. y1 = (var_a + var_b) * 3
  232. y2 = F.sigmoid(var_a + var_b)
  233. net.remove_output(*net.output_vars)
  234. y1.name = "new_o1"
  235. y2.name = "new_o2"
  236. net.add_output(y1, y2)
  237. modified_model = io.BytesIO()
  238. net.dump(modified_model)
  239. modified_model.seek(0)
  240. g = GraphInference(modified_model)
  241. out = g.run(a.numpy(), b.numpy())
  242. np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy())
  243. np.testing.assert_equal(out["new_o2"], (F.sigmoid((a + b))).numpy())
  244. def test_query():
  245. class Model(M.Module):
  246. def __init__(self):
  247. super().__init__()
  248. self.conv1 = M.Conv2d(3, 32, 3)
  249. self.conv2 = M.Conv2d(32, 32, 3)
  250. self.conv3 = M.Conv2d(32, 32, 3)
  251. def forward(self, data):
  252. x = self.conv1(data)
  253. x = self.conv2(x)
  254. x = self.conv3(x)
  255. return x
  256. n = Model()
  257. @trace(symbolic=True, capture_as_const=True)
  258. def fwd(data):
  259. return n(data)
  260. fwd(Tensor(np.random.random((1, 3, 224, 224))))
  261. orig_model = io.BytesIO()
  262. fwd.dump(
  263. orig_model,
  264. arg_names=["data"],
  265. output_names="o",
  266. keep_opr_name=True,
  267. keep_var_name=True,
  268. optimize_for_inference=False,
  269. )
  270. orig_model.seek(0)
  271. graph = Net.load(orig_model)
  272. r = graph.data_providers_filter.as_count()
  273. assert r == 1
  274. opr = graph.get_opr_by_type(Host2DeviceCopy)
  275. assert isinstance(opr, Host2DeviceCopy)
  276. r1 = graph.params_filter.as_count()
  277. assert r1 == 6
  278. r2 = graph.opr_filter.type(N.ConvolutionForward).as_count()
  279. assert r2 == 3
  280. r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count()
  281. assert r3 == len(graph.all_oprs) - r2
  282. var = graph.var_filter.name("data").as_unique()
  283. r4 = graph.opr_filter.has_input(var).as_count()
  284. assert r4 == 1
  285. r5 = graph.opr_filter.name("data").as_count()
  286. assert r5 == 1
  287. opr = graph.get_opr_by_name("data")
  288. assert isinstance(opr, Host2DeviceCopy)
  289. var = graph.get_var_by_name("data")
  290. assert isinstance(var, VarNode)
  291. r6 = graph.var_filter.name("*bias").as_count()
  292. assert r6 == 3
  293. def test_optimize_for_inference():
  294. @trace(symbolic=True, capture_as_const=True)
  295. def f(x):
  296. return F.exp(x)
  297. orig_model = io.BytesIO()
  298. f(Tensor(5.0))
  299. f.dump(orig_model, optimize_for_inference=False)
  300. orig_model.seek(0)
  301. optimize_model = io.BytesIO()
  302. net = Net.load(orig_model)
  303. net.dump(optimize_model, enable_io16xc32=True)
  304. optimize_model.seek(0)
  305. res = G.load_graph(optimize_model)
  306. computing_input = res.output_vars_list[0].owner.inputs[0]
  307. assert computing_input.dtype == np.float16
  308. def test_reset_batchsize():
  309. @trace(symbolic=True, capture_as_const=True)
  310. def f(x):
  311. return F.exp(x)
  312. orig_model = io.BytesIO()
  313. f(Tensor(np.random.random((3, 3, 224, 224))))
  314. f.dump(orig_model, optimize_for_inference=False)
  315. orig_model.seek(0)
  316. modified_model = io.BytesIO()
  317. net = Net.load(orig_model)
  318. net.reset_batch_size(1)
  319. net.dump(modified_model, optimize_for_inference=False)
  320. modified_model.seek(0)
  321. net1 = Net.load(modified_model)
  322. assert net1.data_providers_filter.as_unique().shape[0] == 1
  323. def test_modify_opr_name():
  324. @trace(symbolic=True, capture_as_const=True)
  325. def f(x):
  326. return F.exp(x)
  327. orig_model = io.BytesIO()
  328. f(Tensor(np.random.random((3, 3, 224, 224))))
  329. f.dump(orig_model, arg_names=["a"], optimize_for_inference=False)
  330. orig_model.seek(0)
  331. modified_model = io.BytesIO()
  332. net = Net.load(orig_model)
  333. net.modify_opr_names("net")
  334. net.modify_opr_names(lambda x: "net1." + x)
  335. net.dump(modified_model, optimize_for_inference=False)
  336. modified_model.seek(0)
  337. net1 = Net.load(modified_model)
  338. assert net1.data_providers_filter.as_unique().name == "net1.net.a"
  339. def test_dump_cond_take():
  340. a = Tensor([1.0, 2.0])
  341. @trace(symbolic=True, capture_as_const=True)
  342. def fwd(a):
  343. return F.cond_take(a > 1, a)
  344. fwd(a)
  345. orig_model = io.BytesIO()
  346. fwd.dump(
  347. orig_model,
  348. arg_names=["a"],
  349. output_names=["o1", "o2"],
  350. optimize_for_inference=False,
  351. )
  352. orig_model.seek(0)
  353. net = Net.load(orig_model)
  354. var_a = net.input_vars[0]
  355. val, idx = F.cond_take(var_a > 1, var_a)
  356. net.remove_output(*net.output_vars)
  357. val.name = "value"
  358. idx.name = "index"
  359. net.add_output(val, idx)
  360. modified_model = io.BytesIO()
  361. net.dump(modified_model)
  362. modified_model.seek(0)
  363. g = GraphInference(modified_model)
  364. out = g.run(a.numpy())
  365. data = a.numpy()
  366. mask = a.numpy() > 1
  367. np.testing.assert_equal(out["index"], np.where(mask.reshape(-1))[0])
  368. np.testing.assert_equal(out["value"], data[mask])
  369. def test_set_symbolic_shape():
  370. a = Tensor([1.0, 2.0])
  371. @trace(symbolic=True, capture_as_const=True)
  372. def fwd(a):
  373. return F.relu(a * 2)
  374. fwd(a)
  375. orig_model = io.BytesIO()
  376. fwd.dump(
  377. orig_model, arg_names=["a"], output_names=["o"], optimize_for_inference=False,
  378. )
  379. orig_model.seek(0)
  380. net = Net.load(orig_model)
  381. var_a = net.input_vars[0]
  382. saved_symbolic_shape = set_symbolic_shape(True)
  383. assert isinstance(var_a.shape, VarNode)
  384. set_symbolic_shape(False)
  385. assert var_a.shape == var_a.partial_shape
  386. set_symbolic_shape(saved_symbolic_shape)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台