You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from ctypes import *
  10. import numpy as np
  11. from .base import _Ctensor, _lib, _LiteCObjBase
  12. from .struct import LiteDataType, LiteDeviceType, LiteIOType, Structure
  13. MAX_DIM = 7
  14. _lite_type_to_nptypes = {
  15. LiteDataType.LITE_INT: np.int32,
  16. LiteDataType.LITE_FLOAT: np.float32,
  17. LiteDataType.LITE_UINT8: np.uint8,
  18. LiteDataType.LITE_INT8: np.int8,
  19. LiteDataType.LITE_INT16: np.int16,
  20. LiteDataType.LITE_UINT16: np.uint16,
  21. LiteDataType.LITE_HALF: np.float16,
  22. }
  23. _nptype_to_lite_type = {val: key for key, val in _lite_type_to_nptypes.items()}
  24. _str_nptypes_to_lite_nptypes = {
  25. np.dtype("int32"): LiteDataType.LITE_INT,
  26. np.dtype("float32"): LiteDataType.LITE_FLOAT,
  27. np.dtype("uint8"): LiteDataType.LITE_UINT8,
  28. np.dtype("int8"): LiteDataType.LITE_INT8,
  29. np.dtype("int16"): LiteDataType.LITE_INT16,
  30. np.dtype("uint16"): LiteDataType.LITE_UINT16,
  31. np.dtype("float16"): LiteDataType.LITE_HALF,
  32. }
  33. ctype_to_lite_dtypes = {
  34. c_int: LiteDataType.LITE_INT,
  35. c_uint: LiteDataType.LITE_INT,
  36. c_float: LiteDataType.LITE_FLOAT,
  37. c_ubyte: LiteDataType.LITE_UINT8,
  38. c_byte: LiteDataType.LITE_INT8,
  39. c_short: LiteDataType.LITE_INT16,
  40. c_ushort: LiteDataType.LITE_UINT16,
  41. }
  42. _lite_dtypes_to_ctype = {
  43. LiteDataType.LITE_INT: c_int,
  44. LiteDataType.LITE_FLOAT: c_float,
  45. LiteDataType.LITE_UINT8: c_ubyte,
  46. LiteDataType.LITE_INT8: c_byte,
  47. LiteDataType.LITE_INT16: c_short,
  48. LiteDataType.LITE_UINT16: c_ushort,
  49. }
  50. class LiteLayout(Structure):
  51. """
  52. the simple layout description
  53. """
  54. _fields_ = [
  55. ("_shapes", c_size_t * MAX_DIM),
  56. ("ndim", c_size_t),
  57. ("data_type", c_int),
  58. ]
  59. def __init__(self, shape=None, dtype=None):
  60. if shape:
  61. shape = list(shape)
  62. assert len(shape) <= MAX_DIM, "Layout max dim is 7."
  63. self._shapes = (c_size_t * MAX_DIM)(*shape)
  64. self.ndim = len(shape)
  65. else:
  66. self._shapes = (c_size_t * MAX_DIM)()
  67. self.ndim = 0
  68. if not dtype:
  69. self.data_type = LiteDataType.LITE_FLOAT
  70. elif isinstance(dtype, LiteDataType):
  71. self.data_type = dtype
  72. elif type(dtype) == str:
  73. self.data_type = _str_nptypes_to_lite_nptypes[np.dtype(dtype)]
  74. elif isinstance(dtype, np.dtype):
  75. ctype = np.ctypeslib.as_ctypes_type(dtype)
  76. self.data_type = ctype_to_lite_dtypes[ctype]
  77. elif isinstance(dtype, type):
  78. self.data_type = _nptype_to_lite_type[dtype]
  79. else:
  80. raise RuntimeError("unkonw data type")
  81. @property
  82. def dtype(self):
  83. return _lite_type_to_nptypes[LiteDataType(self.data_type)]
  84. @property
  85. def shapes(self):
  86. return list(self._shapes)[0 : self.ndim]
  87. @shapes.setter
  88. def shapes(self, shape):
  89. shape = list(shape)
  90. assert len(shape) <= MAX_DIM, "Layout max dim is 7."
  91. self._shapes = (c_size_t * MAX_DIM)(*shape)
  92. self.ndim = len(shape)
  93. def __repr__(self):
  94. data = {
  95. "shapes": self.shapes,
  96. "ndim": self.ndim,
  97. "data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)],
  98. }
  99. return data.__repr__()
  100. class _LiteTensorDesc(Structure):
  101. """
  102. warpper of the MegEngine Tensor
  103. :is_pinned_host: when set, the storage memory of the tensor is pinned memory,
  104. this is used to Optimize the H2D or D2H memory copy, if the device or layout
  105. is not set, when copy form other device(CUDA) tensor, this tensor
  106. will be automatically set to pinned tensor
  107. """
  108. _fields_ = [
  109. ("is_pinned_host", c_int),
  110. ("layout", LiteLayout),
  111. ("device_type", c_int),
  112. ("device_id", c_int),
  113. ]
  114. def __init__(self):
  115. self.layout = LiteLayout()
  116. self.device_type = LiteDeviceType.LITE_CPU
  117. self.is_pinned_host = False
  118. self.device_id = 0
  119. def __repr__(self):
  120. data = {
  121. "is_pinned_host": self.is_pinned_host,
  122. "layout": LiteLayout(self.layout),
  123. "device_type": LiteDeviceType(self.device_type.value),
  124. "device_id": self.device_id,
  125. }
  126. return data.__repr__()
  127. class _TensorAPI(_LiteCObjBase):
  128. """
  129. get the api from the lib
  130. """
  131. _api_ = [
  132. ("LITE_make_tensor", [_LiteTensorDesc, POINTER(_Ctensor)]),
  133. ("LITE_set_tensor_layout", [_Ctensor, LiteLayout]),
  134. ("LITE_reset_tensor_memory", [_Ctensor, c_void_p, c_size_t]),
  135. ("LITE_reset_tensor", [_Ctensor, LiteLayout, c_void_p]),
  136. ("LITE_tensor_reshape", [_Ctensor, POINTER(c_int), c_int]),
  137. (
  138. "LITE_tensor_slice",
  139. [
  140. _Ctensor,
  141. POINTER(c_size_t),
  142. POINTER(c_size_t),
  143. POINTER(c_size_t),
  144. c_size_t,
  145. POINTER(_Ctensor),
  146. ],
  147. ),
  148. (
  149. "LITE_tensor_concat",
  150. [POINTER(_Ctensor), c_int, c_int, c_int, c_int, POINTER(_Ctensor),],
  151. ),
  152. ("LITE_tensor_fill_zero", [_Ctensor]),
  153. ("LITE_tensor_copy", [_Ctensor, _Ctensor]),
  154. ("LITE_tensor_share_memory_with", [_Ctensor, _Ctensor]),
  155. ("LITE_get_tensor_memory", [_Ctensor, POINTER(c_void_p)]),
  156. ("LITE_get_tensor_total_size_in_byte", [_Ctensor, POINTER(c_size_t)]),
  157. ("LITE_get_tensor_layout", [_Ctensor, POINTER(LiteLayout)]),
  158. ("LITE_get_tensor_device_type", [_Ctensor, POINTER(c_int)]),
  159. ("LITE_get_tensor_device_id", [_Ctensor, POINTER(c_int)]),
  160. ("LITE_destroy_tensor", [_Ctensor]),
  161. ("LITE_is_pinned_host", [_Ctensor, POINTER(c_int)]),
  162. ]
  163. class LiteTensor(object):
  164. """
  165. the tensor to hold a block of data
  166. """
  167. _api = _TensorAPI()._lib
  168. def __init__(
  169. self,
  170. layout=None,
  171. device_type=LiteDeviceType.LITE_CPU,
  172. device_id=0,
  173. is_pinned_host=False,
  174. shapes=None,
  175. dtype=None,
  176. ):
  177. """
  178. create a Tensor with layout, device, is_pinned_host or shapes, dtype,
  179. device_type, device_id, is_pinned_host param
  180. """
  181. self._tensor = _Ctensor()
  182. self._layout = LiteLayout()
  183. if layout is not None:
  184. self._layout = layout
  185. elif shapes is not None:
  186. shapes = list(shapes)
  187. self._layout = LiteLayout(shapes, dtype)
  188. self._device_type = device_type
  189. self._device_id = device_id
  190. self._is_pinned_host = is_pinned_host
  191. tensor_desc = _LiteTensorDesc()
  192. tensor_desc.layout = self._layout
  193. tensor_desc.device_type = device_type
  194. tensor_desc.device_id = device_id
  195. tensor_desc.is_pinned_host = is_pinned_host
  196. self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
  197. def __del__(self):
  198. self._api.LITE_destroy_tensor(self._tensor)
  199. def fill_zero(self):
  200. """
  201. fill the buffer memory with zero
  202. """
  203. self._api.LITE_tensor_fill_zero(self._tensor)
  204. self.update()
  205. def share_memory_with(self, src_tensor):
  206. """
  207. share the same memory with the src_tensor, the self memory will be freed
  208. """
  209. assert isinstance(src_tensor, LiteTensor)
  210. self._api.LITE_tensor_share_memory_with(self._tensor, src_tensor._tensor)
  211. self.update()
  212. @property
  213. def layout(self):
  214. self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
  215. return self._layout
  216. @layout.setter
  217. def layout(self, layout):
  218. if isinstance(layout, LiteLayout):
  219. self._layout = layout
  220. elif isinstance(layout, list):
  221. self._layout.shapes = layout
  222. self._api.LITE_set_tensor_layout(self._tensor, self._layout)
  223. @property
  224. def is_pinned_host(self):
  225. """
  226. whether the tensor is pinned tensor
  227. """
  228. pinned = c_int()
  229. self._api.LITE_is_pinned_host(self._tensor, byref(pinned))
  230. self._is_pinned_host = pinned
  231. return bool(self._is_pinned_host)
  232. @property
  233. def device_type(self):
  234. """
  235. get device of the tensor
  236. """
  237. device_type = c_int()
  238. self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
  239. self._device_type = device_type
  240. return LiteDeviceType(device_type.value)
  241. @property
  242. def device_id(self):
  243. """
  244. get device id of the tensor
  245. """
  246. device_id = c_int()
  247. self._api.LITE_get_tensor_device_id(self._tensor, byref(device_id))
  248. self._device_id = device_id.value
  249. return device_id.value
  250. @property
  251. def is_continue(self):
  252. """
  253. whether the tensor memory is continue
  254. """
  255. is_continue = c_int()
  256. self._api.LITE_is_memory_continue(self._tensor, byref(is_continue))
  257. return bool(is_continue.value)
  258. @property
  259. def nbytes(self):
  260. """
  261. get the length of the meomry in byte
  262. """
  263. length = c_size_t()
  264. self._api.LITE_get_tensor_total_size_in_byte(self._tensor, byref(length))
  265. return length.value
  266. def update(self):
  267. """
  268. update the member from C, this will auto used after slice, share
  269. """
  270. pinned = c_int()
  271. self._api.LITE_is_pinned_host(self._tensor, byref(pinned))
  272. self._is_pinned_host = pinned
  273. device_type = c_int()
  274. self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
  275. self._device_type = device_type
  276. self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
  277. def copy_from(self, src_tensor):
  278. """
  279. copy memory form the src_tensor
  280. """
  281. assert isinstance(src_tensor, LiteTensor)
  282. self._api.LITE_tensor_copy(self._tensor, src_tensor._tensor)
  283. self.update()
  284. def reshape(self, shape):
  285. """
  286. reshape the tensor with data not change, only change the shape
  287. :param shape: int arrary of dst_shape
  288. """
  289. shape = list(shape)
  290. length = len(shape)
  291. c_shape = (c_int * length)(*shape)
  292. self._api.LITE_tensor_reshape(self._tensor, c_shape, length)
  293. self.update()
  294. def slice(self, start, end, step=None):
  295. """
  296. slice the tensor with gaven start, end, step
  297. :param start: silce begin index of each dim
  298. :param end: silce end index of each dim
  299. :param step: silce step of each dim
  300. """
  301. start = list(start)
  302. end = list(end)
  303. length = len(start)
  304. assert length == len(end), "slice with different length of start and end."
  305. if step:
  306. assert length == len(step), "slice with different length of start and step."
  307. step = list(step)
  308. else:
  309. step = [1 for i in range(length)]
  310. c_start = (c_size_t * length)(*start)
  311. c_end = (c_size_t * length)(*end)
  312. c_step = (c_size_t * length)(*step)
  313. slice_tensor = LiteTensor()
  314. self._api.LITE_tensor_slice(
  315. self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor)
  316. )
  317. slice_tensor.update()
  318. return slice_tensor
  319. def get_ctypes_memory(self):
  320. """
  321. get the memory of the tensor, return c_void_p of the tensor memory
  322. """
  323. mem = c_void_p()
  324. self._api.LITE_get_tensor_memory(self._tensor, byref(mem))
  325. return mem
  326. def set_data_by_share(self, data, length=0, layout=None):
  327. """
  328. share the data to the tensor
  329. param data: the data will shared to the tensor, it should be a
  330. numpy.ndarray or ctypes data
  331. """
  332. if isinstance(data, np.ndarray):
  333. assert (
  334. self.is_continue
  335. ), "set_data_by_share can only apply in continue tensor."
  336. assert (
  337. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  338. ), "set_data_by_share can only apply in cpu tensor or pinned tensor."
  339. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  340. if self.nbytes != data.nbytes:
  341. self.layout = LiteLayout(data.shape, ctype_to_lite_dtypes[c_type])
  342. self._shared_data = data
  343. data = data.ctypes.data_as(POINTER(c_type))
  344. if layout is not None:
  345. self.layout = layout
  346. else:
  347. assert length == 0 or length == self.nbytes, "the data length is not match."
  348. self._api.LITE_reset_tensor_memory(self._tensor, data, self.nbytes)
  349. def set_data_by_copy(self, data, data_length=0, layout=None):
  350. """
  351. copy the data to the tensor
  352. param data: the data to copy to tensor, it should be list,
  353. numpy.ndarraya or ctypes with length
  354. """
  355. if layout is not None:
  356. self.layout = layout
  357. assert self.is_continue, "set_data_by_copy can only apply in continue tensor."
  358. assert (
  359. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  360. ), "set_data_by_copy can only apply in cpu tensor or pinned tensor."
  361. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  362. tensor_memory = c_void_p()
  363. if type(data) == list:
  364. length = len(data)
  365. self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory))
  366. tensor_length = self.nbytes
  367. assert (
  368. length * sizeof(c_type) <= tensor_length
  369. ), "the length of input data to set to the tensor is too large."
  370. arr = (c_type * length)(*data)
  371. memmove(tensor_memory, arr, sizeof(c_type) * length)
  372. elif type(data) == np.ndarray:
  373. if self.nbytes != data.nbytes:
  374. self.layout = LiteLayout(data.shape, data.dtype)
  375. arr = data.ctypes.data_as(POINTER(c_type))
  376. self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory))
  377. assert self.nbytes == data.nbytes
  378. memmove(tensor_memory, arr, self.nbytes)
  379. else:
  380. assert (
  381. data_length == self.nbytes or layout is not None
  382. ), "when input data is ctypes, the length of input data or layout must set"
  383. self._api.LITE_get_tensor_memory(self._tensor, byref(tensor_memory))
  384. memmove(tensor_memory, data, data_length)
  385. def get_data_by_share(self):
  386. """
  387. get the data in the tensor, add share the data with a new numpy, and
  388. return the numpy arrray, be careful, the data in numpy is valid before
  389. the tensor memory is write again, such as LiteNetwok forward next time.
  390. """
  391. assert self.is_continue, "get_data_by_share can only apply in continue tensor."
  392. assert (
  393. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  394. ), "get_data_by_share can only apply in CPU tensor or cpu pinned tensor."
  395. memory = self.get_ctypes_memory()
  396. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  397. pnt = cast(memory, POINTER(c_type))
  398. return np.ctypeslib.as_array(pnt, self._layout.shapes)
  399. def to_numpy(self):
  400. """
  401. get the buffer of the tensor
  402. """
  403. self.update()
  404. if self.nbytes <= 0:
  405. np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)]
  406. return np.array([], dtype=np_type)
  407. if self.is_continue and (
  408. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  409. ):
  410. ptr = c_void_p()
  411. self._api.LITE_get_tensor_memory(self._tensor, byref(ptr))
  412. np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)]
  413. shape = [self._layout.shapes[i] for i in range(self._layout.ndim)]
  414. np_arr = np.zeros(shape, np_type)
  415. if np_arr.nbytes:
  416. memmove(np_arr.ctypes.data_as(c_void_p), ptr, np_arr.nbytes)
  417. return np_arr
  418. else:
  419. tmp_tensor = LiteTensor(self.layout)
  420. tmp_tensor.copy_from(self)
  421. return tmp_tensor.to_numpy()
  422. def __repr__(self):
  423. self.update()
  424. data = {
  425. "layout": self._layout,
  426. "device_type": LiteDeviceType(self._device_type.value),
  427. "device_id": int(self.device_id),
  428. "is_pinned_host": bool(self._is_pinned_host),
  429. }
  430. return data.__repr__()
  431. def LiteTensorConcat(
  432. tensors, dim, device_type=LiteDeviceType.LITE_DEVICE_DEFAULT, device_id=-1
  433. ):
  434. """
  435. concat tensor in input dim to one tensor
  436. dim : the dim to act concat
  437. device_type: the result tensor device type
  438. device_id: the result tensor device id
  439. """
  440. api = _TensorAPI()._lib
  441. length = len(tensors)
  442. c_tensors = [t._tensor for t in tensors]
  443. c_tensors = (_Ctensor * length)(*c_tensors)
  444. result_tensor = LiteTensor()
  445. api.LITE_tensor_concat(
  446. cast(byref(c_tensors), POINTER(c_void_p)),
  447. length,
  448. dim,
  449. device_type,
  450. device_id,
  451. byref(result_tensor._tensor),
  452. )
  453. result_tensor.update()
  454. return result_tensor
  455. def lite_dtype_2_numpy(dtype):
  456. """
  457. convert lite dtype to corresponding numpy dtype
  458. """
  459. assert isinstance(
  460. dtype, LiteDataType
  461. ), "input must be LiteDataType when using lite_dtype_2_numpy."
  462. return _lite_type_to_nptypes[dtype]