GitOrigin-RevId: 0ddbb75e82
tags/v1.8.0
@@ -8,6 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .base import * | from .base import * | ||||
from .base import version as __version__ | |||||
from .global_setting import * | from .global_setting import * | ||||
from .network import * | from .network import * | ||||
from .struct import * | from .struct import * | ||||
@@ -69,7 +69,9 @@ class LiteOptions(Structure): | |||||
"const_shape": bool(self.const_shape), | "const_shape": bool(self.const_shape), | ||||
"force_dynamic_alloc": bool(self.force_dynamic_alloc), | "force_dynamic_alloc": bool(self.force_dynamic_alloc), | ||||
"force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc), | "force_output_dynamic_alloc": bool(self.force_output_dynamic_alloc), | ||||
"force_output_nocopy": bool(self.force_output_nocopy), | |||||
"force_output_use_user_specified_memory": bool( | |||||
self.force_output_use_user_specified_memory | |||||
), | |||||
"no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change), | "no_profiling_on_shape_change": bool(self.no_profiling_on_shape_change), | ||||
"jit_level": self.jit_level, | "jit_level": self.jit_level, | ||||
"comp_node_seq_record_level": self.comp_node_seq_record_level, | "comp_node_seq_record_level": self.comp_node_seq_record_level, | ||||
@@ -99,7 +101,7 @@ class LiteConfig(Structure): | |||||
("device_id", c_int), | ("device_id", c_int), | ||||
("device_type", c_int), | ("device_type", c_int), | ||||
("backend", c_int), | ("backend", c_int), | ||||
("bare_model_cryption_name", c_char_p), | |||||
("_bare_model_cryption_name", c_char_p), | |||||
("options", LiteOptions), | ("options", LiteOptions), | ||||
] | ] | ||||
@@ -110,18 +112,30 @@ class LiteConfig(Structure): | |||||
else: | else: | ||||
self.options = LiteOptions() | self.options = LiteOptions() | ||||
self.bare_model_cryption_name = c_char_p(b"") | |||||
self._bare_model_cryption_name = c_char_p(b"") | |||||
self.use_loader_dynamic_param = 0 | self.use_loader_dynamic_param = 0 | ||||
self.has_compression = 0 | self.has_compression = 0 | ||||
self.backend = LiteBackend.LITE_DEFAULT | self.backend = LiteBackend.LITE_DEFAULT | ||||
@property | |||||
def bare_model_cryption_name(self): | |||||
return self._bare_model_cryption_name.decode("utf-8") | |||||
@bare_model_cryption_name.setter | |||||
def bare_model_cryption_name(self, name): | |||||
if isinstance(name, str): | |||||
self._bare_model_cryption_name = name.encode("utf-8") | |||||
else: | |||||
assert isinstance(name, bytes), "name should be str or bytes type." | |||||
self._bare_model_cryption_name = name | |||||
def __repr__(self): | def __repr__(self): | ||||
data = { | data = { | ||||
"has_compression": bool(self.has_compression), | "has_compression": bool(self.has_compression), | ||||
"device_id": LiteDeviceType(self.device_id), | "device_id": LiteDeviceType(self.device_id), | ||||
"device_type": LiteDeviceType(self.device_type), | "device_type": LiteDeviceType(self.device_type), | ||||
"backend": LiteBackend(self.backend), | "backend": LiteBackend(self.backend), | ||||
"bare_model_cryption_name": self.bare_model_cryption_name.decode("utf-8"), | |||||
"bare_model_cryption_name": self.bare_model_cryption_name, | |||||
"options": self.options, | "options": self.options, | ||||
} | } | ||||
return data.__repr__() | return data.__repr__() | ||||
@@ -149,7 +163,7 @@ class LiteIO(Structure): | |||||
""" | """ | ||||
_fields_ = [ | _fields_ = [ | ||||
("name", c_char_p), | |||||
("_name", c_char_p), | |||||
("is_host", c_int), | ("is_host", c_int), | ||||
("io_type", c_int), | ("io_type", c_int), | ||||
("config_layout", LiteLayout), | ("config_layout", LiteLayout), | ||||
@@ -159,9 +173,9 @@ class LiteIO(Structure): | |||||
self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None | self, name, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None | ||||
): | ): | ||||
if type(name) == str: | if type(name) == str: | ||||
self.name = c_char_p(name.encode("utf-8")) | |||||
self._name = c_char_p(name.encode("utf-8")) | |||||
else: | else: | ||||
self.name = c_char_p(name) | |||||
self._name = c_char_p(name) | |||||
if layout: | if layout: | ||||
self.config_layout = layout | self.config_layout = layout | ||||
@@ -171,6 +185,18 @@ class LiteIO(Structure): | |||||
self.is_host = is_host | self.is_host = is_host | ||||
self.io_type = io_type | self.io_type = io_type | ||||
@property | |||||
def name(self): | |||||
return self._name.decode("utf-8") | |||||
@name.setter | |||||
def name(self, name): | |||||
if isinstance(name, str): | |||||
self._name = name.encode("utf-8") | |||||
else: | |||||
assert isinstance(name, bytes), "name should be str or bytes type." | |||||
self._name = name | |||||
def __repr__(self): | def __repr__(self): | ||||
data = { | data = { | ||||
"name": self.name, | "name": self.name, | ||||
@@ -208,17 +234,45 @@ class LiteNetworkIO(object): | |||||
the input and output information for user to construct _LiteNetWorkIO | the input and output information for user to construct _LiteNetWorkIO | ||||
""" | """ | ||||
def __init__(self): | |||||
def __init__(self, inputs=None, outputs=None): | |||||
self.inputs = [] | self.inputs = [] | ||||
self.outputs = [] | self.outputs = [] | ||||
if inputs: | |||||
for i in inputs: | |||||
if isinstance(i, list): | |||||
self.inputs.append(LiteIO(*i)) | |||||
else: | |||||
assert isinstance( | |||||
i, LiteIO | |||||
), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO." | |||||
self.inputs.append(i) | |||||
if outputs: | |||||
for i in outputs: | |||||
if isinstance(i, list): | |||||
self.outputs.append(LiteIO(*i)) | |||||
else: | |||||
assert isinstance( | |||||
i, LiteIO | |||||
), "the param to construct LiteNetworkIO must be list of the LiteIO member or the LiteIO." | |||||
self.outputs.append(i) | |||||
def add_input( | |||||
self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None | |||||
): | |||||
if isinstance(obj, LiteIO): | |||||
self.inputs.append(obj) | |||||
else: | |||||
name = obj | |||||
self.add_input(LiteIO(name, is_host, io_type, layout)) | |||||
def add_input(self, input_io): | |||||
assert isinstance(input_io, LiteIO) | |||||
self.inputs.append(input_io) | |||||
def add_output(self, output_io): | |||||
assert isinstance(output_io, LiteIO) | |||||
self.outputs.append(output_io) | |||||
def add_output( | |||||
self, obj, is_host=True, io_type=LiteIOType.LITE_IO_VALUE, layout=None | |||||
): | |||||
if isinstance(obj, LiteIO): | |||||
self.outputs.append(obj) | |||||
else: | |||||
name = obj | |||||
self.add_output(LiteIO(name, is_host, io_type, layout)) | |||||
def _create_network_io(self): | def _create_network_io(self): | ||||
network_io = _LiteNetworkIO() | network_io = _LiteNetworkIO() | ||||
@@ -48,6 +48,15 @@ ctype_to_lite_dtypes = { | |||||
c_ushort: LiteDataType.LITE_UINT16, | c_ushort: LiteDataType.LITE_UINT16, | ||||
} | } | ||||
_lite_dtypes_to_ctype = { | |||||
LiteDataType.LITE_INT: c_int, | |||||
LiteDataType.LITE_FLOAT: c_float, | |||||
LiteDataType.LITE_UINT8: c_ubyte, | |||||
LiteDataType.LITE_INT8: c_byte, | |||||
LiteDataType.LITE_INT16: c_short, | |||||
LiteDataType.LITE_UINT16: c_ushort, | |||||
} | |||||
class LiteLayout(Structure): | class LiteLayout(Structure): | ||||
""" | """ | ||||
@@ -55,7 +64,7 @@ class LiteLayout(Structure): | |||||
""" | """ | ||||
_fields_ = [ | _fields_ = [ | ||||
("shapes", c_size_t * MAX_DIM), | |||||
("_shapes", c_size_t * MAX_DIM), | |||||
("ndim", c_size_t), | ("ndim", c_size_t), | ||||
("data_type", c_int), | ("data_type", c_int), | ||||
] | ] | ||||
@@ -64,10 +73,10 @@ class LiteLayout(Structure): | |||||
if shape: | if shape: | ||||
shape = list(shape) | shape = list(shape) | ||||
assert len(shape) <= MAX_DIM, "Layout max dim is 7." | assert len(shape) <= MAX_DIM, "Layout max dim is 7." | ||||
self.shapes = (c_size_t * MAX_DIM)(*shape) | |||||
self._shapes = (c_size_t * MAX_DIM)(*shape) | |||||
self.ndim = len(shape) | self.ndim = len(shape) | ||||
else: | else: | ||||
self.shapes = (c_size_t * MAX_DIM)() | |||||
self._shapes = (c_size_t * MAX_DIM)() | |||||
self.ndim = 0 | self.ndim = 0 | ||||
if not dtype: | if not dtype: | ||||
self.data_type = LiteDataType.LITE_FLOAT | self.data_type = LiteDataType.LITE_FLOAT | ||||
@@ -83,9 +92,24 @@ class LiteLayout(Structure): | |||||
else: | else: | ||||
raise RuntimeError("unkonw data type") | raise RuntimeError("unkonw data type") | ||||
@property | |||||
def dtype(self): | |||||
return _lite_type_to_nptypes[LiteDataType(self.data_type)] | |||||
@property | |||||
def shapes(self): | |||||
return list(self._shapes)[0 : self.ndim] | |||||
@shapes.setter | |||||
def shapes(self, shape): | |||||
shape = list(shape) | |||||
assert len(shape) <= MAX_DIM, "Layout max dim is 7." | |||||
self._shapes = (c_size_t * MAX_DIM)(*shape) | |||||
self.ndim = len(shape) | |||||
def __repr__(self): | def __repr__(self): | ||||
data = { | data = { | ||||
"shapes": list(self.shapes)[0 : self.ndim], | |||||
"shapes": self.shapes, | |||||
"ndim": self.ndim, | "ndim": self.ndim, | ||||
"data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)], | "data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)], | ||||
} | } | ||||
@@ -177,15 +201,20 @@ class LiteTensor(object): | |||||
device_type=LiteDeviceType.LITE_CPU, | device_type=LiteDeviceType.LITE_CPU, | ||||
device_id=0, | device_id=0, | ||||
is_pinned_host=False, | is_pinned_host=False, | ||||
shapes=None, | |||||
dtype=None, | |||||
): | ): | ||||
""" | """ | ||||
create a Tensor with layout, device, is_pinned_host param | |||||
create a Tensor with layout, device, is_pinned_host or shapes, dtype, | |||||
device_type, device_id, is_pinned_host param | |||||
""" | """ | ||||
self._tensor = _Ctensor() | self._tensor = _Ctensor() | ||||
if layout: | |||||
self._layout = LiteLayout() | |||||
if layout is not None: | |||||
self._layout = layout | self._layout = layout | ||||
else: | |||||
self._layout = LiteLayout() | |||||
elif shapes is not None: | |||||
shapes = list(shapes) | |||||
self._layout = LiteLayout(shapes, dtype) | |||||
self._device_type = device_type | self._device_type = device_type | ||||
self._device_id = device_id | self._device_id = device_id | ||||
self._is_pinned_host = is_pinned_host | self._is_pinned_host = is_pinned_host | ||||
@@ -222,9 +251,12 @@ class LiteTensor(object): | |||||
@layout.setter | @layout.setter | ||||
def layout(self, layout): | def layout(self, layout): | ||||
assert isinstance(layout, LiteLayout) | |||||
self._layout = layout | |||||
self._api.LITE_set_tensor_layout(self._tensor, layout) | |||||
if isinstance(layout, LiteLayout): | |||||
self._layout = layout | |||||
elif isinstance(layout, list): | |||||
self._layout.shapes = layout | |||||
self._api.LITE_set_tensor_layout(self._tensor, self._layout) | |||||
@property | @property | ||||
def is_pinned_host(self): | def is_pinned_host(self): | ||||
@@ -270,7 +302,6 @@ class LiteTensor(object): | |||||
""" | """ | ||||
get the length of the meomry in byte | get the length of the meomry in byte | ||||
""" | """ | ||||
self.update() | |||||
length = c_size_t() | length = c_size_t() | ||||
self._api.LITE_get_tensor_total_size_in_byte(self._tensor, byref(length)) | self._api.LITE_get_tensor_total_size_in_byte(self._tensor, byref(length)) | ||||
return length.value | return length.value | ||||
@@ -336,7 +367,6 @@ class LiteTensor(object): | |||||
""" | """ | ||||
get the memory of the tensor, return c_void_p of the tensor memory | get the memory of the tensor, return c_void_p of the tensor memory | ||||
""" | """ | ||||
self.update() | |||||
mem = c_void_p() | mem = c_void_p() | ||||
self._api.LITE_get_tensor_memory(self._tensor, byref(mem)) | self._api.LITE_get_tensor_memory(self._tensor, byref(mem)) | ||||
return mem | return mem | ||||
@@ -347,7 +377,6 @@ class LiteTensor(object): | |||||
param data: the data will shared to the tensor, it should be a | param data: the data will shared to the tensor, it should be a | ||||
numpy.ndarray or ctypes data | numpy.ndarray or ctypes data | ||||
""" | """ | ||||
self.update() | |||||
if isinstance(data, np.ndarray): | if isinstance(data, np.ndarray): | ||||
assert ( | assert ( | ||||
self.is_continue | self.is_continue | ||||
@@ -356,8 +385,7 @@ class LiteTensor(object): | |||||
self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU | self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU | ||||
), "set_data_by_share can only apply in cpu tensor or pinned tensor." | ), "set_data_by_share can only apply in cpu tensor or pinned tensor." | ||||
np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)] | |||||
c_type = np.ctypeslib.as_ctypes_type(np_type) | |||||
c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)] | |||||
if self.nbytes != data.nbytes: | if self.nbytes != data.nbytes: | ||||
self.layout = LiteLayout(data.shape, ctype_to_lite_dtypes[c_type]) | self.layout = LiteLayout(data.shape, ctype_to_lite_dtypes[c_type]) | ||||
@@ -377,7 +405,6 @@ class LiteTensor(object): | |||||
param data: the data to copy to tensor, it should be list, | param data: the data to copy to tensor, it should be list, | ||||
numpy.ndarraya or ctypes with length | numpy.ndarraya or ctypes with length | ||||
""" | """ | ||||
self.update() | |||||
if layout is not None: | if layout is not None: | ||||
self.layout = layout | self.layout = layout | ||||
@@ -386,8 +413,7 @@ class LiteTensor(object): | |||||
self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU | self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU | ||||
), "set_data_by_copy can only apply in cpu tensor or pinned tensor." | ), "set_data_by_copy can only apply in cpu tensor or pinned tensor." | ||||
np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)] | |||||
c_type = np.ctypeslib.as_ctypes_type(np_type) | |||||
c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)] | |||||
tensor_memory = c_void_p() | tensor_memory = c_void_p() | ||||
@@ -415,6 +441,22 @@ class LiteTensor(object): | |||||
self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory)) | self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory)) | ||||
memmove(tensor_memory, data, data_length) | memmove(tensor_memory, data, data_length) | ||||
def get_data_by_share(self): | |||||
""" | |||||
get the data in the tensor, add share the data with a new numpy, and | |||||
return the numpy arrray, be careful, the data in numpy is valid before | |||||
the tensor memory is write again, such as LiteNetwok forward next time. | |||||
""" | |||||
assert self.is_continue, "get_data_by_share can only apply in continue tensor." | |||||
assert ( | |||||
self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU | |||||
), "get_data_by_share can only apply in CPU tensor or cpu pinned tensor." | |||||
memory = self.get_ctypes_memory() | |||||
c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)] | |||||
pnt = cast(memory, POINTER(c_type)) | |||||
return np.ctypeslib.as_array(pnt, self._layout.shapes) | |||||
def to_numpy(self): | def to_numpy(self): | ||||
""" | """ | ||||
get the buffer of the tensor | get the buffer of the tensor | ||||
@@ -475,3 +517,13 @@ def LiteTensorConcat( | |||||
) | ) | ||||
result_tensor.update() | result_tensor.update() | ||||
return result_tensor | return result_tensor | ||||
def lite_dtype_2_numpy(dtype): | |||||
""" | |||||
convert lite dtype to corresponding numpy dtype | |||||
""" | |||||
assert isinstance( | |||||
dtype, LiteDataType | |||||
), "input must be LiteDataType when using lite_dtype_2_numpy." | |||||
return _lite_type_to_nptypes[dtype] |
@@ -21,6 +21,12 @@ def test_version(): | |||||
print("Lite verson: {}".format(version)) | print("Lite verson: {}".format(version)) | ||||
def test_config(): | |||||
config = LiteConfig() | |||||
config.bare_model_cryption_name = "nothing" | |||||
print(config) | |||||
def test_network_io(): | def test_network_io(): | ||||
input_io1 = LiteIO("data1", is_host=False, io_type=LiteIOType.LITE_IO_VALUE) | input_io1 = LiteIO("data1", is_host=False, io_type=LiteIOType.LITE_IO_VALUE) | ||||
input_io2 = LiteIO( | input_io2 = LiteIO( | ||||
@@ -32,6 +38,7 @@ def test_network_io(): | |||||
io = LiteNetworkIO() | io = LiteNetworkIO() | ||||
io.add_input(input_io1) | io.add_input(input_io1) | ||||
io.add_input(input_io2) | io.add_input(input_io2) | ||||
io.add_input("data3", False) | |||||
output_io1 = LiteIO("out1", is_host=False) | output_io1 = LiteIO("out1", is_host=False) | ||||
output_io2 = LiteIO("out2", is_host=True, layout=LiteLayout([1, 1000])) | output_io2 = LiteIO("out2", is_host=True, layout=LiteLayout([1, 1000])) | ||||
@@ -39,7 +46,7 @@ def test_network_io(): | |||||
io.add_output(output_io1) | io.add_output(output_io1) | ||||
io.add_output(output_io2) | io.add_output(output_io2) | ||||
assert len(io.inputs) == 2 | |||||
assert len(io.inputs) == 3 | |||||
assert len(io.outputs) == 2 | assert len(io.outputs) == 2 | ||||
assert io.inputs[0] == input_io1 | assert io.inputs[0] == input_io1 | ||||
@@ -47,9 +54,25 @@ def test_network_io(): | |||||
c_io = io._create_network_io() | c_io = io._create_network_io() | ||||
assert c_io.input_size == 2 | |||||
assert c_io.input_size == 3 | |||||
assert c_io.output_size == 2 | assert c_io.output_size == 2 | ||||
ins = [["data1", True], ["data2", False, LiteIOType.LITE_IO_SHAPE]] | |||||
outs = [["out1", True], ["out2", False, LiteIOType.LITE_IO_VALUE]] | |||||
io2 = LiteNetworkIO(ins, outs) | |||||
assert len(io2.inputs) == 2 | |||||
assert len(io2.outputs) == 2 | |||||
io3 = LiteNetworkIO([input_io1, input_io2], [output_io1, output_io2]) | |||||
assert len(io3.inputs) == 2 | |||||
assert len(io3.outputs) == 2 | |||||
test_io = LiteIO("test") | |||||
assert test_io.name == "test" | |||||
test_io.name = "test2" | |||||
assert test_io.name == "test2" | |||||
class TestShuffleNet(unittest.TestCase): | class TestShuffleNet(unittest.TestCase): | ||||
source_dir = os.getenv("LITE_TEST_RESOURCE") | source_dir = os.getenv("LITE_TEST_RESOURCE") | ||||
@@ -319,9 +342,9 @@ class TestNetwork(TestShuffleNet): | |||||
data = ios[key].to_numpy().flatten() | data = ios[key].to_numpy().flatten() | ||||
input_data = self.input_data.flatten() | input_data = self.input_data.flatten() | ||||
assert data.size == input_data.size | assert data.size == input_data.size | ||||
assert io.name.decode("utf-8") == "data" | |||||
assert io.name == "data" | |||||
for i in range(data.size): | for i in range(data.size): | ||||
assert data[i] == input_data[i] | |||||
assert abs(data[i] - input_data[i]) < 1e-5 | |||||
return 0 | return 0 | ||||
network.set_start_callback(start_callback) | network.set_start_callback(start_callback) | ||||
@@ -343,7 +366,7 @@ class TestNetwork(TestShuffleNet): | |||||
output_data = self.correct_data.flatten() | output_data = self.correct_data.flatten() | ||||
assert data.size == output_data.size | assert data.size == output_data.size | ||||
for i in range(data.size): | for i in range(data.size): | ||||
assert data[i] == output_data[i] | |||||
assert abs(data[i] - output_data[i]) < 1e-5 | |||||
return 0 | return 0 | ||||
network.set_finish_callback(finish_callback) | network.set_finish_callback(finish_callback) | ||||
@@ -404,3 +427,27 @@ class TestNetwork(TestShuffleNet): | |||||
binary_equal_between_batch=True, | binary_equal_between_batch=True, | ||||
) | ) | ||||
self.do_forward(network) | self.do_forward(network) | ||||
def test_device_tensor_no_copy(self): | |||||
# construct LiteOption | |||||
net_config = LiteConfig() | |||||
net_config.options.force_output_use_user_specified_memory = True | |||||
network = LiteNetwork(config=net_config) | |||||
network.load(self.model_path) | |||||
input_tensor = network.get_io_tensor("data") | |||||
# fill input_data with device data | |||||
input_tensor.set_data_by_share(self.input_data) | |||||
output_tensor = network.get_io_tensor(network.get_output_name(0)) | |||||
out_array = np.zeros(output_tensor.layout.shapes, output_tensor.layout.dtype) | |||||
output_tensor.set_data_by_share(out_array) | |||||
# inference | |||||
for i in range(2): | |||||
network.forward() | |||||
network.wait() | |||||
self.check_correct(out_array) |
@@ -54,6 +54,16 @@ def test_tensor_make(): | |||||
tensor = LiteTensor(layout, device_id=1) | tensor = LiteTensor(layout, device_id=1) | ||||
assert tensor.device_id == 1 | assert tensor.device_id == 1 | ||||
tensor.layout = [8, 14] | |||||
assert tensor.layout.shapes[0] == 8 | |||||
assert tensor.layout.shapes[1] == 14 | |||||
assert tensor.layout.data_type == LiteDataType.LITE_FLOAT | |||||
tensor_new = LiteTensor(shapes=[1, 3, 224], dtype=np.int8) | |||||
assert tensor_new.layout.shapes[1] == 3 | |||||
assert tensor_new.layout.shapes[2] == 224 | |||||
assert tensor_new.layout.data_type == LiteDataType.LITE_INT8 | |||||
def test_tensor_set_data(): | def test_tensor_set_data(): | ||||
layout = LiteLayout([2, 16], "int8") | layout = LiteLayout([2, 16], "int8") | ||||
@@ -292,3 +302,24 @@ def test_tensor_concat(): | |||||
for i in range(128): | for i in range(128): | ||||
index = j * 128 + i | index = j * 128 + i | ||||
assert real_data[index // 32][index % 32] == j | assert real_data[index // 32][index % 32] == j | ||||
def test_tensor_get_memory_by_share(): | |||||
layout = LiteLayout([4, 32], "int16") | |||||
tensor = LiteTensor(layout) | |||||
assert tensor.nbytes == 4 * 32 * 2 | |||||
arr = np.ones([4, 32], "int16") | |||||
for i in range(128): | |||||
arr[i // 32][i % 32] = i | |||||
tensor.set_data_by_copy(arr) | |||||
test_data = tensor.get_data_by_share() | |||||
real_data = tensor.to_numpy() | |||||
for i in range(128): | |||||
assert real_data[i // 32][i % 32] == test_data[i // 32][i % 32] | |||||
arr[1][18] = 5 | |||||
arr[3][7] = 345 | |||||
tensor.set_data_by_copy(arr) | |||||
assert test_data[1][18] == 5 | |||||
assert test_data[3][7] == 345 |