You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cgtools.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import io
  9. import numpy as np
  10. import pytest
  11. import megengine
  12. import megengine.functional as F
  13. import megengine.module as M
  14. from megengine import cgtools
  15. from megengine.core.tensor import megbrain_graph as mgb_graph
  16. from megengine.core.tensor.raw_tensor import as_raw_tensor
  17. from megengine.jit import trace
  18. def make_dev_tensor(value, dtype=None, device=None):
  19. return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
  20. def test_replace_vars():
  21. g = mgb_graph.Graph()
  22. g.options.async_exec_level = 0b100
  23. device = "xpux"
  24. dtype = np.float32
  25. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  26. const = g.make_const(1.234)
  27. a_plus_a = F.add(a.outputs[0], a.outputs[0])
  28. a_plus_a_mul_const = F.mul(a_plus_a, const)
  29. rst = F.add(a_plus_a_mul_const, a.outputs[0])
  30. (new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
  31. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  32. func = g.compile(out.outputs[0])
  33. func.execute()
  34. x = make_dev_tensor(5.0, device=device)
  35. a.set_value(x)
  36. res = out.get_value().numpy()
  37. np.testing.assert_equal(res, np.array([105.0]))
  38. def test_replace_oprs():
  39. g = mgb_graph.Graph()
  40. g.options.async_exec_level = 0b100
  41. device = "xpux"
  42. dtype = np.float32
  43. a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  44. const = g.make_const(1.25)
  45. a_plus_a = F.add(a.outputs[0], a.outputs[0])
  46. old_opr = a_plus_a.op
  47. a_plus_a_mul_const = F.mul(a_plus_a, const)
  48. a_mul_a = F.mul(a.outputs[0], a.outputs[0])
  49. new_opr = a_mul_a.op
  50. (new,) = cgtools.replace_oprs(
  51. [a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
  52. )
  53. out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
  54. func = g.compile(out.outputs[0])
  55. func.execute()
  56. x = make_dev_tensor(5.0, device=device)
  57. a.set_value(x)
  58. res = out.get_value().numpy()
  59. np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
  60. def test_graph_traversal():
  61. net = M.Conv2d(3, 32, 3)
  62. @trace(symbolic=True, capture_as_const=True)
  63. def fun(data):
  64. x = net(data)
  65. return x
  66. data = np.random.random([1, 3, 224, 224]).astype(np.float32)
  67. for _ in range(3):
  68. fun(megengine.tensor(data))
  69. file = io.BytesIO()
  70. fun.dump(file, optimize_for_inference=False)
  71. file.seek(0)
  72. cg, _, outputs = mgb_graph.load_graph(file)
  73. _, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
  74. input_var = map_vars[1]
  75. _, var_idx = var2oprs[input_var.id][0]
  76. assert var_idx == 0
  77. def test_load_refcnt():
  78. graph = mgb_graph.Graph()
  79. varnode = graph.make_const(0)
  80. buf, _ = mgb_graph.dump_graph([varnode])
  81. graph, _, (varnode,) = mgb_graph.load_graph(io.BytesIO(buf))
  82. del graph
  83. varnode.owner

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台