You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. # -*- coding: utf-8 -*-
  2. from ctypes import *
  3. import numpy as np
  4. from .base import _Ctensor, _lib, _LiteCObjBase
  5. from .struct import LiteDataType, LiteDeviceType, LiteIOType, Structure
  6. MAX_DIM = 7
  7. _lite_type_to_nptypes = {
  8. LiteDataType.LITE_INT: np.int32,
  9. LiteDataType.LITE_FLOAT: np.float32,
  10. LiteDataType.LITE_UINT8: np.uint8,
  11. LiteDataType.LITE_INT8: np.int8,
  12. LiteDataType.LITE_INT16: np.int16,
  13. LiteDataType.LITE_UINT16: np.uint16,
  14. LiteDataType.LITE_HALF: np.float16,
  15. }
  16. _nptype_to_lite_type = {val: key for key, val in _lite_type_to_nptypes.items()}
  17. _str_nptypes_to_lite_nptypes = {
  18. np.dtype("int32"): LiteDataType.LITE_INT,
  19. np.dtype("float32"): LiteDataType.LITE_FLOAT,
  20. np.dtype("uint8"): LiteDataType.LITE_UINT8,
  21. np.dtype("int8"): LiteDataType.LITE_INT8,
  22. np.dtype("int16"): LiteDataType.LITE_INT16,
  23. np.dtype("uint16"): LiteDataType.LITE_UINT16,
  24. np.dtype("float16"): LiteDataType.LITE_HALF,
  25. }
  26. ctype_to_lite_dtypes = {
  27. c_int: LiteDataType.LITE_INT,
  28. c_uint: LiteDataType.LITE_INT,
  29. c_float: LiteDataType.LITE_FLOAT,
  30. c_ubyte: LiteDataType.LITE_UINT8,
  31. c_byte: LiteDataType.LITE_INT8,
  32. c_short: LiteDataType.LITE_INT16,
  33. c_ushort: LiteDataType.LITE_UINT16,
  34. }
  35. _lite_dtypes_to_ctype = {
  36. LiteDataType.LITE_INT: c_int,
  37. LiteDataType.LITE_FLOAT: c_float,
  38. LiteDataType.LITE_UINT8: c_ubyte,
  39. LiteDataType.LITE_INT8: c_byte,
  40. LiteDataType.LITE_INT16: c_short,
  41. LiteDataType.LITE_UINT16: c_ushort,
  42. }
  43. class LiteLayout(Structure):
  44. """
  45. Description of layout using in Lite. A Lite layout will be totally defined
  46. by shape and data type.
  47. Args:
  48. shape: the shape of data.
  49. dtype: data type.
  50. Note:
  51. Dims of shape should be less than 8. The supported data type defines at
  52. LiteDataType
  53. Examples:
  54. .. code-block:: python
  55. import numpy as np
  56. layout = LiteLayout([1, 4, 8, 8], LiteDataType.LITE_FLOAT)
  57. assert(layout.shape()) == [1, 4, 8, 8]
  58. assert(layout.dtype()) == LiteDataType.LITE_FLOAT
  59. """
  60. _fields_ = [
  61. ("_shapes", c_size_t * MAX_DIM),
  62. ("ndim", c_size_t),
  63. ("data_type", c_int),
  64. ]
  65. def __init__(self, shape=None, dtype=None):
  66. if shape:
  67. shape = list(shape)
  68. assert len(shape) <= MAX_DIM, "Layout max dim is 7."
  69. self._shapes = (c_size_t * MAX_DIM)(*shape)
  70. self.ndim = len(shape)
  71. else:
  72. self._shapes = (c_size_t * MAX_DIM)()
  73. self.ndim = 0
  74. if not dtype:
  75. self.data_type = LiteDataType.LITE_FLOAT
  76. elif isinstance(dtype, LiteDataType):
  77. self.data_type = dtype
  78. elif type(dtype) == str:
  79. self.data_type = _str_nptypes_to_lite_nptypes[np.dtype(dtype)]
  80. elif isinstance(dtype, np.dtype):
  81. ctype = np.ctypeslib.as_ctypes_type(dtype)
  82. self.data_type = ctype_to_lite_dtypes[ctype]
  83. elif isinstance(dtype, type):
  84. self.data_type = _nptype_to_lite_type[dtype]
  85. else:
  86. raise RuntimeError("unkonw data type")
  87. @property
  88. def dtype(self):
  89. return _lite_type_to_nptypes[LiteDataType(self.data_type)]
  90. @property
  91. def shapes(self):
  92. return list(self._shapes)[0 : self.ndim]
  93. @shapes.setter
  94. def shapes(self, shape):
  95. shape = list(shape)
  96. assert len(shape) <= MAX_DIM, "Layout max dim is 7."
  97. self._shapes = (c_size_t * MAX_DIM)(*shape)
  98. self.ndim = len(shape)
  99. def __repr__(self):
  100. data = {
  101. "shapes": self.shapes,
  102. "ndim": self.ndim,
  103. "data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)],
  104. }
  105. return data.__repr__()
  106. class _LiteTensorDesc(Structure):
  107. """
  108. warpper of the MegEngine Tensor
  109. Args:
  110. is_pinned_host: when set, the storage memory of the tensor is pinned
  111. memory. This is used to Optimize the H2D or D2H memory copy, if the
  112. device or layout is not set, when copy form other device(CUDA)
  113. tensor, this tensor will be automatically set to pinned tensor
  114. layout(LiteLayout): layout of this tensor
  115. device_type: type of device
  116. device_id: id of device
  117. """
  118. _fields_ = [
  119. ("is_pinned_host", c_int),
  120. ("layout", LiteLayout),
  121. ("device_type", c_int),
  122. ("device_id", c_int),
  123. ]
  124. def __init__(self):
  125. self.layout = LiteLayout()
  126. self.device_type = LiteDeviceType.LITE_CPU
  127. self.is_pinned_host = False
  128. self.device_id = 0
  129. def __repr__(self):
  130. data = {
  131. "is_pinned_host": self.is_pinned_host,
  132. "layout": LiteLayout(self.layout),
  133. "device_type": LiteDeviceType(self.device_type.value),
  134. "device_id": self.device_id,
  135. }
  136. return data.__repr__()
  137. class _TensorAPI(_LiteCObjBase):
  138. """
  139. Get the API from the lib
  140. """
  141. _api_ = [
  142. ("LITE_make_tensor", [_LiteTensorDesc, POINTER(_Ctensor)]),
  143. ("LITE_set_tensor_layout", [_Ctensor, LiteLayout]),
  144. ("LITE_reset_tensor_memory", [_Ctensor, c_void_p, c_size_t]),
  145. ("LITE_reset_tensor", [_Ctensor, LiteLayout, c_void_p]),
  146. ("LITE_tensor_reshape", [_Ctensor, POINTER(c_int), c_int]),
  147. (
  148. "LITE_tensor_slice",
  149. [
  150. _Ctensor,
  151. POINTER(c_size_t),
  152. POINTER(c_size_t),
  153. POINTER(c_size_t),
  154. c_size_t,
  155. POINTER(_Ctensor),
  156. ],
  157. ),
  158. (
  159. "LITE_tensor_concat",
  160. [POINTER(_Ctensor), c_int, c_int, c_int, c_int, POINTER(_Ctensor),],
  161. ),
  162. ("LITE_tensor_fill_zero", [_Ctensor]),
  163. ("LITE_tensor_copy", [_Ctensor, _Ctensor]),
  164. ("LITE_tensor_share_memory_with", [_Ctensor, _Ctensor]),
  165. ("LITE_get_tensor_memory", [_Ctensor, POINTER(c_void_p)]),
  166. ("LITE_get_tensor_total_size_in_byte", [_Ctensor, POINTER(c_size_t)]),
  167. ("LITE_get_tensor_layout", [_Ctensor, POINTER(LiteLayout)]),
  168. ("LITE_get_tensor_device_type", [_Ctensor, POINTER(c_int)]),
  169. ("LITE_get_tensor_device_id", [_Ctensor, POINTER(c_int)]),
  170. ("LITE_destroy_tensor", [_Ctensor]),
  171. ("LITE_is_pinned_host", [_Ctensor, POINTER(c_int)]),
  172. ]
  173. class LiteTensor(object):
  174. """
  175. Description of a block of data with neccessary information.
  176. Args:
  177. layout: layout of Tensor
  178. device_type: device type of Tensor
  179. device_id: device id of Tensor
  180. is_pinned_host: when set, the storage memory of the tensor is pinned
  181. memory. This is used to Optimize the H2D or D2H memory copy, if the
  182. device or layout is not set, when copy form other device(CUDA)
  183. tensor, this tensor will be automatically set to pinned tensor
  184. shapes: the shape of data
  185. dtype: data type
  186. Note:
  187. Dims of shape should be less than 8. The supported data type defines at
  188. LiteDataType
  189. """
  190. _api = _TensorAPI()._lib
  191. def __init__(
  192. self,
  193. layout=None,
  194. device_type=LiteDeviceType.LITE_CPU,
  195. device_id=0,
  196. is_pinned_host=False,
  197. shapes=None,
  198. dtype=None,
  199. physic_construct=True,
  200. ):
  201. self._tensor = _Ctensor()
  202. self._layout = LiteLayout()
  203. if layout is not None:
  204. self._layout = layout
  205. elif shapes is not None:
  206. shapes = list(shapes)
  207. self._layout = LiteLayout(shapes, dtype)
  208. self._device_type = device_type
  209. self._device_id = device_id
  210. self._is_pinned_host = is_pinned_host
  211. tensor_desc = _LiteTensorDesc()
  212. tensor_desc.layout = self._layout
  213. tensor_desc.device_type = device_type
  214. tensor_desc.device_id = device_id
  215. tensor_desc.is_pinned_host = is_pinned_host
  216. if physic_construct:
  217. self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
  218. self.update()
  219. def __del__(self):
  220. self._api.LITE_destroy_tensor(self._tensor)
  221. def fill_zero(self):
  222. """
  223. fill the buffer memory with zero
  224. """
  225. self._api.LITE_tensor_fill_zero(self._tensor)
  226. self.update()
  227. def share_memory_with(self, src_tensor):
  228. """
  229. share the same memory with the ``src_tensor``, the self memory will be
  230. freed
  231. Args:
  232. src_tensor: the source tensor that will share memory with this tensor
  233. """
  234. assert isinstance(src_tensor, LiteTensor)
  235. self._api.LITE_tensor_share_memory_with(self._tensor, src_tensor._tensor)
  236. self.update()
  237. @property
  238. def layout(self):
  239. self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
  240. return self._layout
  241. @layout.setter
  242. def layout(self, layout):
  243. if isinstance(layout, LiteLayout):
  244. self._layout = layout
  245. elif isinstance(layout, list):
  246. self._layout.shapes = layout
  247. self._api.LITE_set_tensor_layout(self._tensor, self._layout)
  248. @property
  249. def is_pinned_host(self):
  250. """
  251. whether the tensor is pinned tensor
  252. """
  253. pinned = c_int()
  254. self._api.LITE_is_pinned_host(self._tensor, byref(pinned))
  255. self._is_pinned_host = pinned
  256. return bool(self._is_pinned_host)
  257. @property
  258. def device_type(self):
  259. """
  260. get device type of the tensor
  261. """
  262. device_type = c_int()
  263. self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
  264. self._device_type = device_type
  265. return LiteDeviceType(device_type.value)
  266. @property
  267. def device_id(self):
  268. """
  269. get device id of the tensor
  270. """
  271. device_id = c_int()
  272. self._api.LITE_get_tensor_device_id(self._tensor, byref(device_id))
  273. self._device_id = device_id.value
  274. return device_id.value
  275. @property
  276. def is_continue(self):
  277. """
  278. whether the tensor memory is continue
  279. """
  280. is_continue = c_int()
  281. self._api.LITE_is_memory_continue(self._tensor, byref(is_continue))
  282. return bool(is_continue.value)
  283. @property
  284. def nbytes(self):
  285. """
  286. get the length of the meomry in byte
  287. """
  288. length = c_size_t()
  289. self._api.LITE_get_tensor_total_size_in_byte(self._tensor, byref(length))
  290. return length.value
  291. def update(self):
  292. """
  293. update the member from C, this will auto used after slice, share
  294. """
  295. pinned = c_int()
  296. self._api.LITE_is_pinned_host(self._tensor, byref(pinned))
  297. self._is_pinned_host = pinned
  298. device_type = c_int()
  299. self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
  300. self._device_type = device_type
  301. self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
  302. c_types = _lite_dtypes_to_ctype[self._layout.data_type]
  303. self.np_array_type = np.ctypeslib._ctype_ndarray(
  304. c_types, list(self._layout.shapes)[0 : self._layout.ndim]
  305. )
  306. def copy_from(self, src_tensor):
  307. """
  308. copy memory form the src_tensor
  309. Args:
  310. src_tensor: source tensor
  311. """
  312. assert isinstance(src_tensor, LiteTensor)
  313. self._api.LITE_tensor_copy(self._tensor, src_tensor._tensor)
  314. self.update()
  315. def reshape(self, shape):
  316. """
  317. reshape the tensor with data not change.
  318. Args:
  319. shape: target shape
  320. """
  321. shape = list(shape)
  322. length = len(shape)
  323. c_shape = (c_int * length)(*shape)
  324. self._api.LITE_tensor_reshape(self._tensor, c_shape, length)
  325. self.update()
  326. def slice(self, start, end, step=None):
  327. """
  328. slice the tensor with gaven start, end, step
  329. Args:
  330. start: silce begin index of each dim
  331. end: silce end index of each dim
  332. step: silce step of each dim
  333. """
  334. start = list(start)
  335. end = list(end)
  336. length = len(start)
  337. assert length == len(end), "slice with different length of start and end."
  338. if step:
  339. assert length == len(step), "slice with different length of start and step."
  340. step = list(step)
  341. else:
  342. step = [1 for i in range(length)]
  343. c_start = (c_size_t * length)(*start)
  344. c_end = (c_size_t * length)(*end)
  345. c_step = (c_size_t * length)(*step)
  346. slice_tensor = LiteTensor(physic_construct=False)
  347. self._api.LITE_tensor_slice(
  348. self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor),
  349. )
  350. slice_tensor.update()
  351. return slice_tensor
  352. def get_ctypes_memory(self):
  353. """
  354. get the memory of the tensor, return c_void_p of the tensor memory
  355. """
  356. mem = c_void_p()
  357. self._api.LITE_get_tensor_memory(self._tensor, byref(mem))
  358. return mem
  359. def set_data_by_share(self, data, length=0, layout=None):
  360. """
  361. share the data to the tensor
  362. Args:
  363. data: the data will shared to the tensor, it should be a
  364. numpy.ndarray or ctypes data
  365. """
  366. if isinstance(data, np.ndarray):
  367. assert data.flags[
  368. "C_CONTIGUOUS"
  369. ], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_share"
  370. assert (
  371. self.is_continue
  372. ), "set_data_by_share can only apply in continue tensor."
  373. assert (
  374. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  375. ), "set_data_by_share can only apply in cpu tensor or pinned tensor."
  376. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  377. if self.nbytes != data.nbytes:
  378. self.layout = LiteLayout(data.shape, ctype_to_lite_dtypes[c_type])
  379. self._shared_data = data
  380. data = data.ctypes.data_as(POINTER(c_type))
  381. if layout is not None:
  382. self.layout = layout
  383. else:
  384. assert length == 0 or length == self.nbytes, "the data length is not match."
  385. self._api.LITE_reset_tensor_memory(self._tensor, data, self.nbytes)
  386. def set_data_by_copy(self, data, data_length=0, layout=None):
  387. """
  388. copy the data to the tensor
  389. Args:
  390. data: the data to copy to tensor, it should be list, numpy.ndarraya
  391. or ctypes with length
  392. data_length: length of data in bytes
  393. layout: layout of data
  394. """
  395. if layout is not None:
  396. self.layout = layout
  397. assert self.is_continue, "set_data_by_copy can only apply in continue tensor."
  398. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  399. cpu_tensor = LiteTensor(self._layout)
  400. tensor_length = self.nbytes
  401. if type(data) == list:
  402. length = len(data)
  403. assert (
  404. length * sizeof(c_type) <= tensor_length
  405. ), "the length of input data to set to the tensor is too large."
  406. cdata = (c_type * length)(*data)
  407. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, tensor_length)
  408. self.copy_from(cpu_tensor)
  409. elif type(data) == np.ndarray:
  410. assert data.flags[
  411. "C_CONTIGUOUS"
  412. ], "input numpy is not continuous, please call input = np.ascontiguousarray(input) before call set_data_by_copy"
  413. self.layout = LiteLayout(data.shape, data.dtype)
  414. cpu_tensor.layout = LiteLayout(data.shape, data.dtype)
  415. cdata = data.ctypes.data_as(POINTER(c_type))
  416. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, self.nbytes)
  417. self.copy_from(cpu_tensor)
  418. else:
  419. assert (
  420. data_length == self.nbytes or layout is not None
  421. ), "when input data is ctypes, the length of input data or layout must set"
  422. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, data, tensor_length)
  423. self.copy_from(cpu_tensor)
  424. def get_data_by_share(self):
  425. """
  426. get the data in the tensor, add share the data with a new numpy, and
  427. return the numpy arrray
  428. Note:
  429. Be careful, the data in numpy is valid before the tensor memory is
  430. write again, such as LiteNetwok forward next time.
  431. """
  432. self.update()
  433. buffer = c_void_p()
  434. self._api.LITE_get_tensor_memory(self._tensor, byref(buffer))
  435. buffer = self.np_array_type.from_address(buffer.value)
  436. return np.ctypeslib.as_array(buffer)
  437. def to_numpy(self):
  438. """
  439. get the buffer of the tensor
  440. """
  441. self.update()
  442. if self.nbytes <= 0:
  443. np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)]
  444. return np.array([], dtype=np_type)
  445. if self.is_continue and (
  446. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  447. ):
  448. ptr = c_void_p()
  449. self._api.LITE_get_tensor_memory(self._tensor, byref(ptr))
  450. np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)]
  451. shape = [self._layout.shapes[i] for i in range(self._layout.ndim)]
  452. np_arr = np.zeros(shape, np_type)
  453. if np_arr.nbytes:
  454. memmove(np_arr.ctypes.data_as(c_void_p), ptr, np_arr.nbytes)
  455. return np_arr
  456. else:
  457. tmp_tensor = LiteTensor(self.layout)
  458. tmp_tensor.copy_from(self)
  459. return tmp_tensor.to_numpy()
  460. def __repr__(self):
  461. self.update()
  462. data = {
  463. "layout": self._layout,
  464. "device_type": LiteDeviceType(self._device_type.value),
  465. "device_id": int(self.device_id),
  466. "is_pinned_host": bool(self._is_pinned_host),
  467. }
  468. return data.__repr__()
  469. def LiteTensorConcat(
  470. tensors, dim, device_type=LiteDeviceType.LITE_DEVICE_DEFAULT, device_id=-1
  471. ):
  472. """
  473. concat tensors at expected dim to one tensor
  474. Args:
  475. dim : the dim to act concat
  476. device_type: the result tensor device type
  477. device_id: the result tensor device id
  478. """
  479. api = _TensorAPI()._lib
  480. length = len(tensors)
  481. c_tensors = [t._tensor for t in tensors]
  482. c_tensors = (_Ctensor * length)(*c_tensors)
  483. result_tensor = LiteTensor(physic_construct=False)
  484. api.LITE_tensor_concat(
  485. cast(byref(c_tensors), POINTER(c_void_p)),
  486. length,
  487. dim,
  488. device_type,
  489. device_id,
  490. byref(result_tensor._tensor),
  491. )
  492. result_tensor.update()
  493. return result_tensor
  494. def lite_dtype_2_numpy(dtype):
  495. """
  496. convert lite dtype to corresponding numpy dtype
  497. Args:
  498. dtype(LiteDataType): source dtype
  499. """
  500. assert isinstance(
  501. dtype, LiteDataType
  502. ), "input must be LiteDataType when using lite_dtype_2_numpy."
  503. return _lite_type_to_nptypes[dtype]