You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_network.py 9.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import io
  2. import numpy as np
  3. import megengine.core.tensor.megbrain_graph as G
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.utils.network_node as N
  7. from megengine.jit.tracing import trace
  8. from megengine.tensor import Tensor
  9. from megengine.utils.comp_graph_tools import GraphInference
  10. from megengine.utils.network import Network as Net
  11. from megengine.utils.network import as_oprnode
  12. from megengine.utils.network_node import Host2DeviceCopy, VarNode
  13. def test_replace_var():
  14. a = Tensor([1, 2])
  15. b = Tensor([3, 4])
  16. @trace(symbolic=True, capture_as_const=True)
  17. def fwd(a, b):
  18. return (a + b) * 2
  19. fwd(a, b)
  20. orig_model = io.BytesIO()
  21. fwd.dump(
  22. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  23. )
  24. orig_model.seek(0)
  25. graph = Net.load(orig_model)
  26. vara = graph.var_filter.name("a").as_unique()
  27. varb = graph.var_filter.name("b").as_unique()
  28. out = F.mul(vara.var, varb.var)
  29. out = F.relu(out)
  30. var_list = graph.add_dep_oprs(out)
  31. opnode = list(graph.opr_filter.has_input(vara))
  32. repl_dict = {opnode[0].outputs[0]: var_list[0]}
  33. graph.replace_vars(repl_dict)
  34. modified_model = io.BytesIO()
  35. graph.dump(modified_model)
  36. modified_model.seek(0)
  37. load_graph = GraphInference(modified_model)
  38. out = load_graph.run(a, b)
  39. np.testing.assert_equal(out["o"], [6, 16])
  40. def test_replace_opr():
  41. a = Tensor([1, 2])
  42. b = Tensor([3, 4])
  43. @trace(symbolic=True, capture_as_const=True)
  44. def fwd(a, b):
  45. return (a + b) * 2
  46. fwd(a, b)
  47. orig_model = io.BytesIO()
  48. fwd.dump(
  49. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  50. )
  51. orig_model.seek(0)
  52. graph = Net.load(orig_model)
  53. vara = graph.var_filter.name("a").as_unique()
  54. varb = graph.var_filter.name("b").as_unique()
  55. out1 = F.sub(vara.var, varb.var)
  56. out1 = F.relu(out1)
  57. var_list = graph.add_dep_oprs(out1)
  58. repl_opr = as_oprnode(var_list)
  59. orig_opr = graph.opr_filter.has_input(vara).as_unique()
  60. repl_dict = {orig_opr: repl_opr}
  61. graph.replace_oprs(repl_dict)
  62. modified_model1 = io.BytesIO()
  63. graph.dump(modified_model1)
  64. modified_model1.seek(0)
  65. load_graph = GraphInference(modified_model1)
  66. out = load_graph.run(a, b)
  67. np.testing.assert_equal(out["o"], [0, 0])
  68. def test_modify_params():
  69. a = Tensor([1, 2])
  70. b = Tensor([3, 4])
  71. @trace(symbolic=True, capture_as_const=True)
  72. def fwd(a, b):
  73. return (a + b) * 2
  74. fwd(a, b)
  75. orig_model = io.BytesIO()
  76. fwd.dump(
  77. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  78. )
  79. orig_model.seek(0)
  80. graph = Net.load(orig_model)
  81. param_const = graph.params_filter.as_unique()
  82. param_const.set_value(3)
  83. modified_model = io.BytesIO()
  84. graph.dump(modified_model)
  85. modified_model.seek(0)
  86. load_graph = GraphInference(modified_model)
  87. out = load_graph.run(a, b)
  88. np.testing.assert_equal(out["o"], [12, 18])
  89. def test_make_const():
  90. a = Tensor([1, 2])
  91. b = Tensor([3, 4])
  92. @trace(symbolic=True, capture_as_const=True)
  93. def fwd(a, b):
  94. return (a + b) * 2
  95. fwd(a, b)
  96. orig_model = io.BytesIO()
  97. fwd.dump(
  98. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  99. )
  100. orig_model.seek(0)
  101. graph = Net.load(orig_model)
  102. const_b = graph.make_const(np.array([0.0, 0.0]), name="b")
  103. varb = graph.var_filter.name("b").as_unique()
  104. repl_dict = {varb: const_b}
  105. graph.replace_vars(repl_dict)
  106. modified_model = io.BytesIO()
  107. graph.dump(modified_model)
  108. modified_model.seek(0)
  109. load_graph = GraphInference(modified_model)
  110. out = load_graph.run(a)
  111. np.testing.assert_equal(out["o"], [2, 4])
  112. def test_add_input():
  113. a = Tensor([1, 2])
  114. b = Tensor([3, 4])
  115. @trace(symbolic=True, capture_as_const=True)
  116. def fwd(a, b):
  117. return (a + b) * 2
  118. fwd(a, b)
  119. orig_model = io.BytesIO()
  120. fwd.dump(
  121. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  122. )
  123. orig_model.seek(0)
  124. graph = Net.load(orig_model)
  125. inp_c = graph.make_input_node((2,), np.int32, name="c")
  126. varo = graph.var_filter.name("o").as_unique()
  127. out = F.add(varo.var, inp_c.var)
  128. out = graph.add_dep_oprs(out)[0]
  129. out.name = "o1"
  130. graph.remove_output(varo)
  131. graph.add_output(out)
  132. modified_model = io.BytesIO()
  133. graph.dump(modified_model)
  134. modified_model.seek(0)
  135. load_graph = GraphInference(modified_model)
  136. out = load_graph.run(a, b, a)
  137. np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy())
  138. def test_add_output():
  139. a = Tensor([1.0, 2.0])
  140. b = Tensor([3.0, 4.0])
  141. @trace(symbolic=True, capture_as_const=True)
  142. def fwd(a, b):
  143. return (a + b) * 2
  144. fwd(a, b)
  145. orig_model = io.BytesIO()
  146. fwd.dump(
  147. orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
  148. )
  149. orig_model.seek(0)
  150. net = Net.load(orig_model)
  151. var_a = net.var_filter.name("a").as_unique()
  152. var_b = net.var_filter.name("b").as_unique()
  153. y = F.add(var_a.var, var_b.var)
  154. y = F.sigmoid(y)
  155. new_vars = net.add_dep_oprs(y)[0]
  156. new_vars.name = "o1"
  157. net.add_output(new_vars)
  158. modified_model = io.BytesIO()
  159. net.dump(modified_model)
  160. modified_model.seek(0)
  161. g = GraphInference(modified_model)
  162. out = g.run(a.numpy(), b.numpy())
  163. np.testing.assert_equal(out["o"], ((a + b) * 2).numpy())
  164. np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy())
  165. def test_query():
  166. class Model(M.Module):
  167. def __init__(self):
  168. super().__init__()
  169. self.conv1 = M.Conv2d(3, 32, 3)
  170. self.conv2 = M.Conv2d(32, 32, 3)
  171. self.conv3 = M.Conv2d(32, 32, 3)
  172. def forward(self, data):
  173. x = self.conv1(data)
  174. x = self.conv2(x)
  175. x = self.conv3(x)
  176. return x
  177. n = Model()
  178. @trace(symbolic=True, capture_as_const=True)
  179. def fwd(data):
  180. return n(data)
  181. fwd(Tensor(np.random.random((1, 3, 224, 224))))
  182. orig_model = io.BytesIO()
  183. fwd.dump(
  184. orig_model,
  185. arg_names=["data"],
  186. output_names="o",
  187. keep_opr_name=True,
  188. keep_var_name=True,
  189. optimize_for_inference=False,
  190. )
  191. orig_model.seek(0)
  192. graph = Net.load(orig_model)
  193. r = graph.data_providers_filter.as_count()
  194. assert r == 1
  195. opr = graph.get_opr_by_type(Host2DeviceCopy)
  196. assert isinstance(opr, Host2DeviceCopy)
  197. r1 = graph.params_filter.as_count()
  198. assert r1 == 6
  199. r2 = graph.opr_filter.type(N.ConvolutionForward).as_count()
  200. assert r2 == 3
  201. r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count()
  202. assert r3 == len(graph.all_oprs) - r2
  203. var = graph.var_filter.name("data").as_unique()
  204. r4 = graph.opr_filter.has_input(var).as_count()
  205. assert r4 == 1
  206. r5 = graph.opr_filter.name("data").as_count()
  207. assert r5 == 1
  208. opr = graph.get_opr_by_name("data")
  209. assert isinstance(opr, Host2DeviceCopy)
  210. var = graph.get_var_by_name("data")
  211. assert isinstance(var, VarNode)
  212. r6 = graph.var_filter.name("*bias").as_count()
  213. assert r6 == 3
  214. def test_optimize_for_inference():
  215. @trace(symbolic=True, capture_as_const=True)
  216. def f(x):
  217. return F.exp(x)
  218. orig_model = io.BytesIO()
  219. f(Tensor(5.0))
  220. f.dump(orig_model, optimize_for_inference=False)
  221. orig_model.seek(0)
  222. optimize_model = io.BytesIO()
  223. net = Net.load(orig_model)
  224. net.dump(optimize_model, enable_io16xc32=True)
  225. optimize_model.seek(0)
  226. res = G.load_graph(optimize_model)
  227. computing_input = res.output_vars_list[0].owner.inputs[0]
  228. assert computing_input.dtype == np.float16
  229. def test_reset_batchsize():
  230. @trace(symbolic=True, capture_as_const=True)
  231. def f(x):
  232. return F.exp(x)
  233. orig_model = io.BytesIO()
  234. f(Tensor(np.random.random((3, 3, 224, 224))))
  235. f.dump(orig_model, optimize_for_inference=False)
  236. orig_model.seek(0)
  237. modified_model = io.BytesIO()
  238. net = Net.load(orig_model)
  239. net.reset_batch_size(1)
  240. net.dump(modified_model, optimize_for_inference=False)
  241. modified_model.seek(0)
  242. net1 = Net.load(modified_model)
  243. assert net1.data_providers_filter.as_unique().shape[0] == 1
  244. def test_modify_opr_name():
  245. @trace(symbolic=True, capture_as_const=True)
  246. def f(x):
  247. return F.exp(x)
  248. orig_model = io.BytesIO()
  249. f(Tensor(np.random.random((3, 3, 224, 224))))
  250. f.dump(orig_model, arg_names=["a"], optimize_for_inference=False)
  251. orig_model.seek(0)
  252. modified_model = io.BytesIO()
  253. net = Net.load(orig_model)
  254. net.modify_opr_names("net")
  255. net.modify_opr_names(lambda x: "net1." + x)
  256. net.dump(modified_model, optimize_for_inference=False)
  257. modified_model.seek(0)
  258. net1 = Net.load(modified_model)
  259. assert net1.data_providers_filter.as_unique().name == "net1.net.a"

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台