You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cgtools.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import io
  9. import numpy as np
  10. import pytest
  11. import megengine
  12. import megengine.functional as F
  13. import megengine.module as M
  14. import megengine.utils.comp_graph_tools as cgtools
  15. from megengine.core.tensor import megbrain_graph as mgb_graph
  16. from megengine.core.tensor.raw_tensor import as_raw_tensor
  17. from megengine.core.tensor.utils import astensor1d
  18. from megengine.jit import trace
  19. def make_dev_tensor(value, dtype=None, device=None):
  20. return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
  21. def test_replace_vars():
  22. g = mgb_graph.Graph()
  23. g.options.async_exec_level = 0b100
  24. device = "xpux"
  25. dtype = np.float32
  26. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  27. const = g.make_const(1.234)
  28. a_plus_a = F.add(a.outputs[0], a.outputs[0])
  29. a_plus_a_mul_const = F.mul(a_plus_a, const)
  30. rst = F.add(a_plus_a_mul_const, a.outputs[0])
  31. (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
  32. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  33. func = g.compile(out.outputs[0])
  34. func.execute()
  35. x = make_dev_tensor(5.0, device=device)
  36. a.set_value(x)
  37. res = out.get_value().numpy()
  38. np.testing.assert_equal(res, np.array([105.0]))
  39. def test_replace_oprs():
  40. g = mgb_graph.Graph()
  41. g.options.async_exec_level = 0b100
  42. device = "xpux"
  43. dtype = np.float32
  44. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  45. const = g.make_const(1.25)
  46. a_plus_a = F.add(a.outputs[0], a.outputs[0])
  47. old_opr = a_plus_a.op
  48. a_plus_a_mul_const = F.mul(a_plus_a, const)
  49. a_mul_a = F.mul(a.outputs[0], a.outputs[0])
  50. new_opr = a_mul_a.op
  51. (new,) = cgtools.replace_oprs(
  52. [a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
  53. )
  54. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  55. func = g.compile(out.outputs[0])
  56. func.execute()
  57. x = make_dev_tensor(5.0, device=device)
  58. a.set_value(x)
  59. res = out.get_value().numpy()
  60. np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
  61. def test_graph_traversal():
  62. net = M.Conv2d(3, 32, 3)
  63. @trace(symbolic=True, capture_as_const=True)
  64. def fun(data):
  65. x = net(data)
  66. return x
  67. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  68. for _ in range(3):
  69. fun(megengine.tensor(data))
  70. file = io.BytesIO()
  71. fun.dump(file, optimize_for_inference=False)
  72. file.seek(0)
  73. cg, _, outputs = mgb_graph.load_graph(file)
  74. _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
  75. input_var = map_vars[1]
  76. _, var_idx = var2oprs[input_var.id][0]
  77. assert var_idx == 0
  78. def test_load_refcnt():
  79. graph = mgb_graph.Graph()
  80. varnode = graph.make_const(0)
  81. buf, _ = mgb_graph.dump_graph([varnode])
  82. graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf))
  83. del graph
  84. varnode.owner
  85. def test_get_opr_seq():
  86. class Net(M.Module):
  87. def __init__(self):
  88. super().__init__()
  89. self.data = megengine.tensor(
  90. np.random.random((1, 1, 4, 4)), dtype=np.float32
  91. )
  92. def forward(self, input):
  93. A = input.shape[0]
  94. shape = astensor1d((A, A), self.data, dtype="int32", device=input.device)
  95. x = F.reshape(self.data, shape)
  96. o = input + x
  97. return o
  98. net = Net()
  99. input = megengine.tensor(np.random.random((4, 4)), dtype=np.float32)
  100. @trace(symbolic=True, capture_as_const=True)
  101. def func(inp, *, net=None):
  102. return net(inp)
  103. func(input, net=net)
  104. file = io.BytesIO()
  105. func.dump(file, optimize_for_inference=False)
  106. file.seek(0)
  107. *_, outputs = mgb_graph.load_graph(file)
  108. seq_1 = cgtools.get_oprs_seq(outputs, True)
  109. assert len(seq_1) == 5
  110. seq_2 = cgtools.get_oprs_seq(outputs, False)
  111. assert len(seq_2) == 6

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台