GitOrigin-RevId: 050c3864a7
release-1.1
@@ -22,11 +22,13 @@ class Device: | |||||
else: | else: | ||||
self._cn = CompNode(device) | self._cn = CompNode(device) | ||||
self.logical_name = self._cn.logical_name | |||||
def to_c(self): | def to_c(self): | ||||
return self._cn | return self._cn | ||||
def __repr__(self): | def __repr__(self): | ||||
return "{}({})".format(type(self).__qualname__, self) | |||||
return "{}({})".format(type(self).__qualname__, repr(self._cn)) | |||||
def __str__(self): | def __str__(self): | ||||
return str(self._cn) | return str(self._cn) | ||||
@@ -67,7 +67,7 @@ class Tensor(_Tensor): | |||||
state = { | state = { | ||||
"data": self.numpy(), | "data": self.numpy(), | ||||
"device": str(self.device), | |||||
"device": self.device.logical_name, | |||||
"dtype": self.dtype, | "dtype": self.dtype, | ||||
"qdict": self.q_dict, | "qdict": self.q_dict, | ||||
} | } | ||||
@@ -75,13 +75,13 @@ class Tensor(_Tensor): | |||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
data = state.pop("data") | data = state.pop("data") | ||||
device = state.pop("device") | |||||
logical_device = state.pop("device") | |||||
if self.dmap_callback is not None: | if self.dmap_callback is not None: | ||||
assert isinstance(device, str) | |||||
device = self.dmap_callback(device) | |||||
assert isinstance(logical_device, str) | |||||
logical_device = self.dmap_callback(logical_device) | |||||
dtype = state.pop("dtype") | dtype = state.pop("dtype") | ||||
self.q_dict = state.pop("qdict") | self.q_dict = state.pop("qdict") | ||||
super().__init__(data, dtype=dtype, device=device) | |||||
super().__init__(data, dtype=dtype, device=logical_device) | |||||
def detach(self): | def detach(self): | ||||
r""" | r""" | ||||
@@ -55,10 +55,16 @@ void init_common(py::module m) { | |||||
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | ||||
.def(py::init()) | .def(py::init()) | ||||
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | ||||
.def_property_readonly("logical_name", [](const CompNode& cn) { | |||||
return cn.to_string_logical(); | |||||
}) | |||||
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | ||||
.def("_set_default_device", &set_default_device) | .def("_set_default_device", &set_default_device) | ||||
.def("_get_default_device", &get_default_device) | .def("_get_default_device", &get_default_device) | ||||
.def("__str__", &CompNode::to_string_logical) | .def("__str__", &CompNode::to_string_logical) | ||||
.def("__repr__", [](const CompNode& cn) { | |||||
return py::str("\"" + cn.to_string() + "\" from \"" + cn.to_string_logical() + "\""); | |||||
}) | |||||
.def_static("_sync_all", &CompNode::sync_all) | .def_static("_sync_all", &CompNode::sync_all) | ||||
.def(py::self == py::self) | .def(py::self == py::self) | ||||
.def_static("_get_device_count", &CompNode::get_device_count, | .def_static("_get_device_count", &CompNode::get_device_count, | ||||