GitOrigin-RevId: 722c4debfa
tags/v1.0.0-rc1
@@ -17,8 +17,6 @@ __all__ = [ | |||||
"set_default_device", | "set_default_device", | ||||
] | ] | ||||
_default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux") | |||||
def _valid_device(inp): | def _valid_device(inp): | ||||
if isinstance(inp, str) and len(inp) == 4: | if isinstance(inp, str) and len(inp) == 4: | ||||
@@ -76,9 +74,8 @@ def set_default_device(device: str = "xpux"): | |||||
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | ||||
""" | """ | ||||
global _default_device # pylint: disable=global-statement | |||||
assert _valid_device(device), "Invalid device name {}".format(device) | assert _valid_device(device), "Invalid device name {}".format(device) | ||||
_default_device = device | |||||
CompNode._set_default_device(device) | |||||
def get_default_device() -> str: | def get_default_device() -> str: | ||||
@@ -86,4 +83,7 @@ def get_default_device() -> str: | |||||
It returns the value set by :func:`~.set_default_device`. | It returns the value set by :func:`~.set_default_device`. | ||||
""" | """ | ||||
return _default_device | |||||
return CompNode._get_default_device() | |||||
set_default_device(os.getenv("MGE_DEFAULT_DEVICE", "xpux")) |
@@ -39,13 +39,25 @@ auto def_TensorND(py::object parent, const char* name) { | |||||
&XTensorND::template copy_from_fixlayout<HostTensorStorage>)); | &XTensorND::template copy_from_fixlayout<HostTensorStorage>)); | ||||
} | } | ||||
std::string default_device = "xpux"; | |||||
} // namespace | } // namespace | ||||
void set_default_device(const std::string &device) { | |||||
default_device = device; | |||||
} | |||||
std::string get_default_device() { | |||||
return default_device; | |||||
} | |||||
void init_common(py::module m) { | void init_common(py::module m) { | ||||
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | ||||
.def(py::init()) | .def(py::init()) | ||||
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | ||||
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | ||||
.def("_set_default_device", &set_default_device) | |||||
.def("_get_default_device", &get_default_device) | |||||
.def("__str__", &CompNode::to_string_logical) | .def("__str__", &CompNode::to_string_logical) | ||||
.def_static("_sync_all", &CompNode::sync_all) | .def_static("_sync_all", &CompNode::sync_all) | ||||
.def(py::self == py::self) | .def(py::self == py::self) | ||||
@@ -14,3 +14,6 @@ | |||||
#include "./helper.h" | #include "./helper.h" | ||||
void init_common(pybind11::module m); | void init_common(pybind11::module m); | ||||
void set_default_device(const std::string &device); | |||||
std::string get_default_device(); |
@@ -19,6 +19,7 @@ | |||||
#include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
#include "./helper.h" | #include "./helper.h" | ||||
#include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
#include "./common.h" | |||||
namespace py = pybind11; | namespace py = pybind11; | ||||
@@ -230,7 +231,7 @@ void init_graph_rt(py::module m) { | |||||
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | ||||
if (!cn.valid()) { | if (!cn.valid()) { | ||||
cn = CompNode::load("xpux"); | |||||
cn = CompNode::load(get_default_device()); | |||||
} | } | ||||
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | ||||
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | ||||
@@ -21,6 +21,7 @@ | |||||
#include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "./helper.h" | #include "./helper.h" | ||||
#include "./common.h" | |||||
namespace py = pybind11; | namespace py = pybind11; | ||||
@@ -53,7 +54,7 @@ void init_imperative_rt(py::module m) { | |||||
py::class_<Interpreter::Channel>(m, "Interpreter") | py::class_<Interpreter::Channel>(m, "Interpreter") | ||||
.def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) { | .def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) { | ||||
if (!cn.valid()) { | if (!cn.valid()) { | ||||
cn = CompNode::load("xpux"); | |||||
cn = CompNode::load(get_default_device()); | |||||
} | } | ||||
constexpr int size_threshhold = TensorShape::MAX_NDIM; | constexpr int size_threshhold = TensorShape::MAX_NDIM; | ||||
if (data.size() > size_threshhold) { | if (data.size() > size_threshhold) { | ||||