You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.cpp 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. #include "helper.h"
  2. #include "megbrain/graph.h"
  3. #include "megbrain/opr/io.h"
  4. #include <pybind11/embed.h>
  5. #include <pybind11/numpy.h>
  6. #include <memory>
  7. namespace py = pybind11;
  8. namespace mgb {
  9. namespace imperative {
  10. namespace {
  11. #define XSTR(s) STR(s)
  12. #define STR(s) #s
  13. #define CONCAT(a, b) a##b
  14. #define PYINIT(name) CONCAT(PyInit_, name)
  15. #define pyinit PYINIT(MODULE_NAME)
  16. #define UNUSED __attribute__((unused))
  17. extern "C" PyObject* pyinit();
  18. class PyEnv {
  19. static std::unique_ptr<PyEnv> m_instance;
  20. std::unique_ptr<py::scoped_interpreter> m_interpreter;
  21. PyEnv();
  22. public:
  23. static PyEnv& instance();
  24. static py::module get();
  25. };
  26. std::unique_ptr<PyEnv> PyEnv::m_instance = nullptr;
  27. PyEnv::PyEnv() {
  28. mgb_assert(!m_instance);
  29. auto err = PyImport_AppendInittab(XSTR(MODULE_NAME), &pyinit);
  30. mgb_assert(!err);
  31. m_interpreter.reset(new py::scoped_interpreter());
  32. }
  33. PyEnv& PyEnv::instance() {
  34. if (!m_instance) {
  35. m_instance.reset(new PyEnv());
  36. }
  37. return *m_instance;
  38. }
  39. py::module PyEnv::get() {
  40. instance();
  41. return py::module::import(XSTR(MODULE_NAME));
  42. }
  43. py::array array(const Tensor& x) {
  44. PyEnv::get();
  45. return py::cast(x).attr("numpy")();
  46. }
  47. py::array array(const HostTensorND& x) {
  48. return array(*Tensor::make(x));
  49. }
  50. py::array array(const DeviceTensorND& x) {
  51. return array(*Tensor::make(x));
  52. }
  53. UNUSED void print(const Tensor& x) {
  54. return print(array(x));
  55. }
  56. UNUSED void print(const HostTensorND& x) {
  57. return print(array(x));
  58. }
  59. UNUSED void print(const DeviceTensorND& x) {
  60. return print(array(x));
  61. }
  62. UNUSED void print(const char* s) {
  63. PyEnv::instance();
  64. py::print(s);
  65. }
  66. } // anonymous namespace
  67. OprChecker::OprChecker(std::shared_ptr<OpDef> opdef) : m_op(opdef) {}
  68. void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) {
  69. HostTensorGenerator<> gen;
  70. size_t nr_inps = inp_keys.size();
  71. SmallVector<HostTensorND> host_inp(nr_inps);
  72. VarNodeArray sym_inp(nr_inps);
  73. auto graph = ComputingGraph::make();
  74. graph->options().graph_opt_level = 0;
  75. for (size_t i = 0; i < nr_inps; ++i) {
  76. // TODO: remove std::visit for support osx 10.12
  77. host_inp[i] = std::visit(
  78. [&gen](auto&& arg) -> HostTensorND {
  79. using T = std::decay_t<decltype(arg)>;
  80. if constexpr (std::is_same_v<TensorShape, T>) {
  81. return *gen(arg);
  82. } else {
  83. static_assert(std::is_same_v<HostTensorND, T>);
  84. return arg;
  85. }
  86. },
  87. inp_keys[i]);
  88. sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
  89. }
  90. auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp);
  91. size_t nr_oups = sym_oup.size();
  92. ComputingGraph::OutputSpec oup_spec(nr_oups);
  93. SmallVector<HostTensorND> host_sym_oup(nr_oups);
  94. for (size_t i = 0; i < nr_oups; ++i) {
  95. oup_spec[i] = make_callback_copy(sym_oup[i], host_sym_oup[i]);
  96. }
  97. auto func = graph->compile(oup_spec);
  98. SmallVector<TensorPtr> imp_physical_inp(nr_inps);
  99. for (size_t i = 0; i < nr_inps; ++i) {
  100. imp_physical_inp[i] = Tensor::make(host_inp[i]);
  101. }
  102. SmallVector<LogicalTensorDesc> output_descs;
  103. auto imp_oup = OpDef::apply_on_physical_tensor(
  104. *m_op, imp_physical_inp, output_descs, false);
  105. mgb_assert(imp_oup.size() == nr_oups);
  106. // check input not modified
  107. for (size_t i = 0; i < imp_physical_inp.size(); ++i) {
  108. HostTensorND hv;
  109. hv.copy_from(imp_physical_inp[i]->dev_tensor()).sync();
  110. MGB_ASSERT_TENSOR_EQ(hv, host_inp[i]);
  111. }
  112. SmallVector<HostTensorND> host_imp_oup(nr_oups);
  113. for (size_t i = 0; i < nr_oups; ++i) {
  114. host_imp_oup[i].copy_from(imp_oup[i]->dev_tensor()).sync();
  115. }
  116. func->execute().wait(); // run last because it may contain inplace operations
  117. for (size_t i = 0; i < nr_oups; ++i) {
  118. if (bypass.find(i) != bypass.end())
  119. continue;
  120. MGB_ASSERT_TENSOR_EQ(host_sym_oup[i], host_imp_oup[i]);
  121. }
  122. }
  123. TEST(TestHelper, PyModule) {
  124. py::module m = PyEnv::get();
  125. py::print(m);
  126. py::print(py::cast(DeviceTensorND()));
  127. }
  128. } // namespace imperative
  129. } // namespace mgb
  130. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}