You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. import io
  2. import numpy as np
  3. import megengine.core.tensor.megbrain_graph as G
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.utils.network_node as N
  7. from megengine.jit.tracing import trace
  8. from megengine.tensor import Tensor
  9. from megengine.utils.comp_graph_tools import GraphInference
  10. from megengine.utils.network import Network as Net
  11. from megengine.utils.network import as_oprnode, set_symbolic_shape
  12. from megengine.utils.network_node import Host2DeviceCopy, VarNode
  13. def test_metadata():
  14. x = Tensor(0)
  15. @trace(symbolic=True, capture_as_const=True)
  16. def fwd(x):
  17. return x * 2
  18. fwd(x)
  19. orig_model = io.BytesIO()
  20. fwd.dump(orig_model, user_info="test", optimize_for_inference=False)
  21. orig_model.seek(0)
  22. graph = Net.load(orig_model)
  23. assert graph.metadata == {
  24. "user_info": "test",
  25. "graph_modified": False, # False: tracing.dump
  26. "optimized_for_inference": False,
  27. }
  28. orig_model.seek(0)
  29. graph.dump(
  30. orig_model,
  31. user_info={"str": "x", "tensor": x, "module": M.Module, "none": None},
  32. optimize_for_inference=True,
  33. enable_nchw4=True,
  34. enable_ioc16=True,
  35. )
  36. orig_model.seek(0)
  37. graph = Net.load(orig_model)
  38. assert graph.metadata == {
  39. "user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
  40. "graph_modified": True, # True: Network.dump
  41. "optimized_for_inference": True,
  42. "enable_fuse_grain": True,
  43. "enable_nchw4": True,
  44. "enable_ioc16": True,
  45. }
  46. orig_model.seek(0)
  47. fwd.dump(orig_model, enable_metadata=False)
  48. orig_model.seek(0)
  49. graph = Net.load(orig_model)
  50. assert graph.metadata is None
  51. def test_replace_var():
  52. a = Tensor([1, 2])
  53. b = Tensor([3, 4])
  54. @trace(symbolic=True, capture_as_const=True)
  55. def fwd(a, b):
  56. return (a + b) * 2
  57. fwd(a, b)
  58. orig_model = io.BytesIO()
  59. fwd.dump(
  60. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  61. )
  62. orig_model.seek(0)
  63. graph = Net.load(orig_model)
  64. vara = graph.var_filter.name("a").as_unique()
  65. varb = graph.var_filter.name("b").as_unique()
  66. out = F.mul(vara, varb)
  67. out = F.relu(out)
  68. opnode = list(graph.opr_filter.has_input(vara))
  69. repl_dict = {opnode[0].outputs[0]: out}
  70. graph.replace_vars(repl_dict)
  71. modified_model = io.BytesIO()
  72. graph.dump(modified_model)
  73. modified_model.seek(0)
  74. load_graph = GraphInference(modified_model)
  75. out = load_graph.run(a, b)
  76. np.testing.assert_equal(out["o"], [6, 16])
  77. def test_replace_opr():
  78. a = Tensor([1, 2])
  79. b = Tensor([3, 4])
  80. @trace(symbolic=True, capture_as_const=True)
  81. def fwd(a, b):
  82. return (a + b) * 2
  83. fwd(a, b)
  84. orig_model = io.BytesIO()
  85. fwd.dump(
  86. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  87. )
  88. orig_model.seek(0)
  89. graph = Net.load(orig_model)
  90. vara = graph.var_filter.name("a").as_unique()
  91. varb = graph.var_filter.name("b").as_unique()
  92. out1 = F.mul(vara, varb)
  93. out1 = F.relu(out1)
  94. out1 += 2
  95. out1 *= 3
  96. out1 = graph.add_dep_oprs(out1)
  97. orig_opr = graph.opr_filter.has_input(vara).as_unique()
  98. new_opr = out1[0].owner
  99. repl_dict = {orig_opr: new_opr}
  100. graph.replace_oprs(repl_dict)
  101. var_out = orig_opr.outputs
  102. for idx, node in enumerate(var_out):
  103. assert node.owner is new_opr
  104. assert node.owner.outputs[idx] is node
  105. modified_model1 = io.BytesIO()
  106. graph.dump(modified_model1)
  107. modified_model1.seek(0)
  108. load_graph = GraphInference(modified_model1)
  109. out = load_graph.run(a, b)
  110. np.testing.assert_equal(out["o"], [30, 60])
  111. def test_splice_network():
  112. x = F.ones((2,))
  113. y = F.ones((2,))
  114. @trace(symbolic=True, capture_as_const=True)
  115. def fun1(a, b):
  116. return (a + b) * 2
  117. @trace(symbolic=True, capture_as_const=True)
  118. def fun2(a):
  119. return a * 2 - 1
  120. model = io.BytesIO()
  121. fun1(x, y)
  122. fun2(x)
  123. fun1.dump(
  124. model,
  125. arg_names=["net1_i0", "net1_i1"],
  126. output_names=["net1_o0"],
  127. optimize_for_inference=False,
  128. )
  129. model.seek(0)
  130. net1 = Net.load(model)
  131. model.seek(0)
  132. fun2.dump(
  133. model,
  134. arg_names=["net2_i0"],
  135. output_names=["net2_o0"],
  136. optimize_for_inference=False,
  137. )
  138. model.seek(0)
  139. net2 = Net.load(model)
  140. net1.add_output(*net2.output_vars)
  141. var = net1.var_filter.name("net1_i0").as_unique()
  142. repl_var = net2.var_filter.name("net2_o0").as_unique()
  143. net1.replace_vars({var: repl_var})
  144. assert "net1_i0" not in [var.name for var in net1.all_vars]
  145. assert "net2_i0" in [var.name for var in net1.all_vars]
  146. model.seek(0)
  147. net1.dump(model, keep_var_name=2, optimize_for_inference=False)
  148. model.seek(0)
  149. net = Net.load(model)
  150. assert "net1_i0" not in [var.name for var in net.all_vars]
  151. assert "net2_i0" in [var.name for var in net.all_vars]
  152. def test_modify_params():
  153. a = Tensor([1, 2])
  154. b = Tensor([3, 4])
  155. @trace(symbolic=True, capture_as_const=True)
  156. def fwd(a, b):
  157. return (a + b) * 2
  158. fwd(a, b)
  159. orig_model = io.BytesIO()
  160. fwd.dump(
  161. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  162. )
  163. orig_model.seek(0)
  164. graph = Net.load(orig_model)
  165. param_const = graph.params_filter.as_unique()
  166. param_const.set_value(3)
  167. modified_model = io.BytesIO()
  168. graph.dump(modified_model)
  169. modified_model.seek(0)
  170. load_graph = GraphInference(modified_model)
  171. out = load_graph.run(a, b)
  172. np.testing.assert_equal(out["o"], [12, 18])
  173. def test_make_const():
  174. a = Tensor([1, 2])
  175. b = Tensor([3, 4])
  176. @trace(symbolic=True, capture_as_const=True)
  177. def fwd(a, b):
  178. return (a + b) * 2
  179. fwd(a, b)
  180. orig_model = io.BytesIO()
  181. fwd.dump(
  182. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  183. )
  184. orig_model.seek(0)
  185. graph = Net.load(orig_model)
  186. const_b = graph.make_const(np.array([0.0, 0.0]), name="b")
  187. varb = graph.var_filter.name("b").as_unique()
  188. repl_dict = {varb: const_b}
  189. graph.replace_vars(repl_dict)
  190. modified_model = io.BytesIO()
  191. graph.dump(modified_model)
  192. modified_model.seek(0)
  193. load_graph = GraphInference(modified_model)
  194. out = load_graph.run(a)
  195. np.testing.assert_equal(out["o"], [2, 4])
  196. def test_add_input():
  197. a = Tensor([1, 2])
  198. b = Tensor([3, 4])
  199. @trace(symbolic=True, capture_as_const=True)
  200. def fwd(a, b):
  201. return (a + b) * 2
  202. fwd(a, b)
  203. orig_model = io.BytesIO()
  204. fwd.dump(
  205. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  206. )
  207. orig_model.seek(0)
  208. graph = Net.load(orig_model)
  209. inp_c = graph.make_input_node((2,), np.int32, name="c")
  210. varo = graph.var_filter.name("o").as_unique()
  211. out = F.add(varo, inp_c)
  212. out.name = "o1"
  213. graph.remove_output(varo)
  214. graph.add_output(out)
  215. modified_model = io.BytesIO()
  216. graph.dump(modified_model)
  217. modified_model.seek(0)
  218. load_graph = GraphInference(modified_model)
  219. out = load_graph.run(a, b, a)
  220. np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy())
  221. def test_add_remove_output():
  222. a = Tensor([1.0, 2.0])
  223. b = Tensor([3.0, 4.0])
  224. @trace(symbolic=True, capture_as_const=True)
  225. def fwd(a, b):
  226. return (a + b) * 2, (a - b)
  227. fwd(a, b)
  228. orig_model = io.BytesIO()
  229. fwd.dump(
  230. orig_model,
  231. arg_names=["a", "b"],
  232. output_names=["o1", "o2"],
  233. optimize_for_inference=False,
  234. )
  235. orig_model.seek(0)
  236. net = Net.load(orig_model)
  237. var_a = net.var_filter.name("a").as_unique()
  238. var_b = net.var_filter.name("b").as_unique()
  239. y1 = (var_a + var_b) * 3
  240. y2 = F.sigmoid(var_a + var_b)
  241. net.remove_output(*net.output_vars)
  242. y1.name = "new_o1"
  243. y2.name = "new_o2"
  244. net.add_output(y1, y2)
  245. modified_model = io.BytesIO()
  246. net.dump(modified_model)
  247. modified_model.seek(0)
  248. g = GraphInference(modified_model)
  249. out = g.run(a.numpy(), b.numpy())
  250. np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy())
  251. np.testing.assert_almost_equal(out["new_o2"], (F.sigmoid((a + b))).numpy())
  252. def test_query():
  253. class Model(M.Module):
  254. def __init__(self):
  255. super().__init__()
  256. self.conv1 = M.Conv2d(3, 32, 3)
  257. self.conv2 = M.Conv2d(32, 32, 3)
  258. self.conv3 = M.Conv2d(32, 32, 3)
  259. def forward(self, data):
  260. x = self.conv1(data)
  261. x = self.conv2(x)
  262. x = self.conv3(x)
  263. return x
  264. n = Model()
  265. @trace(symbolic=True, capture_as_const=True)
  266. def fwd(data):
  267. return n(data)
  268. fwd(Tensor(np.random.random((1, 3, 224, 224))))
  269. orig_model = io.BytesIO()
  270. fwd.dump(
  271. orig_model,
  272. arg_names=["data"],
  273. output_names="o",
  274. keep_opr_name=True,
  275. keep_var_name=True,
  276. optimize_for_inference=False,
  277. )
  278. orig_model.seek(0)
  279. graph = Net.load(orig_model)
  280. r = graph.data_providers_filter.as_count()
  281. assert r == 1
  282. opr = graph.get_opr_by_type(Host2DeviceCopy)
  283. assert isinstance(opr, Host2DeviceCopy)
  284. r1 = graph.params_filter.as_count()
  285. assert r1 == 6
  286. r2 = graph.opr_filter.type(N.ConvolutionForward).as_count()
  287. assert r2 == 3
  288. r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count()
  289. assert r3 == len(graph.all_oprs) - r2
  290. var = graph.var_filter.name("data").as_unique()
  291. r4 = graph.opr_filter.has_input(var).as_count()
  292. assert r4 == 1
  293. r5 = graph.opr_filter.name("data").as_count()
  294. assert r5 == 1
  295. opr = graph.get_opr_by_name("data")
  296. assert isinstance(opr, Host2DeviceCopy)
  297. var = graph.get_var_by_name("data")
  298. assert isinstance(var, VarNode)
  299. r6 = graph.var_filter.name("*bias").as_count()
  300. assert r6 == 3
  301. def test_optimize_for_inference():
  302. @trace(symbolic=True, capture_as_const=True)
  303. def f(x):
  304. return F.exp(x)
  305. orig_model = io.BytesIO()
  306. f(Tensor(5.0))
  307. f.dump(orig_model, optimize_for_inference=False)
  308. orig_model.seek(0)
  309. optimize_model = io.BytesIO()
  310. net = Net.load(orig_model)
  311. net.dump(optimize_model, enable_io16xc32=True)
  312. optimize_model.seek(0)
  313. res = G.load_graph(optimize_model)
  314. computing_input = res.output_vars_list[0].owner.inputs[0]
  315. assert computing_input.dtype == np.float16
  316. def test_reset_batchsize():
  317. @trace(symbolic=True, capture_as_const=True)
  318. def f(x):
  319. return F.exp(x)
  320. orig_model = io.BytesIO()
  321. f(Tensor(np.random.random((3, 3, 224, 224))))
  322. f.dump(orig_model, optimize_for_inference=False)
  323. orig_model.seek(0)
  324. modified_model = io.BytesIO()
  325. net = Net.load(orig_model)
  326. net.reset_batch_size(1)
  327. net.dump(modified_model, optimize_for_inference=False)
  328. modified_model.seek(0)
  329. net1 = Net.load(modified_model)
  330. assert net1.data_providers_filter.as_unique().shape[0] == 1
  331. def test_modify_opr_name():
  332. @trace(symbolic=True, capture_as_const=True)
  333. def f(x):
  334. return F.exp(x)
  335. orig_model = io.BytesIO()
  336. f(Tensor(np.random.random((3, 3, 224, 224))))
  337. f.dump(orig_model, arg_names=["a"], optimize_for_inference=False)
  338. orig_model.seek(0)
  339. modified_model = io.BytesIO()
  340. net = Net.load(orig_model)
  341. net.modify_opr_names("net")
  342. net.modify_opr_names(lambda x: "net1." + x)
  343. net.dump(modified_model, optimize_for_inference=False)
  344. modified_model.seek(0)
  345. net1 = Net.load(modified_model)
  346. assert net1.data_providers_filter.as_unique().name == "net1.net.a"
  347. def test_dump_cond_take():
  348. a = Tensor([1.0, 2.0])
  349. @trace(symbolic=True, capture_as_const=True)
  350. def fwd(a):
  351. return F.cond_take(a > 1, a)
  352. fwd(a)
  353. orig_model = io.BytesIO()
  354. fwd.dump(
  355. orig_model,
  356. arg_names=["a"],
  357. output_names=["o1", "o2"],
  358. optimize_for_inference=False,
  359. )
  360. orig_model.seek(0)
  361. net = Net.load(orig_model)
  362. var_a = net.input_vars[0]
  363. val, idx = F.cond_take(var_a > 1, var_a)
  364. net.remove_output(*net.output_vars)
  365. val.name = "value"
  366. idx.name = "index"
  367. net.add_output(val, idx)
  368. modified_model = io.BytesIO()
  369. net.dump(modified_model)
  370. modified_model.seek(0)
  371. g = GraphInference(modified_model)
  372. out = g.run(a.numpy())
  373. data = a.numpy()
  374. mask = a.numpy() > 1
  375. np.testing.assert_equal(out["index"], np.where(mask.reshape(-1))[0])
  376. np.testing.assert_equal(out["value"], data[mask])
  377. def test_set_symbolic_shape():
  378. a = Tensor([1.0, 2.0])
  379. @trace(symbolic=True, capture_as_const=True)
  380. def fwd(a):
  381. return F.relu(a * 2)
  382. fwd(a)
  383. orig_model = io.BytesIO()
  384. fwd.dump(
  385. orig_model, arg_names=["a"], output_names=["o"], optimize_for_inference=False,
  386. )
  387. orig_model.seek(0)
  388. net = Net.load(orig_model)
  389. var_a = net.input_vars[0]
  390. saved_symbolic_shape = set_symbolic_shape(True)
  391. assert isinstance(var_a.shape, VarNode)
  392. set_symbolic_shape(False)
  393. assert var_a.shape == var_a.partial_shape
  394. set_symbolic_shape(saved_symbolic_shape)
  395. def test_replace_var_in_different_network():
  396. a = Tensor([1, 2])
  397. b = Tensor([3, 4])
  398. @trace(symbolic=True, capture_as_const=True)
  399. def fwd(a, b):
  400. return (a + b) * 2
  401. @trace(symbolic=True, capture_as_const=True)
  402. def fwd1(c, d):
  403. return c + d
  404. fwd(a, b)
  405. orig_model = io.BytesIO()
  406. fwd.dump(
  407. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  408. )
  409. orig_model.seek(0)
  410. fwd1(a, b)
  411. orig_model1 = io.BytesIO()
  412. fwd1.dump(
  413. orig_model1,
  414. arg_names=["c", "d"],
  415. output_names="o",
  416. optimize_for_inference=False,
  417. )
  418. orig_model1.seek(0)
  419. graph = Net.load(orig_model)
  420. graph1 = Net.load(orig_model1)
  421. vara = graph.var_filter.name("a").as_unique()
  422. varb = graph.var_filter.name("b").as_unique()
  423. varo = graph1.var_filter.name("o").as_unique()
  424. graph.replace_vars({vara: varo, varb: varo})
  425. modified_model = io.BytesIO()
  426. graph.dump(modified_model)
  427. modified_model.seek(0)
  428. load_graph = GraphInference(modified_model)
  429. out = load_graph.run(a, b)
  430. np.testing.assert_equal(out["o"], [16, 24])