diff --git a/imperative/python/megengine/device.py b/imperative/python/megengine/device.py index 008920fe..731b8985 100644 --- a/imperative/python/megengine/device.py +++ b/imperative/python/megengine/device.py @@ -17,8 +17,6 @@ __all__ = [ "set_default_device", ] -_default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux") - def _valid_device(inp): if isinstance(inp, str) and len(inp) == 4: @@ -76,9 +74,8 @@ def set_default_device(device: str = "xpux"): It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. """ - global _default_device # pylint: disable=global-statement assert _valid_device(device), "Invalid device name {}".format(device) - _default_device = device + CompNode._set_default_device(device) def get_default_device() -> str: @@ -86,4 +83,7 @@ def get_default_device() -> str: It returns the value set by :func:`~.set_default_device`. """ - return _default_device + return CompNode._get_default_device() + + +set_default_device(os.getenv("MGE_DEFAULT_DEVICE", "xpux")) diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 10807a2e..cdaaf819 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -39,13 +39,25 @@ auto def_TensorND(py::object parent, const char* name) { &XTensorND::template copy_from_fixlayout)); } +std::string default_device = "xpux"; + } // namespace +void set_default_device(const std::string &device) { + default_device = device; +} + +std::string get_default_device() { + return default_device; +} + void init_common(py::module m) { auto&& PyCompNode = py::class_(m, "CompNode") .def(py::init()) .def(py::init(py::overload_cast(&CompNode::load))) .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) + .def("_set_default_device", &set_default_device) + .def("_get_default_device", &get_default_device) .def("__str__", &CompNode::to_string_logical) .def_static("_sync_all", &CompNode::sync_all) .def(py::self == py::self) diff --git a/imperative/python/src/common.h b/imperative/python/src/common.h index 5837de34..d5201b19 100644 --- a/imperative/python/src/common.h +++ b/imperative/python/src/common.h @@ -14,3 +14,6 @@ #include "./helper.h" void init_common(pybind11::module m); + +void set_default_device(const std::string &device); +std::string get_default_device(); \ No newline at end of file diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index deb29bba..500d8dd3 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -19,6 +19,7 @@ #include "megbrain/imperative.h" #include "./helper.h" #include "megbrain/plugin/profiler.h" +#include "./common.h" namespace py = pybind11; @@ -230,7 +231,7 @@ void init_graph_rt(py::module m) { m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { if (!cn.valid()) { - cn = CompNode::load("xpux"); + cn = CompNode::load(get_default_device()); } auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); diff --git a/imperative/python/src/imperative_rt.cpp b/imperative/python/src/imperative_rt.cpp index 6055b411..c96efbfe 100644 --- a/imperative/python/src/imperative_rt.cpp +++ b/imperative/python/src/imperative_rt.cpp @@ -21,6 +21,7 @@ #include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/ops/opr_attr.h" #include "./helper.h" +#include "./common.h" namespace py = pybind11; @@ -53,7 +54,7 @@ void init_imperative_rt(py::module m) { py::class_(m, "Interpreter") .def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) { if (!cn.valid()) { - cn = CompNode::load("xpux"); + cn = CompNode::load(get_default_device()); } constexpr int size_threshhold = TensorShape::MAX_NDIM; if (data.size() > size_threshhold) {