You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.cpp 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. /**
  2. * \file imperative/python/src/module.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <pybind11/eval.h>
  12. #define DO_IMPORT_ARRAY
  13. #include "./helper.h"
  14. #include "./numpy_dtypes.h"
  15. #include "./common.h"
  16. #include "./graph_rt.h"
  17. #include "./imperative_rt.h"
  18. #include "./ops.h"
  19. #include "./utils.h"
  20. #include "./tensor.h"
  21. namespace py = pybind11;
  22. using namespace mgb::imperative::python;
  23. #ifndef MODULE_NAME
  24. #define MODULE_NAME imperative_rt
  25. #endif
  26. PYBIND11_MODULE(MODULE_NAME, m) {
  27. // initialize numpy
  28. if ([]() {
  29. import_array1(1);
  30. return 0;
  31. }()) {
  32. throw py::error_already_set();
  33. }
  34. py::module::import("sys").attr("modules")[m.attr("__name__")] = m;
  35. m.attr("__package__") = m.attr("__name__");
  36. m.attr("__builtins__") = py::module::import("builtins");
  37. auto atexit = py::module::import("atexit");
  38. atexit.attr("register")(py::cpp_function([]() {
  39. py::gil_scoped_release _;
  40. py_task_q.wait_all_task_finish();
  41. }));
  42. auto common = submodule(m, "common");
  43. auto utils = submodule(m, "utils");
  44. auto imperative = submodule(m, "imperative");
  45. auto graph = submodule(m, "graph");
  46. auto ops = submodule(m, "ops");
  47. init_common(common);
  48. init_utils(utils);
  49. init_imperative_rt(imperative);
  50. init_graph_rt(graph);
  51. init_ops(ops);
  52. py::exec(
  53. R"(
  54. from .common import *
  55. from .utils import *
  56. from .imperative import *
  57. from .graph import *
  58. from .ops import OpDef
  59. )",
  60. py::getattr(m, "__dict__"));
  61. init_tensor(submodule(m, "core2"));
  62. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台