Browse Source

feat(mge/device): __repr__ method will show physical device

GitOrigin-RevId: 050c3864a7
release-1.1
Megvii Engine Team 4 years ago
parent
commit
2e9ba679b5
3 changed files with 14 additions and 6 deletions
  1. +3
    -1
      imperative/python/megengine/core/_wrap.py
  2. +5
    -5
      imperative/python/megengine/tensor.py
  3. +6
    -0
      imperative/python/src/common.cpp

+ 3
- 1
imperative/python/megengine/core/_wrap.py View File

@@ -22,11 +22,13 @@ class Device:
else: else:
self._cn = CompNode(device) self._cn = CompNode(device)


self.logical_name = self._cn.logical_name

def to_c(self): def to_c(self):
return self._cn return self._cn


def __repr__(self): def __repr__(self):
return "{}({})".format(type(self).__qualname__, self)
return "{}({})".format(type(self).__qualname__, repr(self._cn))


def __str__(self): def __str__(self):
return str(self._cn) return str(self._cn)


+ 5
- 5
imperative/python/megengine/tensor.py View File

@@ -67,7 +67,7 @@ class Tensor(_Tensor):


state = { state = {
"data": self.numpy(), "data": self.numpy(),
"device": str(self.device),
"device": self.device.logical_name,
"dtype": self.dtype, "dtype": self.dtype,
"qdict": self.q_dict, "qdict": self.q_dict,
} }
@@ -75,13 +75,13 @@ class Tensor(_Tensor):


def __setstate__(self, state): def __setstate__(self, state):
data = state.pop("data") data = state.pop("data")
device = state.pop("device")
logical_device = state.pop("device")
if self.dmap_callback is not None: if self.dmap_callback is not None:
assert isinstance(device, str)
device = self.dmap_callback(device)
assert isinstance(logical_device, str)
logical_device = self.dmap_callback(logical_device)
dtype = state.pop("dtype") dtype = state.pop("dtype")
self.q_dict = state.pop("qdict") self.q_dict = state.pop("qdict")
super().__init__(data, dtype=dtype, device=device)
super().__init__(data, dtype=dtype, device=logical_device)


def detach(self): def detach(self):
r""" r"""


+ 6
- 0
imperative/python/src/common.cpp View File

@@ -55,10 +55,16 @@ void init_common(py::module m) {
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") auto&& PyCompNode = py::class_<CompNode>(m, "CompNode")
.def(py::init()) .def(py::init())
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) .def(py::init(py::overload_cast<const std::string&>(&CompNode::load)))
.def_property_readonly("logical_name", [](const CompNode& cn) {
return cn.to_string_logical();
})
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul)
.def("_set_default_device", &set_default_device) .def("_set_default_device", &set_default_device)
.def("_get_default_device", &get_default_device) .def("_get_default_device", &get_default_device)
.def("__str__", &CompNode::to_string_logical) .def("__str__", &CompNode::to_string_logical)
.def("__repr__", [](const CompNode& cn) {
return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\"");
})
.def_static("_sync_all", &CompNode::sync_all) .def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self) .def(py::self == py::self)
.def_static("_get_device_count", &CompNode::get_device_count, .def_static("_get_device_count", &CompNode::get_device_count,


Loading…
Cancel
Save