From 2e9ba679b5b57c4af1d9c263089d437984a7388e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 25 Sep 2020 13:42:15 +0800 Subject: [PATCH] feat(mge/device): __repr__ method will show physical device GitOrigin-RevId: 050c3864a7d99234a02114197dd0f499cb88f413 --- imperative/python/megengine/core/_wrap.py | 4 +++- imperative/python/megengine/tensor.py | 10 +++++----- imperative/python/src/common.cpp | 6 ++++++ 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/core/_wrap.py b/imperative/python/megengine/core/_wrap.py index c4bf7564..538518a1 100644 --- a/imperative/python/megengine/core/_wrap.py +++ b/imperative/python/megengine/core/_wrap.py @@ -22,11 +22,13 @@ class Device: else: self._cn = CompNode(device) + self.logical_name = self._cn.logical_name + def to_c(self): return self._cn def __repr__(self): - return "{}({})".format(type(self).__qualname__, self) + return "{}({})".format(type(self).__qualname__, repr(self._cn)) def __str__(self): return str(self._cn) diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index 0d30a264..5d13530a 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -67,7 +67,7 @@ class Tensor(_Tensor): state = { "data": self.numpy(), - "device": str(self.device), + "device": self.device.logical_name, "dtype": self.dtype, "qdict": self.q_dict, } @@ -75,13 +75,13 @@ class Tensor(_Tensor): def __setstate__(self, state): data = state.pop("data") - device = state.pop("device") + logical_device = state.pop("device") if self.dmap_callback is not None: - assert isinstance(device, str) - device = self.dmap_callback(device) + assert isinstance(logical_device, str) + logical_device = self.dmap_callback(logical_device) dtype = state.pop("dtype") self.q_dict = state.pop("qdict") - super().__init__(data, dtype=dtype, device=device) + super().__init__(data, dtype=dtype, device=logical_device) def detach(self): r""" diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index aeb1f9e9..136d6dde 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -55,10 +55,16 @@ void init_common(py::module m) { auto&& PyCompNode = py::class_(m, "CompNode") .def(py::init()) .def(py::init(py::overload_cast(&CompNode::load))) + .def_property_readonly("logical_name", [](const CompNode& cn) { + return cn.to_string_logical(); + }) .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) .def("_set_default_device", &set_default_device) .def("_get_default_device", &get_default_device) .def("__str__", &CompNode::to_string_logical) + .def("__repr__", [](const CompNode& cn) { + return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\""); + }) .def_static("_sync_all", &CompNode::sync_all) .def(py::self == py::self) .def_static("_get_device_count", &CompNode::get_device_count,