You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cgtools.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import io
  9. import numpy as np
  10. import pytest
  11. import megengine
  12. import megengine.functional as F
  13. import megengine.module as M
  14. import megengine.utils.comp_graph_tools as cgtools
  15. from megengine.core.ops.builtin import Elemwise
  16. from megengine.core.tensor import megbrain_graph as mgb_graph
  17. from megengine.core.tensor.megbrain_graph import apply_normal_varnode
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.jit import trace
  20. def make_dev_tensor(value, dtype=None, device=None):
  21. return megengine.tensor(value, dtype=dtype, device=device)._dev_tensor()
  22. def test_replace_vars():
  23. g = mgb_graph.Graph()
  24. g.options.async_exec_level = 0b100
  25. device = "xpux"
  26. dtype = np.float32
  27. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  28. const = g.make_const(1.234, device=device)
  29. add_op = Elemwise(Elemwise.Mode.ADD)
  30. mul_op = Elemwise(Elemwise.Mode.MUL)
  31. a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
  32. a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
  33. rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0]
  34. (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
  35. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  36. func = g.compile(out.outputs[0])
  37. func.execute()
  38. x = make_dev_tensor(5.0, device=device)
  39. a.set_value(x)
  40. res = out.get_value().numpy()
  41. np.testing.assert_equal(res, np.array([105.0]))
  42. def test_replace_oprs():
  43. g = mgb_graph.Graph()
  44. g.options.async_exec_level = 0b100
  45. device = "xpux"
  46. dtype = np.float32
  47. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  48. const = g.make_const(1.25, device=device)
  49. add_op = Elemwise(Elemwise.Mode.ADD)
  50. mul_op = Elemwise(Elemwise.Mode.MUL)
  51. a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
  52. old_opr = a_plus_a.op
  53. a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
  54. a_mul_a = apply_normal_varnode(mul_op, a.outputs[0], a.outputs[0])[0]
  55. new_opr = a_mul_a.op
  56. (new,) = cgtools.replace_oprs(
  57. [a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
  58. )
  59. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  60. func = g.compile(out.outputs[0])
  61. func.execute()
  62. x = make_dev_tensor(5.0, device=device)
  63. a.set_value(x)
  64. res = out.get_value().numpy()
  65. np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
  66. def test_graph_traversal():
  67. net = M.Conv2d(3, 32, 3)
  68. @trace(symbolic=True, capture_as_const=True)
  69. def fun(data):
  70. x = net(data)
  71. return x
  72. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  73. for _ in range(3):
  74. fun(megengine.tensor(data))
  75. file = io.BytesIO()
  76. fun.dump(file, optimize_for_inference=False)
  77. file.seek(0)
  78. outputs = mgb_graph.load_graph(file).output_vars_list
  79. _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
  80. input_var = map_vars[1]
  81. _, var_idx = var2oprs[input_var.id][0]
  82. assert var_idx == 0
  83. def test_load_refcnt():
  84. graph = mgb_graph.Graph()
  85. varnode = graph.make_const(0)
  86. buf, _ = mgb_graph.dump_graph([varnode])
  87. ret = mgb_graph.load_graph(io.BytesIO(buf))
  88. graph, (varnode,) = ret.graph, ret.output_vars_list
  89. del ret
  90. del graph
  91. varnode.owner
  92. def test_get_opr_seq():
  93. class Net(M.Module):
  94. def __init__(self):
  95. super().__init__()
  96. self.data = megengine.tensor(
  97. np.random.random((1, 1, 4, 4)), dtype=np.float32
  98. )
  99. def forward(self, input):
  100. A = input.shape[0]
  101. shape = astensor1d((A, A), self.data, dtype="int32", device=input.device)
  102. x = F.reshape(self.data, shape)
  103. o = input + x
  104. return o
  105. net = Net()
  106. input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32)
  107. @trace(symbolic=True, capture_as_const=True)
  108. def func(inp, *, net=None):
  109. return net(inp)
  110. func(input, net=net)
  111. file = io.BytesIO()
  112. func.dump(file, optimize_for_inference=False)
  113. file.seek(0)
  114. outputs = mgb_graph.load_graph(file).output_vars_list
  115. seq_1 = cgtools.get_oprs_seq(outputs, True)
  116. assert len(seq_1) == 5
  117. seq_2 = cgtools.get_oprs_seq(outputs, False)
  118. assert len(seq_2) == 6
  119. def test_graph_function():
  120. class Net(M.Module):
  121. def forward(self, a, b):
  122. return a - b, a * b
  123. net = Net()
  124. @trace(symbolic=True, capture_as_const=True)
  125. def function(a, b, *, net=None):
  126. return net(a, b)
  127. a = np.array([1, 2, 3])
  128. b = np.array([3])
  129. x, y = function(megengine.tensor(a), megengine.tensor(b), net=net)
  130. file = io.BytesIO()
  131. function.dump(
  132. file,
  133. arg_names=["a", "b"],
  134. output_names=["x", "y"],
  135. optimize_for_inference=False,
  136. )
  137. file.seek(0)
  138. graph = cgtools.GraphInference(file)
  139. results = graph.run(inp_dict={"a": a, "b": b})
  140. np.testing.assert_equal(x.numpy(), results["x"])
  141. np.testing.assert_equal(y.numpy(), results["y"])
  142. results = graph.run(a, inp_dict={"b": b})
  143. np.testing.assert_equal(x.numpy(), results["x"])
  144. np.testing.assert_equal(y.numpy(), results["y"])
  145. results = graph.run(a, b)
  146. np.testing.assert_equal(x.numpy(), results["x"])
  147. np.testing.assert_equal(y.numpy(), results["y"])
  148. file.seek(0)
  149. graph1 = cgtools.GraphInference(file, outputs=["x"])
  150. results = graph1.run(inp_dict={"a": a, "b": b})
  151. np.testing.assert_equal(x.numpy(), results["x"])
  152. assert "y" not in results

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台