You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dispatch.py 935 B

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from megengine.core.tensor.multipledispatch import Dispatcher
  2. def test_register_many():
  3. f = Dispatcher("f")
  4. log = []
  5. @f.register()
  6. def _(x: int):
  7. log.append("a")
  8. return log[-1]
  9. @f.register()
  10. def _(x: int):
  11. log.append("b")
  12. return log[-1]
  13. assert f(0) == "b"
  14. assert log == ["b"]
  15. def test_return_not_implemented():
  16. f = Dispatcher("f")
  17. log = []
  18. @f.register()
  19. def _(x: int):
  20. log.append("a")
  21. return log[-1]
  22. @f.register()
  23. def _(x: int):
  24. log.append("b")
  25. return NotImplemented
  26. assert f(0) == "a"
  27. assert log == ["b", "a"]
  28. def test_super():
  29. f = Dispatcher("f")
  30. log = []
  31. @f.register()
  32. def _(x: int):
  33. log.append("a")
  34. return log[-1]
  35. @f.register()
  36. def _(x: int):
  37. log.append("b")
  38. return f.super(x)
  39. assert f(0) == "a"
  40. assert log == ["b", "a"]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台