Browse Source

fix(imperative): fix hardcode of default device

GitOrigin-RevId: 722c4debfa
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
2cc85487d1
5 changed files with 24 additions and 7 deletions
  1. +5
    -5
      imperative/python/megengine/device.py
  2. +12
    -0
      imperative/python/src/common.cpp
  3. +3
    -0
      imperative/python/src/common.h
  4. +2
    -1
      imperative/python/src/graph_rt.cpp
  5. +2
    -1
      imperative/python/src/imperative_rt.cpp

+ 5
- 5
imperative/python/megengine/device.py View File

@@ -17,8 +17,6 @@ __all__ = [
"set_default_device",
]

_default_device = os.getenv("MGE_DEFAULT_DEVICE", "xpux")


def _valid_device(inp):
if isinstance(inp, str) and len(inp) == 4:
@@ -76,9 +74,8 @@ def set_default_device(device: str = "xpux"):

It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.
"""
global _default_device # pylint: disable=global-statement
assert _valid_device(device), "Invalid device name {}".format(device)
_default_device = device
CompNode._set_default_device(device)


def get_default_device() -> str:
@@ -86,4 +83,7 @@ def get_default_device() -> str:

It returns the value set by :func:`~.set_default_device`.
"""
return _default_device
return CompNode._get_default_device()


set_default_device(os.getenv("MGE_DEFAULT_DEVICE", "xpux"))

+ 12
- 0
imperative/python/src/common.cpp View File

@@ -39,13 +39,25 @@ auto def_TensorND(py::object parent, const char* name) {
&XTensorND::template copy_from_fixlayout<HostTensorStorage>));
}

std::string default_device = "xpux";

} // namespace

void set_default_device(const std::string &device) {
default_device = device;
}

std::string get_default_device() {
return default_device;
}

void init_common(py::module m) {
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode")
.def(py::init())
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load)))
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul)
.def("_set_default_device", &set_default_device)
.def("_get_default_device", &get_default_device)
.def("__str__", &CompNode::to_string_logical)
.def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self)


+ 3
- 0
imperative/python/src/common.h View File

@@ -14,3 +14,6 @@
#include "./helper.h"

void init_common(pybind11::module m);

void set_default_device(const std::string &device);
std::string get_default_device();

+ 2
- 1
imperative/python/src/graph_rt.cpp View File

@@ -19,6 +19,7 @@
#include "megbrain/imperative.h"
#include "./helper.h"
#include "megbrain/plugin/profiler.h"
#include "./common.h"

namespace py = pybind11;

@@ -230,7 +231,7 @@ void init_graph_rt(py::module m) {

m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) {
if (!cn.valid()) {
cn = CompNode::load("xpux");
cn = CompNode::load(get_default_device());
}
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();


+ 2
- 1
imperative/python/src/imperative_rt.cpp View File

@@ -21,6 +21,7 @@
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "./helper.h"
#include "./common.h"

namespace py = pybind11;

@@ -53,7 +54,7 @@ void init_imperative_rt(py::module m) {
py::class_<Interpreter::Channel>(m, "Interpreter")
.def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) {
if (!cn.valid()) {
cn = CompNode::load("xpux");
cn = CompNode::load(get_default_device());
}
constexpr int size_threshhold = TensorShape::MAX_NDIM;
if (data.size() > size_threshhold) {


Loading…
Cancel
Save