|
|
@@ -225,6 +225,7 @@ class LiteTensor(object): |
|
|
|
tensor_desc.device_id = device_id |
|
|
|
tensor_desc.is_pinned_host = is_pinned_host |
|
|
|
self._api.LITE_make_tensor(tensor_desc, byref(self._tensor)) |
|
|
|
self.update() |
|
|
|
|
|
|
|
def __del__(self): |
|
|
|
self._api.LITE_destroy_tensor(self._tensor) |
|
|
@@ -318,6 +319,11 @@ class LiteTensor(object): |
|
|
|
self._device_type = device_type |
|
|
|
self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout)) |
|
|
|
|
|
|
|
c_types = _lite_dtypes_to_ctype[self._layout.data_type] |
|
|
|
self.np_array_type = np.ctypeslib._ctype_ndarray( |
|
|
|
c_types, list(self._layout.shapes)[0 : self._layout.ndim] |
|
|
|
) |
|
|
|
|
|
|
|
def copy_from(self, src_tensor): |
|
|
|
""" |
|
|
|
copy memory form the src_tensor |
|
|
@@ -447,15 +453,11 @@ class LiteTensor(object): |
|
|
|
return the numpy arrray, be careful, the data in numpy is valid before |
|
|
|
the tensor memory is write again, such as LiteNetwok forward next time. |
|
|
|
""" |
|
|
|
assert self.is_continue, "get_data_by_share can only apply in continue tensor." |
|
|
|
assert ( |
|
|
|
self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU |
|
|
|
), "get_data_by_share can only apply in CPU tensor or cpu pinned tensor." |
|
|
|
|
|
|
|
memory = self.get_ctypes_memory() |
|
|
|
c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)] |
|
|
|
pnt = cast(memory, POINTER(c_type)) |
|
|
|
return np.ctypeslib.as_array(pnt, self._layout.shapes) |
|
|
|
buffer = c_void_p() |
|
|
|
self._api.LITE_get_tensor_memory(self._tensor, byref(buffer)) |
|
|
|
buffer = self.np_array_type.from_address(buffer.value) |
|
|
|
return np.ctypeslib.as_array(buffer) |
|
|
|
|
|
|
|
def to_numpy(self): |
|
|
|
""" |
|
|
|