You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cpp 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * \file imperative/src/test/helper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "helper.h"
  12. #include "megbrain/graph.h"
  13. #include "megbrain/opr/io.h"
  14. #include <memory>
  15. #include <pybind11/embed.h>
  16. #include <pybind11/numpy.h>
  17. namespace py = pybind11;
  18. namespace mgb {
  19. namespace imperative {
  20. namespace {
  21. #define XSTR(s) STR(s)
  22. #define STR(s) #s
  23. #define CONCAT(a, b) a##b
  24. #define PYINIT(name) CONCAT(PyInit_, name)
  25. #define pyinit PYINIT(MODULE_NAME)
  26. #define UNUSED __attribute__((unused))
  27. extern "C" PyObject* pyinit();
  28. class PyEnv {
  29. static std::unique_ptr<PyEnv> m_instance;
  30. std::unique_ptr<py::scoped_interpreter> m_interpreter;
  31. PyEnv();
  32. public:
  33. static PyEnv& instance();
  34. static py::module get();
  35. };
  36. std::unique_ptr<PyEnv> PyEnv::m_instance = nullptr;
  37. PyEnv::PyEnv() {
  38. mgb_assert(!m_instance);
  39. auto err = PyImport_AppendInittab(XSTR(MODULE_NAME), &pyinit);
  40. mgb_assert(!err);
  41. m_interpreter.reset(new py::scoped_interpreter());
  42. }
  43. PyEnv& PyEnv::instance() {
  44. if (!m_instance) {
  45. m_instance.reset(new PyEnv());
  46. }
  47. return *m_instance;
  48. }
  49. py::module PyEnv::get() {
  50. instance();
  51. return py::module::import(XSTR(MODULE_NAME));
  52. }
  53. py::array array(const Tensor& x) {
  54. PyEnv::get();
  55. return py::cast(x).attr("numpy")();
  56. }
  57. py::array array(const HostTensorND& x) {
  58. return array(*Tensor::make(x));
  59. }
  60. py::array array(const DeviceTensorND& x) {
  61. return array(*Tensor::make(x));
  62. }
  63. UNUSED void print(const Tensor& x) {
  64. return print(array(x));
  65. }
  66. UNUSED void print(const HostTensorND& x) {
  67. return print(array(x));
  68. }
  69. UNUSED void print(const DeviceTensorND& x) {
  70. return print(array(x));
  71. }
  72. UNUSED void print(const char* s) {
  73. PyEnv::instance();
  74. py::print(s);
  75. }
  76. } // anonymous namespace
  77. OprChecker::OprChecker(std::shared_ptr<OpDef> opdef)
  78. : m_op(opdef) {}
  79. void OprChecker::run(std::vector<InputSpec> inp_keys) {
  80. HostTensorGenerator<> gen;
  81. size_t nr_inps = inp_keys.size();
  82. SmallVector<HostTensorND> host_inp(nr_inps);
  83. VarNodeArray sym_inp(nr_inps);
  84. auto graph = ComputingGraph::make();
  85. graph->options().graph_opt_level = 0;
  86. for (size_t i = 0; i < nr_inps; ++ i) {
  87. host_inp[i] = std::visit([&gen](auto&& arg) -> HostTensorND {
  88. using T = std::decay_t<decltype(arg)>;
  89. if constexpr (std::is_same_v<TensorShape, T>) {
  90. return *gen(arg);
  91. } else {
  92. static_assert(std::is_same_v<HostTensorND, T>);
  93. return arg;
  94. }
  95. }, inp_keys[i]);
  96. sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
  97. }
  98. auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp)->usable_output();
  99. size_t nr_oups = sym_oup.size();
  100. ComputingGraph::OutputSpec oup_spec(nr_oups);
  101. SmallVector<HostTensorND> host_sym_oup(nr_oups);
  102. for (size_t i = 0; i < nr_oups; ++ i) {
  103. oup_spec[i] = make_callback_copy(sym_oup[i], host_sym_oup[i]);
  104. }
  105. auto func = graph->compile(oup_spec);
  106. SmallVector<TensorPtr> imp_physical_inp(nr_inps);
  107. for (size_t i = 0; i < nr_inps; ++ i) {
  108. imp_physical_inp[i] = Tensor::make(host_inp[i]);
  109. }
  110. auto imp_oup = OpDef::apply_on_physical_tensor(*m_op, imp_physical_inp);
  111. mgb_assert(imp_oup.size() == nr_oups);
  112. // check input not modified
  113. for (size_t i = 0; i < imp_physical_inp.size(); ++i) {
  114. HostTensorND hv;
  115. hv.copy_from(imp_physical_inp[i]->dev_tensor()).sync();
  116. MGB_ASSERT_TENSOR_EQ(hv, host_inp[i]);
  117. }
  118. SmallVector<HostTensorND> host_imp_oup(nr_oups);
  119. for (size_t i = 0; i < nr_oups; ++ i) {
  120. host_imp_oup[i].copy_from(imp_oup[i]->dev_tensor()).sync();
  121. }
  122. func->execute().wait(); // run last because it may contain inplace operations
  123. for(size_t i = 0; i < nr_oups; ++ i) {
  124. MGB_ASSERT_TENSOR_EQ(host_sym_oup[i], host_imp_oup[i]);
  125. }
  126. }
  127. TEST(TestHelper, PyModule) {
  128. py::module m = PyEnv::get();
  129. py::print(m);
  130. py::print(py::cast(DeviceTensorND()));
  131. }
  132. } // namespace imperative
  133. } // namespace mgb
  134. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台