You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_megbrain_graph.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from concurrent.futures import Future
  10. import numpy as np
  11. import megengine.functional as F
  12. from megengine.core._imperative_rt import DeviceTensorND
  13. from megengine.core.tensor import megbrain_graph as mgb_graph
  14. from megengine.core.tensor.raw_tensor import as_raw_tensor
  15. def make_dev_tensor(value, dtype=None, device=None):
  16. return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
  17. def test_io():
  18. g = mgb_graph.Graph()
  19. x = make_dev_tensor(np.random.randn(3).astype("float32"), device="xpux")
  20. vx, _ = mgb_graph.input_callback(
  21. lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
  22. )
  23. y = Future()
  24. v = mgb_graph.output_callback(y.set_result, vx)
  25. f = g.compile(v)
  26. f()
  27. np.testing.assert_equal(x.numpy(), y.result().numpy())
  28. def test_io2():
  29. g = mgb_graph.Graph()
  30. g.options.async_exec_level = 0b100
  31. dtype, device = "float32", "xpux"
  32. px = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  33. py = mgb_graph.OutputNode(px.outputs[0])
  34. f = g.compile(py.outputs[0])
  35. for _ in range(3):
  36. f.execute()
  37. x = make_dev_tensor(np.random.randn(10).astype(dtype), device=device)
  38. px.set_value(x)
  39. y = py.get_value()
  40. np.testing.assert_equal(x.numpy(), y.numpy())
  41. f.wait()
  42. def test_attr_output():
  43. g = mgb_graph.Graph()
  44. g.options.async_exec_level = 0b100
  45. dtype, device = "float32", "xpux"
  46. px = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
  47. py = mgb_graph.AttrOutputNode(px.outputs[0])
  48. f = g.compile(py.outputs[0])
  49. for shape in [(2,), (3,), (5,)]:
  50. f.execute()
  51. x = make_dev_tensor(np.random.randn(*shape).astype(dtype), device=device)
  52. px.set_value(x)
  53. ay = py.get_value()
  54. assert ay.shape == shape
  55. assert ay.dtype == np.dtype(dtype)
  56. assert ay.device == device
  57. f.wait()
  58. def test_op():
  59. g = mgb_graph.Graph()
  60. x = make_dev_tensor(np.random.randn(10).astype("float32"), device="xpux")
  61. v, _ = mgb_graph.input_callback(
  62. lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
  63. )
  64. v = F.neg(v)
  65. y = Future()
  66. v = mgb_graph.output_callback(y.set_result, v)
  67. f = g.compile(v)
  68. f()
  69. np.testing.assert_equal(x.numpy(), -y.result().numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台