Browse Source

feat(mge/device): __repr__ method will show physical device

GitOrigin-RevId: 050c3864a7
release-1.1
Megvii Engine Team 4 years ago
parent
commit
2e9ba679b5
3 changed files with 14 additions and 6 deletions
  1. +3
    -1
      imperative/python/megengine/core/_wrap.py
  2. +5
    -5
      imperative/python/megengine/tensor.py
  3. +6
    -0
      imperative/python/src/common.cpp

+ 3
- 1
imperative/python/megengine/core/_wrap.py View File

@@ -22,11 +22,13 @@ class Device:
else:
self._cn = CompNode(device)

self.logical_name = self._cn.logical_name

def to_c(self):
return self._cn

def __repr__(self):
return "{}({})".format(type(self).__qualname__, self)
return "{}({})".format(type(self).__qualname__, repr(self._cn))

def __str__(self):
return str(self._cn)


+ 5
- 5
imperative/python/megengine/tensor.py View File

@@ -67,7 +67,7 @@ class Tensor(_Tensor):

state = {
"data": self.numpy(),
"device": str(self.device),
"device": self.device.logical_name,
"dtype": self.dtype,
"qdict": self.q_dict,
}
@@ -75,13 +75,13 @@ class Tensor(_Tensor):

def __setstate__(self, state):
data = state.pop("data")
device = state.pop("device")
logical_device = state.pop("device")
if self.dmap_callback is not None:
assert isinstance(device, str)
device = self.dmap_callback(device)
assert isinstance(logical_device, str)
logical_device = self.dmap_callback(logical_device)
dtype = state.pop("dtype")
self.q_dict = state.pop("qdict")
super().__init__(data, dtype=dtype, device=device)
super().__init__(data, dtype=dtype, device=logical_device)

def detach(self):
r"""


+ 6
- 0
imperative/python/src/common.cpp View File

@@ -55,10 +55,16 @@ void init_common(py::module m) {
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode")
.def(py::init())
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load)))
.def_property_readonly("logical_name", [](const CompNode& cn) {
return cn.to_string_logical();
})
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul)
.def("_set_default_device", &set_default_device)
.def("_get_default_device", &get_default_device)
.def("__str__", &CompNode::to_string_logical)
.def("__repr__", [](const CompNode& cn) {
return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\"");
})
.def_static("_sync_all", &CompNode::sync_all)
.def(py::self == py::self)
.def_static("_get_device_count", &CompNode::get_device_count,


Loading…
Cancel
Save