From d52ba79d8906a1a758051f3c3189c7440ecc1adc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 13 Apr 2022 18:25:08 +0800 Subject: [PATCH] fix(lite): support set data by copy on device tensor GitOrigin-RevId: 88b7f73d364d63bb3cad95ec063f283048895f94 --- lite/pylite/megenginelite/tensor.py | 32 +++++++++++++++----------------- lite/pylite/test/test_tensor.py | 24 ++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/lite/pylite/megenginelite/tensor.py b/lite/pylite/megenginelite/tensor.py index 897ac44b..4a1b7c78 100644 --- a/lite/pylite/megenginelite/tensor.py +++ b/lite/pylite/megenginelite/tensor.py @@ -407,7 +407,7 @@ class LiteTensor(object): def set_data_by_copy(self, data, data_length=0, layout=None): """ - copy the data to the tensor + copy the data to the tensor, the memory of the tensor must be continue param data: the data to copy to tensor, it should be list, numpy.ndarraya or ctypes with length """ @@ -415,37 +415,34 @@ class LiteTensor(object): self.layout = layout assert self.is_continue, "set_data_by_copy can only apply in continue tensor." - assert ( - self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU - ), "set_data_by_copy can only apply in cpu tensor or pinned tensor." c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)] - tensor_memory = c_void_p() + cpu_tensor = LiteTensor(self._layout) + tensor_length = self.nbytes if type(data) == list: length = len(data) - self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory)) - tensor_length = self.nbytes assert ( length * sizeof(c_type) <= tensor_length ), "the length of input data to set to the tensor is too large." - arr = (c_type * length)(*data) - memmove(tensor_memory, arr, sizeof(c_type) * length) + cdata = (c_type * length)(*data) + self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, tensor_length) + self.copy_from(cpu_tensor) elif type(data) == np.ndarray: - if self.nbytes != data.nbytes: - self.layout = LiteLayout(data.shape, data.dtype) - arr = data.ctypes.data_as(POINTER(c_type)) - self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory)) - assert self.nbytes == data.nbytes - memmove(tensor_memory, arr, self.nbytes) + self.layout = LiteLayout(data.shape, data.dtype) + cpu_tensor.layout = LiteLayout(data.shape, data.dtype) + cdata = data.ctypes.data_as(POINTER(c_type)) + self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, self.nbytes) + self.copy_from(cpu_tensor) + else: assert ( data_length == self.nbytes or layout is not None ), "when input data is ctypes, the length of input data or layout must set" - self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory)) - memmove(tensor_memory, data, data_length) + self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, data, tensor_length) + self.copy_from(cpu_tensor) def get_data_by_share(self): """ @@ -454,6 +451,7 @@ class LiteTensor(object): the tensor memory is write again, such as LiteNetwok forward next time. """ + self.update() buffer = c_void_p() self._api.LITE_get_tensor_memory(self._tensor, byref(buffer)) buffer = self.np_array_type.from_address(buffer.value) diff --git a/lite/pylite/test/test_tensor.py b/lite/pylite/test/test_tensor.py index 86af29c3..281f670f 100644 --- a/lite/pylite/test/test_tensor.py +++ b/lite/pylite/test/test_tensor.py @@ -323,3 +323,27 @@ def test_tensor_get_memory_by_share(): tensor.set_data_by_copy(arr) assert test_data[1][18] == 5 assert test_data[3][7] == 345 + + +@require_cuda +def test_tensor_set_data_device(): + layout = LiteLayout([2, 16], "int8") + tensor = LiteTensor(layout, device_type=LiteDeviceType.LITE_CUDA) + assert tensor.nbytes == 2 * 16 + + data = [i for i in range(32)] + tensor.set_data_by_copy(data) + real_data = tensor.to_numpy() + for i in range(32): + assert real_data[i // 16][i % 16] == i + + arr = np.ones([2, 16], "int8") + tensor.set_data_by_copy(arr) + real_data = tensor.to_numpy() + for i in range(32): + assert real_data[i // 16][i % 16] == 1 + + tensor.set_data_by_copy(list(range(32))) + real_data = tensor.to_numpy() + for i in range(32): + assert real_data[i // 16][i % 16] == i