|
|
@@ -14,6 +14,7 @@ |
|
|
|
#define DO_IMPORT_ARRAY |
|
|
|
#include "./helper.h" |
|
|
|
#include "./numpy_dtypes.h" |
|
|
|
#include "megdnn/handle.h" |
|
|
|
|
|
|
|
#include "./common.h" |
|
|
|
#include "./graph_rt.h" |
|
|
@@ -30,6 +31,10 @@ using namespace mgb::imperative::python; |
|
|
|
#define MODULE_NAME imperative_rt |
|
|
|
#endif |
|
|
|
|
|
|
|
namespace megdnn { |
|
|
|
extern const std::shared_ptr<Handle>& inplace_cpu_handle(int debug_level = 0); |
|
|
|
} |
|
|
|
|
|
|
|
PYBIND11_MODULE(MODULE_NAME, m) { |
|
|
|
// initialize numpy |
|
|
|
if ([]() { |
|
|
@@ -39,6 +44,8 @@ PYBIND11_MODULE(MODULE_NAME, m) { |
|
|
|
throw py::error_already_set(); |
|
|
|
} |
|
|
|
|
|
|
|
megdnn::inplace_cpu_handle(); |
|
|
|
|
|
|
|
py::module::import("sys").attr("modules")[m.attr("__name__")] = m; |
|
|
|
|
|
|
|
m.attr("__package__") = m.attr("__name__"); |
|
|
|