You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. # -*- coding: utf-8 -*-
  2. from ctypes import *
  3. import numpy as np
  4. from .base import _Ctensor, _lib, _LiteCObjBase
  5. from .struct import LiteDataType, LiteDeviceType, LiteIOType, Structure
  6. MAX_DIM = 7
  7. _lite_type_to_nptypes = {
  8. LiteDataType.LITE_INT: np.int32,
  9. LiteDataType.LITE_FLOAT: np.float32,
  10. LiteDataType.LITE_UINT8: np.uint8,
  11. LiteDataType.LITE_INT8: np.int8,
  12. LiteDataType.LITE_INT16: np.int16,
  13. LiteDataType.LITE_UINT16: np.uint16,
  14. LiteDataType.LITE_HALF: np.float16,
  15. }
  16. _nptype_to_lite_type = {val: key for key, val in _lite_type_to_nptypes.items()}
  17. _str_nptypes_to_lite_nptypes = {
  18. np.dtype("int32"): LiteDataType.LITE_INT,
  19. np.dtype("float32"): LiteDataType.LITE_FLOAT,
  20. np.dtype("uint8"): LiteDataType.LITE_UINT8,
  21. np.dtype("int8"): LiteDataType.LITE_INT8,
  22. np.dtype("int16"): LiteDataType.LITE_INT16,
  23. np.dtype("uint16"): LiteDataType.LITE_UINT16,
  24. np.dtype("float16"): LiteDataType.LITE_HALF,
  25. }
  26. ctype_to_lite_dtypes = {
  27. c_int: LiteDataType.LITE_INT,
  28. c_uint: LiteDataType.LITE_INT,
  29. c_float: LiteDataType.LITE_FLOAT,
  30. c_ubyte: LiteDataType.LITE_UINT8,
  31. c_byte: LiteDataType.LITE_INT8,
  32. c_short: LiteDataType.LITE_INT16,
  33. c_ushort: LiteDataType.LITE_UINT16,
  34. }
  35. _lite_dtypes_to_ctype = {
  36. LiteDataType.LITE_INT: c_int,
  37. LiteDataType.LITE_FLOAT: c_float,
  38. LiteDataType.LITE_UINT8: c_ubyte,
  39. LiteDataType.LITE_INT8: c_byte,
  40. LiteDataType.LITE_INT16: c_short,
  41. LiteDataType.LITE_UINT16: c_ushort,
  42. }
  43. class LiteLayout(Structure):
  44. """
  45. Description of layout using in Lite. A Lite layout will be totally defined
  46. by shape and data type.
  47. Args:
  48. shape: the shape of data.
  49. dtype: data type.
  50. Note:
  51. Dims of shape should be less than 8. The supported data type defines at
  52. LiteDataType
  53. Examples:
  54. .. code-block:: python
  55. import numpy as np
  56. layout = LiteLayout([1, 4, 8, 8], LiteDataType.LITE_FLOAT)
  57. assert(layout.shape()) == [1, 4, 8, 8]
  58. assert(layout.dtype()) == LiteDataType.LITE_FLOAT
  59. """
  60. _fields_ = [
  61. ("_shapes", c_size_t * MAX_DIM),
  62. ("ndim", c_size_t),
  63. ("data_type", c_int),
  64. ]
  65. def __init__(self, shape=None, dtype=None):
  66. if shape:
  67. shape = list(shape)
  68. assert len(shape) <= MAX_DIM, "Layout max dim is 7."
  69. self._shapes = (c_size_t * MAX_DIM)(*shape)
  70. self.ndim = len(shape)
  71. else:
  72. self._shapes = (c_size_t * MAX_DIM)()
  73. self.ndim = 0
  74. if not dtype:
  75. self.data_type = LiteDataType.LITE_FLOAT
  76. elif isinstance(dtype, LiteDataType):
  77. self.data_type = dtype
  78. elif type(dtype) == str:
  79. self.data_type = _str_nptypes_to_lite_nptypes[np.dtype(dtype)]
  80. elif isinstance(dtype, np.dtype):
  81. ctype = np.ctypeslib.as_ctypes_type(dtype)
  82. self.data_type = ctype_to_lite_dtypes[ctype]
  83. elif isinstance(dtype, type):
  84. self.data_type = _nptype_to_lite_type[dtype]
  85. else:
  86. raise RuntimeError("unkonw data type")
  87. @property
  88. def dtype(self):
  89. return _lite_type_to_nptypes[LiteDataType(self.data_type)]
  90. @property
  91. def shapes(self):
  92. return list(self._shapes)[0 : self.ndim]
  93. @shapes.setter
  94. def shapes(self, shape):
  95. shape = list(shape)
  96. assert len(shape) <= MAX_DIM, "Layout max dim is 7."
  97. self._shapes = (c_size_t * MAX_DIM)(*shape)
  98. self.ndim = len(shape)
  99. def __repr__(self):
  100. data = {
  101. "shapes": self.shapes,
  102. "ndim": self.ndim,
  103. "data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)],
  104. }
  105. return data.__repr__()
  106. class _LiteTensorDesc(Structure):
  107. """
  108. warpper of the MegEngine Tensor
  109. Args:
  110. is_pinned_host: when set, the storage memory of the tensor is pinned
  111. memory. This is used to Optimize the H2D or D2H memory copy, if the
  112. device or layout is not set, when copy form other device(CUDA)
  113. tensor, this tensor will be automatically set to pinned tensor
  114. layout(LiteLayout): layout of this tensor
  115. device_type: type of device
  116. device_id: id of device
  117. """
  118. _fields_ = [
  119. ("is_pinned_host", c_int),
  120. ("layout", LiteLayout),
  121. ("device_type", c_int),
  122. ("device_id", c_int),
  123. ]
  124. def __init__(self):
  125. self.layout = LiteLayout()
  126. self.device_type = LiteDeviceType.LITE_CPU
  127. self.is_pinned_host = False
  128. self.device_id = 0
  129. def __repr__(self):
  130. data = {
  131. "is_pinned_host": self.is_pinned_host,
  132. "layout": LiteLayout(self.layout),
  133. "device_type": LiteDeviceType(self.device_type.value),
  134. "device_id": self.device_id,
  135. }
  136. return data.__repr__()
  137. class _TensorAPI(_LiteCObjBase):
  138. """
  139. Get the API from the lib
  140. """
  141. _api_ = [
  142. ("LITE_make_tensor", [_LiteTensorDesc, POINTER(_Ctensor)]),
  143. ("LITE_set_tensor_layout", [_Ctensor, LiteLayout]),
  144. ("LITE_reset_tensor_memory", [_Ctensor, c_void_p, c_size_t]),
  145. ("LITE_reset_tensor", [_Ctensor, LiteLayout, c_void_p]),
  146. ("LITE_tensor_reshape", [_Ctensor, POINTER(c_int), c_int]),
  147. (
  148. "LITE_tensor_slice",
  149. [
  150. _Ctensor,
  151. POINTER(c_size_t),
  152. POINTER(c_size_t),
  153. POINTER(c_size_t),
  154. c_size_t,
  155. POINTER(_Ctensor),
  156. ],
  157. ),
  158. (
  159. "LITE_tensor_concat",
  160. [POINTER(_Ctensor), c_int, c_int, c_int, c_int, POINTER(_Ctensor),],
  161. ),
  162. ("LITE_tensor_fill_zero", [_Ctensor]),
  163. ("LITE_tensor_copy", [_Ctensor, _Ctensor]),
  164. ("LITE_tensor_share_memory_with", [_Ctensor, _Ctensor]),
  165. ("LITE_get_tensor_memory", [_Ctensor, POINTER(c_void_p)]),
  166. ("LITE_get_tensor_total_size_in_byte", [_Ctensor, POINTER(c_size_t)]),
  167. ("LITE_get_tensor_layout", [_Ctensor, POINTER(LiteLayout)]),
  168. ("LITE_get_tensor_device_type", [_Ctensor, POINTER(c_int)]),
  169. ("LITE_get_tensor_device_id", [_Ctensor, POINTER(c_int)]),
  170. ("LITE_destroy_tensor", [_Ctensor]),
  171. ("LITE_is_pinned_host", [_Ctensor, POINTER(c_int)]),
  172. ]
  173. class LiteTensor(object):
  174. """
  175. Description of a block of data with neccessary information.
  176. Args:
  177. layout: layout of Tensor
  178. device_type: device type of Tensor
  179. device_id: device id of Tensor
  180. is_pinned_host: when set, the storage memory of the tensor is pinned
  181. memory. This is used to Optimize the H2D or D2H memory copy, if the
  182. device or layout is not set, when copy form other device(CUDA)
  183. tensor, this tensor will be automatically set to pinned tensor
  184. shapes: the shape of data
  185. dtype: data type
  186. Note:
  187. Dims of shape should be less than 8. The supported data type defines at
  188. LiteDataType
  189. """
  190. _api = _TensorAPI()._lib
  191. def __init__(
  192. self,
  193. layout=None,
  194. device_type=LiteDeviceType.LITE_CPU,
  195. device_id=0,
  196. is_pinned_host=False,
  197. shapes=None,
  198. dtype=None,
  199. ):
  200. self._tensor = _Ctensor()
  201. self._layout = LiteLayout()
  202. if layout is not None:
  203. self._layout = layout
  204. elif shapes is not None:
  205. shapes = list(shapes)
  206. self._layout = LiteLayout(shapes, dtype)
  207. self._device_type = device_type
  208. self._device_id = device_id
  209. self._is_pinned_host = is_pinned_host
  210. tensor_desc = _LiteTensorDesc()
  211. tensor_desc.layout = self._layout
  212. tensor_desc.device_type = device_type
  213. tensor_desc.device_id = device_id
  214. tensor_desc.is_pinned_host = is_pinned_host
  215. self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
  216. self.update()
  217. def __del__(self):
  218. self._api.LITE_destroy_tensor(self._tensor)
  219. def fill_zero(self):
  220. """
  221. fill the buffer memory with zero
  222. """
  223. self._api.LITE_tensor_fill_zero(self._tensor)
  224. self.update()
  225. def share_memory_with(self, src_tensor):
  226. """
  227. share the same memory with the ``src_tensor``, the self memory will be
  228. freed
  229. Args:
  230. src_tensor: the source tensor that will share memory with this tensor
  231. """
  232. assert isinstance(src_tensor, LiteTensor)
  233. self._api.LITE_tensor_share_memory_with(self._tensor, src_tensor._tensor)
  234. self.update()
  235. @property
  236. def layout(self):
  237. self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
  238. return self._layout
  239. @layout.setter
  240. def layout(self, layout):
  241. if isinstance(layout, LiteLayout):
  242. self._layout = layout
  243. elif isinstance(layout, list):
  244. self._layout.shapes = layout
  245. self._api.LITE_set_tensor_layout(self._tensor, self._layout)
  246. @property
  247. def is_pinned_host(self):
  248. """
  249. whether the tensor is pinned tensor
  250. """
  251. pinned = c_int()
  252. self._api.LITE_is_pinned_host(self._tensor, byref(pinned))
  253. self._is_pinned_host = pinned
  254. return bool(self._is_pinned_host)
  255. @property
  256. def device_type(self):
  257. """
  258. get device type of the tensor
  259. """
  260. device_type = c_int()
  261. self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
  262. self._device_type = device_type
  263. return LiteDeviceType(device_type.value)
  264. @property
  265. def device_id(self):
  266. """
  267. get device id of the tensor
  268. """
  269. device_id = c_int()
  270. self._api.LITE_get_tensor_device_id(self._tensor, byref(device_id))
  271. self._device_id = device_id.value
  272. return device_id.value
  273. @property
  274. def is_continue(self):
  275. """
  276. whether the tensor memory is continue
  277. """
  278. is_continue = c_int()
  279. self._api.LITE_is_memory_continue(self._tensor, byref(is_continue))
  280. return bool(is_continue.value)
  281. @property
  282. def nbytes(self):
  283. """
  284. get the length of the meomry in byte
  285. """
  286. length = c_size_t()
  287. self._api.LITE_get_tensor_total_size_in_byte(self._tensor, byref(length))
  288. return length.value
  289. def update(self):
  290. """
  291. update the member from C, this will auto used after slice, share
  292. """
  293. pinned = c_int()
  294. self._api.LITE_is_pinned_host(self._tensor, byref(pinned))
  295. self._is_pinned_host = pinned
  296. device_type = c_int()
  297. self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
  298. self._device_type = device_type
  299. self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
  300. c_types = _lite_dtypes_to_ctype[self._layout.data_type]
  301. self.np_array_type = np.ctypeslib._ctype_ndarray(
  302. c_types, list(self._layout.shapes)[0 : self._layout.ndim]
  303. )
  304. def copy_from(self, src_tensor):
  305. """
  306. copy memory form the src_tensor
  307. Args:
  308. src_tensor: source tensor
  309. """
  310. assert isinstance(src_tensor, LiteTensor)
  311. self._api.LITE_tensor_copy(self._tensor, src_tensor._tensor)
  312. self.update()
  313. def reshape(self, shape):
  314. """
  315. reshape the tensor with data not change.
  316. Args:
  317. shape: target shape
  318. """
  319. shape = list(shape)
  320. length = len(shape)
  321. c_shape = (c_int * length)(*shape)
  322. self._api.LITE_tensor_reshape(self._tensor, c_shape, length)
  323. self.update()
  324. def slice(self, start, end, step=None):
  325. """
  326. slice the tensor with gaven start, end, step
  327. Args:
  328. start: silce begin index of each dim
  329. end: silce end index of each dim
  330. step: silce step of each dim
  331. """
  332. start = list(start)
  333. end = list(end)
  334. length = len(start)
  335. assert length == len(end), "slice with different length of start and end."
  336. if step:
  337. assert length == len(step), "slice with different length of start and step."
  338. step = list(step)
  339. else:
  340. step = [1 for i in range(length)]
  341. c_start = (c_size_t * length)(*start)
  342. c_end = (c_size_t * length)(*end)
  343. c_step = (c_size_t * length)(*step)
  344. slice_tensor = LiteTensor()
  345. self._api.LITE_tensor_slice(
  346. self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor),
  347. )
  348. slice_tensor.update()
  349. return slice_tensor
  350. def get_ctypes_memory(self):
  351. """
  352. get the memory of the tensor, return c_void_p of the tensor memory
  353. """
  354. mem = c_void_p()
  355. self._api.LITE_get_tensor_memory(self._tensor, byref(mem))
  356. return mem
  357. def set_data_by_share(self, data, length=0, layout=None):
  358. """
  359. share the data to the tensor
  360. Args:
  361. data: the data will shared to the tensor, it should be a
  362. numpy.ndarray or ctypes data
  363. """
  364. if isinstance(data, np.ndarray):
  365. assert data.flags[
  366. "C_CONTIGUOUS"
  367. ], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_share"
  368. assert (
  369. self.is_continue
  370. ), "set_data_by_share can only apply in continue tensor."
  371. assert (
  372. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  373. ), "set_data_by_share can only apply in cpu tensor or pinned tensor."
  374. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  375. if self.nbytes != data.nbytes:
  376. self.layout = LiteLayout(data.shape, ctype_to_lite_dtypes[c_type])
  377. self._shared_data = data
  378. data = data.ctypes.data_as(POINTER(c_type))
  379. if layout is not None:
  380. self.layout = layout
  381. else:
  382. assert length == 0 or length == self.nbytes, "the data length is not match."
  383. self._api.LITE_reset_tensor_memory(self._tensor, data, self.nbytes)
  384. def set_data_by_copy(self, data, data_length=0, layout=None):
  385. """
  386. copy the data to the tensor
  387. Args:
  388. data: the data to copy to tensor, it should be list, numpy.ndarraya
  389. or ctypes with length
  390. data_length: length of data in bytes
  391. layout: layout of data
  392. """
  393. if layout is not None:
  394. self.layout = layout
  395. assert self.is_continue, "set_data_by_copy can only apply in continue tensor."
  396. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  397. cpu_tensor = LiteTensor(self._layout)
  398. tensor_length = self.nbytes
  399. if type(data) == list:
  400. length = len(data)
  401. assert (
  402. length * sizeof(c_type) <= tensor_length
  403. ), "the length of input data to set to the tensor is too large."
  404. cdata = (c_type * length)(*data)
  405. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, tensor_length)
  406. self.copy_from(cpu_tensor)
  407. elif type(data) == np.ndarray:
  408. assert data.flags[
  409. "C_CONTIGUOUS"
  410. ], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_copy"
  411. self.layout = LiteLayout(data.shape, data.dtype)
  412. cpu_tensor.layout = LiteLayout(data.shape, data.dtype)
  413. cdata = data.ctypes.data_as(POINTER(c_type))
  414. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, self.nbytes)
  415. self.copy_from(cpu_tensor)
  416. else:
  417. assert (
  418. data_length == self.nbytes or layout is not None
  419. ), "when input data is ctypes, the length of input data or layout must set"
  420. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, data, tensor_length)
  421. self.copy_from(cpu_tensor)
  422. def get_data_by_share(self):
  423. """
  424. get the data in the tensor, add share the data with a new numpy, and
  425. return the numpy arrray
  426. Note:
  427. Be careful, the data in numpy is valid before the tensor memory is
  428. write again, such as LiteNetwok forward next time.
  429. """
  430. self.update()
  431. buffer = c_void_p()
  432. self._api.LITE_get_tensor_memory(self._tensor, byref(buffer))
  433. buffer = self.np_array_type.from_address(buffer.value)
  434. return np.ctypeslib.as_array(buffer)
  435. def to_numpy(self):
  436. """
  437. get the buffer of the tensor
  438. """
  439. self.update()
  440. if self.nbytes <= 0:
  441. np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)]
  442. return np.array([], dtype=np_type)
  443. if self.is_continue and (
  444. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  445. ):
  446. ptr = c_void_p()
  447. self._api.LITE_get_tensor_memory(self._tensor, byref(ptr))
  448. np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)]
  449. shape = [self._layout.shapes[i] for i in range(self._layout.ndim)]
  450. np_arr = np.zeros(shape, np_type)
  451. if np_arr.nbytes:
  452. memmove(np_arr.ctypes.data_as(c_void_p), ptr, np_arr.nbytes)
  453. return np_arr
  454. else:
  455. tmp_tensor = LiteTensor(self.layout)
  456. tmp_tensor.copy_from(self)
  457. return tmp_tensor.to_numpy()
  458. def __repr__(self):
  459. self.update()
  460. data = {
  461. "layout": self._layout,
  462. "device_type": LiteDeviceType(self._device_type.value),
  463. "device_id": int(self.device_id),
  464. "is_pinned_host": bool(self._is_pinned_host),
  465. }
  466. return data.__repr__()
  467. def LiteTensorConcat(
  468. tensors, dim, device_type=LiteDeviceType.LITE_DEVICE_DEFAULT, device_id=-1
  469. ):
  470. """
  471. concat tensors at expected dim to one tensor
  472. Args:
  473. dim : the dim to act concat
  474. device_type: the result tensor device type
  475. device_id: the result tensor device id
  476. """
  477. api = _TensorAPI()._lib
  478. length = len(tensors)
  479. c_tensors = [t._tensor for t in tensors]
  480. c_tensors = (_Ctensor * length)(*c_tensors)
  481. result_tensor = LiteTensor()
  482. api.LITE_tensor_concat(
  483. cast(byref(c_tensors), POINTER(c_void_p)),
  484. length,
  485. dim,
  486. device_type,
  487. device_id,
  488. byref(result_tensor._tensor),
  489. )
  490. result_tensor.update()
  491. return result_tensor
  492. def lite_dtype_2_numpy(dtype):
  493. """
  494. convert lite dtype to corresponding numpy dtype
  495. Args:
  496. dtype(LiteDataType): source dtype
  497. """
  498. assert isinstance(
  499. dtype, LiteDataType
  500. ), "input must be LiteDataType when using lite_dtype_2_numpy."
  501. return _lite_type_to_nptypes[dtype]