You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import io
  2. import numpy as np
  3. import megengine.core.tensor.megbrain_graph as G
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.utils.network_node as N
  7. from megengine.jit.tracing import trace
  8. from megengine.tensor import Tensor
  9. from megengine.utils.comp_graph_tools import GraphInference
  10. from megengine.utils.network import Network as Net
  11. from megengine.utils.network import as_oprnode, set_symbolic_shape
  12. from megengine.utils.network_node import Host2DeviceCopy, VarNode
  13. def test_replace_var():
  14. a = Tensor([1, 2])
  15. b = Tensor([3, 4])
  16. @trace(symbolic=True, capture_as_const=True)
  17. def fwd(a, b):
  18. return (a + b) * 2
  19. fwd(a, b)
  20. orig_model = io.BytesIO()
  21. fwd.dump(
  22. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  23. )
  24. orig_model.seek(0)
  25. graph = Net.load(orig_model)
  26. vara = graph.var_filter.name("a").as_unique()
  27. varb = graph.var_filter.name("b").as_unique()
  28. out = F.mul(vara, varb)
  29. out = F.relu(out)
  30. opnode = list(graph.opr_filter.has_input(vara))
  31. repl_dict = {opnode[0].outputs[0]: out}
  32. graph.replace_vars(repl_dict)
  33. modified_model = io.BytesIO()
  34. graph.dump(modified_model)
  35. modified_model.seek(0)
  36. load_graph = GraphInference(modified_model)
  37. out = load_graph.run(a, b)
  38. np.testing.assert_equal(out["o"], [6, 16])
  39. def test_replace_opr():
  40. a = Tensor([1, 2])
  41. b = Tensor([3, 4])
  42. @trace(symbolic=True, capture_as_const=True)
  43. def fwd(a, b):
  44. return (a + b) * 2
  45. fwd(a, b)
  46. orig_model = io.BytesIO()
  47. fwd.dump(
  48. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  49. )
  50. orig_model.seek(0)
  51. graph = Net.load(orig_model)
  52. vara = graph.var_filter.name("a").as_unique()
  53. varb = graph.var_filter.name("b").as_unique()
  54. out1 = F.sub(vara, varb)
  55. out1 = F.relu(out1)
  56. out1 = graph.add_dep_oprs(out1)
  57. orig_opr = graph.opr_filter.has_input(vara).as_unique()
  58. repl_dict = {orig_opr: out1[0].owner}
  59. graph.replace_oprs(repl_dict)
  60. modified_model1 = io.BytesIO()
  61. graph.dump(modified_model1)
  62. modified_model1.seek(0)
  63. load_graph = GraphInference(modified_model1)
  64. out = load_graph.run(a, b)
  65. np.testing.assert_equal(out["o"], [0, 0])
  66. def test_modify_params():
  67. a = Tensor([1, 2])
  68. b = Tensor([3, 4])
  69. @trace(symbolic=True, capture_as_const=True)
  70. def fwd(a, b):
  71. return (a + b) * 2
  72. fwd(a, b)
  73. orig_model = io.BytesIO()
  74. fwd.dump(
  75. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  76. )
  77. orig_model.seek(0)
  78. graph = Net.load(orig_model)
  79. param_const = graph.params_filter.as_unique()
  80. param_const.set_value(3)
  81. modified_model = io.BytesIO()
  82. graph.dump(modified_model)
  83. modified_model.seek(0)
  84. load_graph = GraphInference(modified_model)
  85. out = load_graph.run(a, b)
  86. np.testing.assert_equal(out["o"], [12, 18])
  87. def test_make_const():
  88. a = Tensor([1, 2])
  89. b = Tensor([3, 4])
  90. @trace(symbolic=True, capture_as_const=True)
  91. def fwd(a, b):
  92. return (a + b) * 2
  93. fwd(a, b)
  94. orig_model = io.BytesIO()
  95. fwd.dump(
  96. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  97. )
  98. orig_model.seek(0)
  99. graph = Net.load(orig_model)
  100. const_b = graph.make_const(np.array([0.0, 0.0]), name="b")
  101. varb = graph.var_filter.name("b").as_unique()
  102. repl_dict = {varb: const_b}
  103. graph.replace_vars(repl_dict)
  104. modified_model = io.BytesIO()
  105. graph.dump(modified_model)
  106. modified_model.seek(0)
  107. load_graph = GraphInference(modified_model)
  108. out = load_graph.run(a)
  109. np.testing.assert_equal(out["o"], [2, 4])
  110. def test_add_input():
  111. a = Tensor([1, 2])
  112. b = Tensor([3, 4])
  113. @trace(symbolic=True, capture_as_const=True)
  114. def fwd(a, b):
  115. return (a + b) * 2
  116. fwd(a, b)
  117. orig_model = io.BytesIO()
  118. fwd.dump(
  119. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  120. )
  121. orig_model.seek(0)
  122. graph = Net.load(orig_model)
  123. inp_c = graph.make_input_node((2,), np.int32, name="c")
  124. varo = graph.var_filter.name("o").as_unique()
  125. out = F.add(varo, inp_c)
  126. out.name = "o1"
  127. graph.remove_output(varo)
  128. graph.add_output(out)
  129. modified_model = io.BytesIO()
  130. graph.dump(modified_model)
  131. modified_model.seek(0)
  132. load_graph = GraphInference(modified_model)
  133. out = load_graph.run(a, b, a)
  134. np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy())
  135. def test_add_remove_output():
  136. a = Tensor([1.0, 2.0])
  137. b = Tensor([3.0, 4.0])
  138. @trace(symbolic=True, capture_as_const=True)
  139. def fwd(a, b):
  140. return (a + b) * 2, (a - b)
  141. fwd(a, b)
  142. orig_model = io.BytesIO()
  143. fwd.dump(
  144. orig_model,
  145. arg_names=["a", "b"],
  146. output_names=["o1", "o2"],
  147. optimize_for_inference=False,
  148. )
  149. orig_model.seek(0)
  150. net = Net.load(orig_model)
  151. var_a = net.var_filter.name("a").as_unique()
  152. var_b = net.var_filter.name("b").as_unique()
  153. y1 = (var_a + var_b) * 3
  154. y2 = F.sigmoid(var_a + var_b)
  155. net.remove_output(*net.output_vars)
  156. y1.name = "new_o1"
  157. y2.name = "new_o2"
  158. net.add_output(y1, y2)
  159. modified_model = io.BytesIO()
  160. net.dump(modified_model)
  161. modified_model.seek(0)
  162. g = GraphInference(modified_model)
  163. out = g.run(a.numpy(), b.numpy())
  164. np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy())
  165. np.testing.assert_equal(out["new_o2"], (F.sigmoid((a + b))).numpy())
  166. def test_query():
  167. class Model(M.Module):
  168. def __init__(self):
  169. super().__init__()
  170. self.conv1 = M.Conv2d(3, 32, 3)
  171. self.conv2 = M.Conv2d(32, 32, 3)
  172. self.conv3 = M.Conv2d(32, 32, 3)
  173. def forward(self, data):
  174. x = self.conv1(data)
  175. x = self.conv2(x)
  176. x = self.conv3(x)
  177. return x
  178. n = Model()
  179. @trace(symbolic=True, capture_as_const=True)
  180. def fwd(data):
  181. return n(data)
  182. fwd(Tensor(np.random.random((1, 3, 224, 224))))
  183. orig_model = io.BytesIO()
  184. fwd.dump(
  185. orig_model,
  186. arg_names=["data"],
  187. output_names="o",
  188. keep_opr_name=True,
  189. keep_var_name=True,
  190. optimize_for_inference=False,
  191. )
  192. orig_model.seek(0)
  193. graph = Net.load(orig_model)
  194. r = graph.data_providers_filter.as_count()
  195. assert r == 1
  196. opr = graph.get_opr_by_type(Host2DeviceCopy)
  197. assert isinstance(opr, Host2DeviceCopy)
  198. r1 = graph.params_filter.as_count()
  199. assert r1 == 6
  200. r2 = graph.opr_filter.type(N.ConvolutionForward).as_count()
  201. assert r2 == 3
  202. r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count()
  203. assert r3 == len(graph.all_oprs) - r2
  204. var = graph.var_filter.name("data").as_unique()
  205. r4 = graph.opr_filter.has_input(var).as_count()
  206. assert r4 == 1
  207. r5 = graph.opr_filter.name("data").as_count()
  208. assert r5 == 1
  209. opr = graph.get_opr_by_name("data")
  210. assert isinstance(opr, Host2DeviceCopy)
  211. var = graph.get_var_by_name("data")
  212. assert isinstance(var, VarNode)
  213. r6 = graph.var_filter.name("*bias").as_count()
  214. assert r6 == 3
  215. def test_optimize_for_inference():
  216. @trace(symbolic=True, capture_as_const=True)
  217. def f(x):
  218. return F.exp(x)
  219. orig_model = io.BytesIO()
  220. f(Tensor(5.0))
  221. f.dump(orig_model, optimize_for_inference=False)
  222. orig_model.seek(0)
  223. optimize_model = io.BytesIO()
  224. net = Net.load(orig_model)
  225. net.dump(optimize_model, enable_io16xc32=True)
  226. optimize_model.seek(0)
  227. res = G.load_graph(optimize_model)
  228. computing_input = res.output_vars_list[0].owner.inputs[0]
  229. assert computing_input.dtype == np.float16
  230. def test_reset_batchsize():
  231. @trace(symbolic=True, capture_as_const=True)
  232. def f(x):
  233. return F.exp(x)
  234. orig_model = io.BytesIO()
  235. f(Tensor(np.random.random((3, 3, 224, 224))))
  236. f.dump(orig_model, optimize_for_inference=False)
  237. orig_model.seek(0)
  238. modified_model = io.BytesIO()
  239. net = Net.load(orig_model)
  240. net.reset_batch_size(1)
  241. net.dump(modified_model, optimize_for_inference=False)
  242. modified_model.seek(0)
  243. net1 = Net.load(modified_model)
  244. assert net1.data_providers_filter.as_unique().shape[0] == 1
  245. def test_modify_opr_name():
  246. @trace(symbolic=True, capture_as_const=True)
  247. def f(x):
  248. return F.exp(x)
  249. orig_model = io.BytesIO()
  250. f(Tensor(np.random.random((3, 3, 224, 224))))
  251. f.dump(orig_model, arg_names=["a"], optimize_for_inference=False)
  252. orig_model.seek(0)
  253. modified_model = io.BytesIO()
  254. net = Net.load(orig_model)
  255. net.modify_opr_names("net")
  256. net.modify_opr_names(lambda x: "net1." + x)
  257. net.dump(modified_model, optimize_for_inference=False)
  258. modified_model.seek(0)
  259. net1 = Net.load(modified_model)
  260. assert net1.data_providers_filter.as_unique().name == "net1.net.a"
  261. def test_dump_cond_take():
  262. a = Tensor([1.0, 2.0])
  263. @trace(symbolic=True, capture_as_const=True)
  264. def fwd(a):
  265. return F.cond_take(a > 1, a)
  266. fwd(a)
  267. orig_model = io.BytesIO()
  268. fwd.dump(
  269. orig_model,
  270. arg_names=["a"],
  271. output_names=["o1", "o2"],
  272. optimize_for_inference=False,
  273. )
  274. orig_model.seek(0)
  275. net = Net.load(orig_model)
  276. var_a = net.input_vars[0]
  277. val, idx = F.cond_take(var_a > 1, var_a)
  278. net.remove_output(*net.output_vars)
  279. val.name = "value"
  280. idx.name = "index"
  281. net.add_output(val, idx)
  282. modified_model = io.BytesIO()
  283. net.dump(modified_model)
  284. modified_model.seek(0)
  285. g = GraphInference(modified_model)
  286. out = g.run(a.numpy())
  287. data = a.numpy()
  288. mask = a.numpy() > 1
  289. np.testing.assert_equal(out["index"], np.where(mask.reshape(-1))[0])
  290. np.testing.assert_equal(out["value"], data[mask])
  291. def test_set_symbolic_shape():
  292. a = Tensor([1.0, 2.0])
  293. @trace(symbolic=True, capture_as_const=True)
  294. def fwd(a):
  295. return F.relu(a * 2)
  296. fwd(a)
  297. orig_model = io.BytesIO()
  298. fwd.dump(
  299. orig_model, arg_names=["a"], output_names=["o"], optimize_for_inference=False,
  300. )
  301. orig_model.seek(0)
  302. net = Net.load(orig_model)
  303. var_a = net.input_vars[0]
  304. saved_symbolic_shape = set_symbolic_shape(True)
  305. assert isinstance(var_a.shape, VarNode)
  306. set_symbolic_shape(False)
  307. assert var_a.shape == var_a.partial_shape
  308. set_symbolic_shape(saved_symbolic_shape)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台