You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cgtools.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import io
  9. import numpy as np
  10. import pytest
  11. import megengine
  12. import megengine.functional as F
  13. import megengine.module as M
  14. import megengine.utils.comp_graph_tools as cgtools
  15. from megengine.core.ops.builtin import Elemwise
  16. from megengine.core.tensor import megbrain_graph as mgb_graph
  17. from megengine.core.tensor.megbrain_graph import apply_normal_varnode
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.jit import trace
  20. from megengine.utils.network import Network
  21. def make_dev_tensor(value, dtype=None, device=None):
  22. return megengine.tensor(value, dtype=dtype, device=device)._dev_tensor()
  23. def test_replace_vars():
  24. g = mgb_graph.Graph()
  25. g.options.async_exec_level = 0b100
  26. device = "xpux"
  27. dtype = np.float32
  28. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  29. const = g.make_const(1.234, device=device)
  30. add_op = Elemwise(Elemwise.Mode.ADD)
  31. mul_op = Elemwise(Elemwise.Mode.MUL)
  32. a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
  33. a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
  34. rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0]
  35. (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
  36. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  37. func = g.compile(out.outputs[0])
  38. func.execute()
  39. x = make_dev_tensor(5.0, device=device)
  40. a.set_value(x)
  41. res = out.get_value().numpy()
  42. np.testing.assert_equal(res, np.array([105.0]))
  43. def test_replace_oprs():
  44. g = mgb_graph.Graph()
  45. g.options.async_exec_level = 0b100
  46. device = "xpux"
  47. dtype = np.float32
  48. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  49. const = g.make_const(1.25, device=device)
  50. add_op = Elemwise(Elemwise.Mode.ADD)
  51. mul_op = Elemwise(Elemwise.Mode.MUL)
  52. a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
  53. old_opr = a_plus_a.op
  54. a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
  55. a_mul_a = apply_normal_varnode(mul_op, a.outputs[0], a.outputs[0])[0]
  56. new_opr = a_mul_a.op
  57. (new,) = cgtools.replace_oprs(
  58. [a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
  59. )
  60. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  61. func = g.compile(out.outputs[0])
  62. func.execute()
  63. x = make_dev_tensor(5.0, device=device)
  64. a.set_value(x)
  65. res = out.get_value().numpy()
  66. np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
  67. def test_graph_traversal():
  68. net = M.Conv2d(3, 32, 3)
  69. @trace(symbolic=True, capture_as_const=True)
  70. def fun(data):
  71. x = net(data)
  72. return x
  73. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  74. for _ in range(3):
  75. fun(megengine.tensor(data))
  76. file = io.BytesIO()
  77. fun.dump(file, optimize_for_inference=False)
  78. file.seek(0)
  79. outputs = mgb_graph.load_graph(file).output_vars_list
  80. _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
  81. input_var = map_vars[1]
  82. _, var_idx = var2oprs[input_var.id][0]
  83. assert var_idx == 0
  84. def test_load_refcnt():
  85. graph = mgb_graph.Graph()
  86. varnode = graph.make_const(0)
  87. buf, _ = mgb_graph.dump_graph([varnode])
  88. ret = mgb_graph.load_graph(io.BytesIO(buf))
  89. graph, (varnode,) = ret.graph, ret.output_vars_list
  90. del ret
  91. del graph
  92. varnode.owner
  93. def test_get_opr_seq():
  94. class Net(M.Module):
  95. def __init__(self):
  96. super().__init__()
  97. self.data = megengine.tensor(
  98. np.random.random((1, 1, 4, 4)), dtype=np.float32
  99. )
  100. def forward(self, input):
  101. A = input.shape[0]
  102. shape = astensor1d((A, A), self.data, dtype="int32", device=input.device)
  103. x = F.reshape(self.data, shape)
  104. o = input + x
  105. return o
  106. net = Net()
  107. input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32)
  108. @trace(symbolic=True, capture_as_const=True)
  109. def func(inp, *, net=None):
  110. return net(inp)
  111. func(input, net=net)
  112. file = io.BytesIO()
  113. func.dump(file, optimize_for_inference=False)
  114. file.seek(0)
  115. outputs = mgb_graph.load_graph(file).output_vars_list
  116. seq_1 = cgtools.get_oprs_seq(outputs, True)
  117. assert len(seq_1) == 5
  118. seq_2 = cgtools.get_oprs_seq(outputs, False)
  119. assert len(seq_2) == 6
  120. def test_topological_sort():
  121. @trace(symbolic=True, capture_as_const=True)
  122. def func(x, y):
  123. a = x + y
  124. a1 = F.relu(a)
  125. a2 = F.abs(a)
  126. a3 = F.ceil(a) * 2
  127. a4 = F.floor(a)
  128. r = a1 - a2
  129. r1 = a3 / a4
  130. return r, r1
  131. file = io.BytesIO()
  132. func(megengine.tensor(1.0), megengine.tensor(2.0))
  133. func.dump(
  134. file, optimize_for_inference=False, keep_opr_name=True, keep_opr_priority=True
  135. )
  136. file.seek(0)
  137. g = Network.load(file)
  138. oprseq1 = g.all_oprs
  139. gt = [
  140. "Host2DeviceCopy",
  141. "Host2DeviceCopy",
  142. "ADD",
  143. "RELU",
  144. "ABS",
  145. "CEIL",
  146. "ImmutableTensor",
  147. "MUL",
  148. "FLOOR",
  149. "SUB",
  150. "TRUE_DIV",
  151. ]
  152. for op, mode in zip(oprseq1, gt):
  153. if op.type == "Elemwise":
  154. assert op.params["mode"] == mode
  155. else:
  156. assert op.type == mode
  157. def test_graph_function():
  158. class Net(M.Module):
  159. def forward(self, a, b):
  160. return a - b, a * b
  161. net = Net()
  162. @trace(symbolic=True, capture_as_const=True)
  163. def function(a, b, *, net=None):
  164. return net(a, b)
  165. a = np.array([1, 2, 3])
  166. b = np.array([3])
  167. x, y = function(megengine.tensor(a), megengine.tensor(b), net=net)
  168. file = io.BytesIO()
  169. function.dump(
  170. file,
  171. arg_names=["a", "b"],
  172. output_names=["x", "y"],
  173. optimize_for_inference=False,
  174. )
  175. file.seek(0)
  176. graph = cgtools.GraphInference(file)
  177. results = graph.run(inp_dict={"a": a, "b": b})
  178. np.testing.assert_equal(x.numpy(), results["x"])
  179. np.testing.assert_equal(y.numpy(), results["y"])
  180. results = graph.run(a, inp_dict={"b": b})
  181. np.testing.assert_equal(x.numpy(), results["x"])
  182. np.testing.assert_equal(y.numpy(), results["y"])
  183. results = graph.run(a, b)
  184. np.testing.assert_equal(x.numpy(), results["x"])
  185. np.testing.assert_equal(y.numpy(), results["y"])
  186. file.seek(0)
  187. graph1 = cgtools.GraphInference(file, outputs=["x"])
  188. results = graph1.run(inp_dict={"a": a, "b": b})
  189. np.testing.assert_equal(x.numpy(), results["x"])
  190. assert "y" not in results

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台