and set_data_by_copy. pylite network input is not
correct when input np is not continuous
GitOrigin-RevId: 1bdeae970a
HuaHua404-patch-4
@@ -423,6 +423,9 @@ class LiteTensor(object): | |||||
numpy.ndarray or ctypes data | numpy.ndarray or ctypes data | ||||
""" | """ | ||||
if isinstance(data, np.ndarray): | if isinstance(data, np.ndarray): | ||||
assert data.flags[ | |||||
"C_CONTIGUOUS" | |||||
], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_share" | |||||
assert ( | assert ( | ||||
self.is_continue | self.is_continue | ||||
), "set_data_by_share can only apply in continue tensor." | ), "set_data_by_share can only apply in continue tensor." | ||||
@@ -474,6 +477,9 @@ class LiteTensor(object): | |||||
self.copy_from(cpu_tensor) | self.copy_from(cpu_tensor) | ||||
elif type(data) == np.ndarray: | elif type(data) == np.ndarray: | ||||
assert data.flags[ | |||||
"C_CONTIGUOUS" | |||||
], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_copy" | |||||
self.layout = LiteLayout(data.shape, data.dtype) | self.layout = LiteLayout(data.shape, data.dtype) | ||||
cpu_tensor.layout = LiteLayout(data.shape, data.dtype) | cpu_tensor.layout = LiteLayout(data.shape, data.dtype) | ||||
cdata = data.ctypes.data_as(POINTER(c_type)) | cdata = data.ctypes.data_as(POINTER(c_type)) | ||||
@@ -4,6 +4,7 @@ import functools | |||||
import numpy as np | import numpy as np | ||||
import pytest | |||||
from megenginelite import * | from megenginelite import * | ||||
@@ -89,6 +90,32 @@ def test_tensor_set_data(): | |||||
assert real_data[1][3] == 20 | assert real_data[1][3] == 20 | ||||
def test_set_data_by_copy_not_continuous(): | |||||
layout = LiteLayout() | |||||
tensor = LiteTensor(layout) | |||||
arr = np.arange(6).reshape(2, 3).astype(np.uint8).transpose(1, 0) | |||||
with pytest.raises(AssertionError): | |||||
tensor.set_data_by_copy(arr) | |||||
arr = np.ascontiguousarray(arr) | |||||
tensor.set_data_by_copy(arr) | |||||
def test_set_data_by_share_not_continuous(): | |||||
layout = LiteLayout([2, 3], "int8") | |||||
tensor = LiteTensor(layout) | |||||
arr = np.arange(6).reshape(2, 3).astype(np.uint8).transpose(1, 0) | |||||
with pytest.raises(AssertionError): | |||||
tensor.set_data_by_share(arr, 2 * 3) | |||||
arr = np.ascontiguousarray(arr) | |||||
tensor.set_data_by_share(arr.ctypes.data, 2 * 3) | |||||
def test_fill_zero(): | def test_fill_zero(): | ||||
layout = LiteLayout([4, 8], "int16") | layout = LiteLayout([4, 8], "int16") | ||||
tensor1 = LiteTensor(layout) | tensor1 = LiteTensor(layout) | ||||