You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cgtools.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import io
  9. import numpy as np
  10. import pytest
  11. import megengine
  12. import megengine.functional as F
  13. import megengine.module as M
  14. import megengine.utils.comp_graph_tools as cgtools
  15. from megengine.core.ops.builtin import Elemwise
  16. from megengine.core.tensor import megbrain_graph as mgb_graph
  17. from megengine.core.tensor.megbrain_graph import apply_normal_varnode
  18. from megengine.core.tensor.utils import astensor1d
  19. from megengine.jit import trace
  20. def make_dev_tensor(value, dtype=None, device=None):
  21. return megengine.tensor(value, dtype=dtype, device=device)._dev_tensor()
  22. def test_replace_vars():
  23. g = mgb_graph.Graph()
  24. g.options.async_exec_level = 0b100
  25. device = "xpux"
  26. dtype = np.float32
  27. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  28. const = g.make_const(1.234, device=device)
  29. add_op = Elemwise(Elemwise.Mode.ADD)
  30. mul_op = Elemwise(Elemwise.Mode.MUL)
  31. a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
  32. a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
  33. rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0]
  34. (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
  35. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  36. func = g.compile(out.outputs[0])
  37. func.execute()
  38. x = make_dev_tensor(5.0, device=device)
  39. a.set_value(x)
  40. res = out.get_value().numpy()
  41. np.testing.assert_equal(res, np.array([105.0]))
  42. def test_replace_oprs():
  43. g = mgb_graph.Graph()
  44. g.options.async_exec_level = 0b100
  45. device = "xpux"
  46. dtype = np.float32
  47. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  48. const = g.make_const(1.25, device=device)
  49. add_op = Elemwise(Elemwise.Mode.ADD)
  50. mul_op = Elemwise(Elemwise.Mode.MUL)
  51. a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
  52. old_opr = a_plus_a.op
  53. a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
  54. a_mul_a = apply_normal_varnode(mul_op, a.outputs[0], a.outputs[0])[0]
  55. new_opr = a_mul_a.op
  56. (new,) = cgtools.replace_oprs(
  57. [a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
  58. )
  59. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  60. func = g.compile(out.outputs[0])
  61. func.execute()
  62. x = make_dev_tensor(5.0, device=device)
  63. a.set_value(x)
  64. res = out.get_value().numpy()
  65. np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
  66. def test_graph_traversal():
  67. net = M.Conv2d(3, 32, 3)
  68. @trace(symbolic=True, capture_as_const=True)
  69. def fun(data):
  70. x = net(data)
  71. return x
  72. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  73. for _ in range(3):
  74. fun(megengine.tensor(data))
  75. file = io.BytesIO()
  76. fun.dump(file, optimize_for_inference=False)
  77. file.seek(0)
  78. cg, _, outputs = mgb_graph.load_graph(file)
  79. _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
  80. input_var = map_vars[1]
  81. _, var_idx = var2oprs[input_var.id][0]
  82. assert var_idx == 0
  83. def test_load_refcnt():
  84. graph = mgb_graph.Graph()
  85. varnode = graph.make_const(0)
  86. buf, _ = mgb_graph.dump_graph([varnode])
  87. graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf))
  88. del graph
  89. varnode.owner
  90. def test_get_opr_seq():
  91. class Net(M.Module):
  92. def __init__(self):
  93. super().__init__()
  94. self.data = megengine.tensor(
  95. np.random.random((1, 1, 4, 4)), dtype=np.float32
  96. )
  97. def forward(self, input):
  98. A = input.shape[0]
  99. shape = astensor1d((A, A), self.data, dtype="int32", device=input.device)
  100. x = F.reshape(self.data, shape)
  101. o = input + x
  102. return o
  103. net = Net()
  104. input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32)
  105. @trace(symbolic=True, capture_as_const=True)
  106. def func(inp, *, net=None):
  107. return net(inp)
  108. func(input, net=net)
  109. file = io.BytesIO()
  110. func.dump(file, optimize_for_inference=False)
  111. file.seek(0)
  112. *_, outputs = mgb_graph.load_graph(file)
  113. seq_1 = cgtools.get_oprs_seq(outputs, True)
  114. assert len(seq_1) == 5
  115. seq_2 = cgtools.get_oprs_seq(outputs, False)
  116. assert len(seq_2) == 6
  117. def test_graph_function():
  118. class Net(M.Module):
  119. def forward(self, a, b):
  120. return a - b, a * b
  121. net = Net()
  122. @trace(symbolic=True, capture_as_const=True)
  123. def function(a, b, *, net=None):
  124. return net(a, b)
  125. a = np.array([1, 2, 3])
  126. b = np.array([3])
  127. x, y = function(megengine.tensor(a), megengine.tensor(b), net=net)
  128. file = io.BytesIO()
  129. function.dump(
  130. file,
  131. arg_names=["a", "b"],
  132. output_names=["x", "y"],
  133. optimize_for_inference=False,
  134. )
  135. file.seek(0)
  136. graph = cgtools.GraphInference(file)
  137. results = graph.run(inp_dict={"a": a, "b": b})
  138. np.testing.assert_equal(x.numpy(), results["x"])
  139. np.testing.assert_equal(y.numpy(), results["y"])
  140. results = graph.run(a, inp_dict={"b": b})
  141. np.testing.assert_equal(x.numpy(), results["x"])
  142. np.testing.assert_equal(y.numpy(), results["y"])
  143. results = graph.run(a, b)
  144. np.testing.assert_equal(x.numpy(), results["x"])
  145. np.testing.assert_equal(y.numpy(), results["y"])
  146. file.seek(0)
  147. graph1 = cgtools.GraphInference(file, outputs=["x"])
  148. results = graph1.run(inp_dict={"a": a, "b": b})
  149. np.testing.assert_equal(x.numpy(), results["x"])
  150. assert "y" not in results

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台