Browse Source

feat(lite): add assert log for set_data_by_share

and set_data_by_copy. pylite network input is not
correct when input np is not continuous

GitOrigin-RevId: 1bdeae970a
HuaHua404-patch-4
Megvii Engine Team 2 years ago
parent
commit
10a0349eca
2 changed files with 33 additions and 0 deletions
  1. +6
    -0
      lite/pylite/megenginelite/tensor.py
  2. +27
    -0
      lite/pylite/test/test_tensor.py

+ 6
- 0
lite/pylite/megenginelite/tensor.py View File

@@ -423,6 +423,9 @@ class LiteTensor(object):
numpy.ndarray or ctypes data numpy.ndarray or ctypes data
""" """
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
assert data.flags[
"C_CONTIGUOUS"
], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_share"
assert ( assert (
self.is_continue self.is_continue
), "set_data_by_share can only apply in continue tensor." ), "set_data_by_share can only apply in continue tensor."
@@ -474,6 +477,9 @@ class LiteTensor(object):
self.copy_from(cpu_tensor) self.copy_from(cpu_tensor)


elif type(data) == np.ndarray: elif type(data) == np.ndarray:
assert data.flags[
"C_CONTIGUOUS"
], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_copy"
self.layout = LiteLayout(data.shape, data.dtype) self.layout = LiteLayout(data.shape, data.dtype)
cpu_tensor.layout = LiteLayout(data.shape, data.dtype) cpu_tensor.layout = LiteLayout(data.shape, data.dtype)
cdata = data.ctypes.data_as(POINTER(c_type)) cdata = data.ctypes.data_as(POINTER(c_type))


+ 27
- 0
lite/pylite/test/test_tensor.py View File

@@ -4,6 +4,7 @@ import functools


import numpy as np import numpy as np


import pytest
from megenginelite import * from megenginelite import *




@@ -89,6 +90,32 @@ def test_tensor_set_data():
assert real_data[1][3] == 20 assert real_data[1][3] == 20




def test_set_data_by_copy_not_continuous():
layout = LiteLayout()
tensor = LiteTensor(layout)

arr = np.arange(6).reshape(2, 3).astype(np.uint8).transpose(1, 0)

with pytest.raises(AssertionError):
tensor.set_data_by_copy(arr)

arr = np.ascontiguousarray(arr)
tensor.set_data_by_copy(arr)


def test_set_data_by_share_not_continuous():
layout = LiteLayout([2, 3], "int8")
tensor = LiteTensor(layout)

arr = np.arange(6).reshape(2, 3).astype(np.uint8).transpose(1, 0)

with pytest.raises(AssertionError):
tensor.set_data_by_share(arr, 2 * 3)

arr = np.ascontiguousarray(arr)
tensor.set_data_by_share(arr.ctypes.data, 2 * 3)


def test_fill_zero(): def test_fill_zero():
layout = LiteLayout([4, 8], "int16") layout = LiteLayout([4, 8], "int16")
tensor1 = LiteTensor(layout) tensor1 = LiteTensor(layout)


Loading…
Cancel
Save