You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_imperative_rt.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import pytest
  11. import megengine.core.tensor.raw_tensor
  12. from megengine.core.tensor.core import apply
  13. def elemwise(*args, mode):
  14. from megengine.core.ops.builtin import Elemwise
  15. from megengine.core._imperative_rt.imperative import apply_op
  16. return apply_op(Elemwise(mode=mode).to_c(), args)
  17. def test_basic_interface():
  18. cf = megengine.core._imperative_rt.OperatorNodeConfig()
  19. cf.name = "megengine.core"
  20. cf.dtype = "float32"
  21. cf.comp_node_arr = ["xpux"]
  22. print(cf.name)
  23. print(cf.dtype)
  24. print(cf.comp_node_arr)
  25. print(cf.comp_node)
  26. cf.comp_node_arr = ["xpux", "xpux:1"]
  27. with pytest.raises(ValueError):
  28. cf.comp_node
  29. def test_opr_attr():
  30. from megengine.core.ops.builtin import Elemwise
  31. assert Elemwise(mode="add") == Elemwise(mode="add")
  32. def test_simple_arith():
  33. x = np.random.rand(10).astype("float32")
  34. xx = megengine.core._imperative_rt.put(x)
  35. (yy,) = elemwise(xx, xx, mode="mul")
  36. np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy))
  37. megengine.core._imperative_rt.delete(xx)
  38. megengine.core._imperative_rt.delete(yy)
  39. def test_tensor_on_device():
  40. device = megengine.core._imperative_rt.CompNode("cpu0:1")
  41. x = np.random.rand(10).astype("float32")
  42. xx = megengine.core._imperative_rt.put(x, device=device)
  43. assert str(megengine.core._imperative_rt.get_device(xx)) == "cpu0:1"
  44. np.testing.assert_equal(x, megengine.core._imperative_rt.get_value(xx))
  45. megengine.core._imperative_rt.delete(xx)
  46. def test_raw_tensor():
  47. from megengine.core.tensor.raw_tensor import as_raw_tensor
  48. from megengine.core.ops.builtin import Elemwise
  49. x = np.random.rand(10).astype("float32")
  50. xx = as_raw_tensor(x)
  51. (yy,) = apply(Elemwise(mode="mul"), xx, xx)
  52. np.testing.assert_allclose(x * x, yy.numpy())
  53. (yy,) = apply(Elemwise(mode="mul"), xx, xx)
  54. np.testing.assert_allclose(x * x, yy.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台