You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from ctypes import *
  10. import numpy as np
  11. from .base import _Ctensor, _lib, _LiteCObjBase
  12. from .struct import LiteDataType, LiteDeviceType, LiteIOType, Structure
  13. MAX_DIM = 7
  14. _lite_type_to_nptypes = {
  15. LiteDataType.LITE_INT: np.int32,
  16. LiteDataType.LITE_FLOAT: np.float32,
  17. LiteDataType.LITE_UINT8: np.uint8,
  18. LiteDataType.LITE_INT8: np.int8,
  19. LiteDataType.LITE_INT16: np.int16,
  20. LiteDataType.LITE_UINT16: np.uint16,
  21. LiteDataType.LITE_HALF: np.float16,
  22. }
  23. _nptype_to_lite_type = {val: key for key, val in _lite_type_to_nptypes.items()}
  24. _str_nptypes_to_lite_nptypes = {
  25. np.dtype("int32"): LiteDataType.LITE_INT,
  26. np.dtype("float32"): LiteDataType.LITE_FLOAT,
  27. np.dtype("uint8"): LiteDataType.LITE_UINT8,
  28. np.dtype("int8"): LiteDataType.LITE_INT8,
  29. np.dtype("int16"): LiteDataType.LITE_INT16,
  30. np.dtype("uint16"): LiteDataType.LITE_UINT16,
  31. np.dtype("float16"): LiteDataType.LITE_HALF,
  32. }
  33. ctype_to_lite_dtypes = {
  34. c_int: LiteDataType.LITE_INT,
  35. c_uint: LiteDataType.LITE_INT,
  36. c_float: LiteDataType.LITE_FLOAT,
  37. c_ubyte: LiteDataType.LITE_UINT8,
  38. c_byte: LiteDataType.LITE_INT8,
  39. c_short: LiteDataType.LITE_INT16,
  40. c_ushort: LiteDataType.LITE_UINT16,
  41. }
  42. _lite_dtypes_to_ctype = {
  43. LiteDataType.LITE_INT: c_int,
  44. LiteDataType.LITE_FLOAT: c_float,
  45. LiteDataType.LITE_UINT8: c_ubyte,
  46. LiteDataType.LITE_INT8: c_byte,
  47. LiteDataType.LITE_INT16: c_short,
  48. LiteDataType.LITE_UINT16: c_ushort,
  49. }
  50. class LiteLayout(Structure):
  51. """
  52. the simple layout description
  53. """
  54. _fields_ = [
  55. ("_shapes", c_size_t * MAX_DIM),
  56. ("ndim", c_size_t),
  57. ("data_type", c_int),
  58. ]
  59. def __init__(self, shape=None, dtype=None):
  60. if shape:
  61. shape = list(shape)
  62. assert len(shape) <= MAX_DIM, "Layout max dim is 7."
  63. self._shapes = (c_size_t * MAX_DIM)(*shape)
  64. self.ndim = len(shape)
  65. else:
  66. self._shapes = (c_size_t * MAX_DIM)()
  67. self.ndim = 0
  68. if not dtype:
  69. self.data_type = LiteDataType.LITE_FLOAT
  70. elif isinstance(dtype, LiteDataType):
  71. self.data_type = dtype
  72. elif type(dtype) == str:
  73. self.data_type = _str_nptypes_to_lite_nptypes[np.dtype(dtype)]
  74. elif isinstance(dtype, np.dtype):
  75. ctype = np.ctypeslib.as_ctypes_type(dtype)
  76. self.data_type = ctype_to_lite_dtypes[ctype]
  77. elif isinstance(dtype, type):
  78. self.data_type = _nptype_to_lite_type[dtype]
  79. else:
  80. raise RuntimeError("unkonw data type")
  81. @property
  82. def dtype(self):
  83. return _lite_type_to_nptypes[LiteDataType(self.data_type)]
  84. @property
  85. def shapes(self):
  86. return list(self._shapes)[0 : self.ndim]
  87. @shapes.setter
  88. def shapes(self, shape):
  89. shape = list(shape)
  90. assert len(shape) <= MAX_DIM, "Layout max dim is 7."
  91. self._shapes = (c_size_t * MAX_DIM)(*shape)
  92. self.ndim = len(shape)
  93. def __repr__(self):
  94. data = {
  95. "shapes": self.shapes,
  96. "ndim": self.ndim,
  97. "data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)],
  98. }
  99. return data.__repr__()
  100. class _LiteTensorDesc(Structure):
  101. """
  102. warpper of the MegEngine Tensor
  103. :is_pinned_host: when set, the storage memory of the tensor is pinned memory,
  104. this is used to Optimize the H2D or D2H memory copy, if the device or layout
  105. is not set, when copy form other device(CUDA) tensor, this tensor
  106. will be automatically set to pinned tensor
  107. """
  108. _fields_ = [
  109. ("is_pinned_host", c_int),
  110. ("layout", LiteLayout),
  111. ("device_type", c_int),
  112. ("device_id", c_int),
  113. ]
  114. def __init__(self):
  115. self.layout = LiteLayout()
  116. self.device_type = LiteDeviceType.LITE_CPU
  117. self.is_pinned_host = False
  118. self.device_id = 0
  119. def __repr__(self):
  120. data = {
  121. "is_pinned_host": self.is_pinned_host,
  122. "layout": LiteLayout(self.layout),
  123. "device_type": LiteDeviceType(self.device_type.value),
  124. "device_id": self.device_id,
  125. }
  126. return data.__repr__()
  127. class _TensorAPI(_LiteCObjBase):
  128. """
  129. get the api from the lib
  130. """
  131. _api_ = [
  132. ("LITE_make_tensor", [_LiteTensorDesc, POINTER(_Ctensor)]),
  133. ("LITE_set_tensor_layout", [_Ctensor, LiteLayout]),
  134. ("LITE_reset_tensor_memory", [_Ctensor, c_void_p, c_size_t]),
  135. ("LITE_reset_tensor", [_Ctensor, LiteLayout, c_void_p]),
  136. ("LITE_tensor_reshape", [_Ctensor, POINTER(c_int), c_int]),
  137. (
  138. "LITE_tensor_slice",
  139. [
  140. _Ctensor,
  141. POINTER(c_size_t),
  142. POINTER(c_size_t),
  143. POINTER(c_size_t),
  144. c_size_t,
  145. POINTER(_Ctensor),
  146. ],
  147. ),
  148. (
  149. "LITE_tensor_concat",
  150. [POINTER(_Ctensor), c_int, c_int, c_int, c_int, POINTER(_Ctensor),],
  151. ),
  152. ("LITE_tensor_fill_zero", [_Ctensor]),
  153. ("LITE_tensor_copy", [_Ctensor, _Ctensor]),
  154. ("LITE_tensor_share_memory_with", [_Ctensor, _Ctensor]),
  155. ("LITE_get_tensor_memory", [_Ctensor, POINTER(c_void_p)]),
  156. ("LITE_get_tensor_total_size_in_byte", [_Ctensor, POINTER(c_size_t)]),
  157. ("LITE_get_tensor_layout", [_Ctensor, POINTER(LiteLayout)]),
  158. ("LITE_get_tensor_device_type", [_Ctensor, POINTER(c_int)]),
  159. ("LITE_get_tensor_device_id", [_Ctensor, POINTER(c_int)]),
  160. ("LITE_destroy_tensor", [_Ctensor]),
  161. ("LITE_is_pinned_host", [_Ctensor, POINTER(c_int)]),
  162. ]
  163. class LiteTensor(object):
  164. """
  165. the tensor to hold a block of data
  166. """
  167. _api = _TensorAPI()._lib
  168. def __init__(
  169. self,
  170. layout=None,
  171. device_type=LiteDeviceType.LITE_CPU,
  172. device_id=0,
  173. is_pinned_host=False,
  174. shapes=None,
  175. dtype=None,
  176. ):
  177. """
  178. create a Tensor with layout, device, is_pinned_host or shapes, dtype,
  179. device_type, device_id, is_pinned_host param
  180. """
  181. self._tensor = _Ctensor()
  182. self._layout = LiteLayout()
  183. if layout is not None:
  184. self._layout = layout
  185. elif shapes is not None:
  186. shapes = list(shapes)
  187. self._layout = LiteLayout(shapes, dtype)
  188. self._device_type = device_type
  189. self._device_id = device_id
  190. self._is_pinned_host = is_pinned_host
  191. tensor_desc = _LiteTensorDesc()
  192. tensor_desc.layout = self._layout
  193. tensor_desc.device_type = device_type
  194. tensor_desc.device_id = device_id
  195. tensor_desc.is_pinned_host = is_pinned_host
  196. self._api.LITE_make_tensor(tensor_desc, byref(self._tensor))
  197. self.update()
  198. def __del__(self):
  199. self._api.LITE_destroy_tensor(self._tensor)
  200. def fill_zero(self):
  201. """
  202. fill the buffer memory with zero
  203. """
  204. self._api.LITE_tensor_fill_zero(self._tensor)
  205. self.update()
  206. def share_memory_with(self, src_tensor):
  207. """
  208. share the same memory with the src_tensor, the self memory will be freed
  209. """
  210. assert isinstance(src_tensor, LiteTensor)
  211. self._api.LITE_tensor_share_memory_with(self._tensor, src_tensor._tensor)
  212. self.update()
  213. @property
  214. def layout(self):
  215. self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
  216. return self._layout
  217. @layout.setter
  218. def layout(self, layout):
  219. if isinstance(layout, LiteLayout):
  220. self._layout = layout
  221. elif isinstance(layout, list):
  222. self._layout.shapes = layout
  223. self._api.LITE_set_tensor_layout(self._tensor, self._layout)
  224. @property
  225. def is_pinned_host(self):
  226. """
  227. whether the tensor is pinned tensor
  228. """
  229. pinned = c_int()
  230. self._api.LITE_is_pinned_host(self._tensor, byref(pinned))
  231. self._is_pinned_host = pinned
  232. return bool(self._is_pinned_host)
  233. @property
  234. def device_type(self):
  235. """
  236. get device of the tensor
  237. """
  238. device_type = c_int()
  239. self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
  240. self._device_type = device_type
  241. return LiteDeviceType(device_type.value)
  242. @property
  243. def device_id(self):
  244. """
  245. get device id of the tensor
  246. """
  247. device_id = c_int()
  248. self._api.LITE_get_tensor_device_id(self._tensor, byref(device_id))
  249. self._device_id = device_id.value
  250. return device_id.value
  251. @property
  252. def is_continue(self):
  253. """
  254. whether the tensor memory is continue
  255. """
  256. is_continue = c_int()
  257. self._api.LITE_is_memory_continue(self._tensor, byref(is_continue))
  258. return bool(is_continue.value)
  259. @property
  260. def nbytes(self):
  261. """
  262. get the length of the meomry in byte
  263. """
  264. length = c_size_t()
  265. self._api.LITE_get_tensor_total_size_in_byte(self._tensor, byref(length))
  266. return length.value
  267. def update(self):
  268. """
  269. update the member from C, this will auto used after slice, share
  270. """
  271. pinned = c_int()
  272. self._api.LITE_is_pinned_host(self._tensor, byref(pinned))
  273. self._is_pinned_host = pinned
  274. device_type = c_int()
  275. self._api.LITE_get_tensor_device_type(self._tensor, byref(device_type))
  276. self._device_type = device_type
  277. self._api.LITE_get_tensor_layout(self._tensor, byref(self._layout))
  278. c_types = _lite_dtypes_to_ctype[self._layout.data_type]
  279. self.np_array_type = np.ctypeslib._ctype_ndarray(
  280. c_types, list(self._layout.shapes)[0 : self._layout.ndim]
  281. )
  282. def copy_from(self, src_tensor):
  283. """
  284. copy memory form the src_tensor
  285. """
  286. assert isinstance(src_tensor, LiteTensor)
  287. self._api.LITE_tensor_copy(self._tensor, src_tensor._tensor)
  288. self.update()
  289. def reshape(self, shape):
  290. """
  291. reshape the tensor with data not change, only change the shape
  292. :param shape: int arrary of dst_shape
  293. """
  294. shape = list(shape)
  295. length = len(shape)
  296. c_shape = (c_int * length)(*shape)
  297. self._api.LITE_tensor_reshape(self._tensor, c_shape, length)
  298. self.update()
  299. def slice(self, start, end, step=None):
  300. """
  301. slice the tensor with gaven start, end, step
  302. :param start: silce begin index of each dim
  303. :param end: silce end index of each dim
  304. :param step: silce step of each dim
  305. """
  306. start = list(start)
  307. end = list(end)
  308. length = len(start)
  309. assert length == len(end), "slice with different length of start and end."
  310. if step:
  311. assert length == len(step), "slice with different length of start and step."
  312. step = list(step)
  313. else:
  314. step = [1 for i in range(length)]
  315. c_start = (c_size_t * length)(*start)
  316. c_end = (c_size_t * length)(*end)
  317. c_step = (c_size_t * length)(*step)
  318. slice_tensor = LiteTensor()
  319. self._api.LITE_tensor_slice(
  320. self._tensor, c_start, c_end, c_step, length, byref(slice_tensor._tensor)
  321. )
  322. slice_tensor.update()
  323. return slice_tensor
  324. def get_ctypes_memory(self):
  325. """
  326. get the memory of the tensor, return c_void_p of the tensor memory
  327. """
  328. mem = c_void_p()
  329. self._api.LITE_get_tensor_memory(self._tensor, byref(mem))
  330. return mem
  331. def set_data_by_share(self, data, length=0, layout=None):
  332. """
  333. share the data to the tensor
  334. param data: the data will shared to the tensor, it should be a
  335. numpy.ndarray or ctypes data
  336. """
  337. if isinstance(data, np.ndarray):
  338. assert (
  339. self.is_continue
  340. ), "set_data_by_share can only apply in continue tensor."
  341. assert (
  342. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  343. ), "set_data_by_share can only apply in cpu tensor or pinned tensor."
  344. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  345. if self.nbytes != data.nbytes:
  346. self.layout = LiteLayout(data.shape, ctype_to_lite_dtypes[c_type])
  347. self._shared_data = data
  348. data = data.ctypes.data_as(POINTER(c_type))
  349. if layout is not None:
  350. self.layout = layout
  351. else:
  352. assert length == 0 or length == self.nbytes, "the data length is not match."
  353. self._api.LITE_reset_tensor_memory(self._tensor, data, self.nbytes)
  354. def set_data_by_copy(self, data, data_length=0, layout=None):
  355. """
  356. copy the data to the tensor, the memory of the tensor must be continue
  357. param data: the data to copy to tensor, it should be list,
  358. numpy.ndarraya or ctypes with length
  359. """
  360. if layout is not None:
  361. self.layout = layout
  362. assert self.is_continue, "set_data_by_copy can only apply in continue tensor."
  363. c_type = _lite_dtypes_to_ctype[LiteDataType(self._layout.data_type)]
  364. cpu_tensor = LiteTensor(self._layout)
  365. tensor_length = self.nbytes
  366. if type(data) == list:
  367. length = len(data)
  368. assert (
  369. length * sizeof(c_type) <= tensor_length
  370. ), "the length of input data to set to the tensor is too large."
  371. cdata = (c_type * length)(*data)
  372. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, tensor_length)
  373. self.copy_from(cpu_tensor)
  374. elif type(data) == np.ndarray:
  375. self.layout = LiteLayout(data.shape, data.dtype)
  376. cpu_tensor.layout = LiteLayout(data.shape, data.dtype)
  377. cdata = data.ctypes.data_as(POINTER(c_type))
  378. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, cdata, self.nbytes)
  379. self.copy_from(cpu_tensor)
  380. else:
  381. assert (
  382. data_length == self.nbytes or layout is not None
  383. ), "when input data is ctypes, the length of input data or layout must set"
  384. self._api.LITE_reset_tensor_memory(cpu_tensor._tensor, data, tensor_length)
  385. self.copy_from(cpu_tensor)
  386. def get_data_by_share(self):
  387. """
  388. get the data in the tensor, add share the data with a new numpy, and
  389. return the numpy arrray, be careful, the data in numpy is valid before
  390. the tensor memory is write again, such as LiteNetwok forward next time.
  391. """
  392. self.update()
  393. buffer = c_void_p()
  394. self._api.LITE_get_tensor_memory(self._tensor, byref(buffer))
  395. buffer = self.np_array_type.from_address(buffer.value)
  396. return np.ctypeslib.as_array(buffer)
  397. def to_numpy(self):
  398. """
  399. get the buffer of the tensor
  400. """
  401. self.update()
  402. if self.nbytes <= 0:
  403. np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)]
  404. return np.array([], dtype=np_type)
  405. if self.is_continue and (
  406. self.is_pinned_host or self.device_type == LiteDeviceType.LITE_CPU
  407. ):
  408. ptr = c_void_p()
  409. self._api.LITE_get_tensor_memory(self._tensor, byref(ptr))
  410. np_type = _lite_type_to_nptypes[LiteDataType(self._layout.data_type)]
  411. shape = [self._layout.shapes[i] for i in range(self._layout.ndim)]
  412. np_arr = np.zeros(shape, np_type)
  413. if np_arr.nbytes:
  414. memmove(np_arr.ctypes.data_as(c_void_p), ptr, np_arr.nbytes)
  415. return np_arr
  416. else:
  417. tmp_tensor = LiteTensor(self.layout)
  418. tmp_tensor.copy_from(self)
  419. return tmp_tensor.to_numpy()
  420. def __repr__(self):
  421. self.update()
  422. data = {
  423. "layout": self._layout,
  424. "device_type": LiteDeviceType(self._device_type.value),
  425. "device_id": int(self.device_id),
  426. "is_pinned_host": bool(self._is_pinned_host),
  427. }
  428. return data.__repr__()
  429. def LiteTensorConcat(
  430. tensors, dim, device_type=LiteDeviceType.LITE_DEVICE_DEFAULT, device_id=-1
  431. ):
  432. """
  433. concat tensor in input dim to one tensor
  434. dim : the dim to act concat
  435. device_type: the result tensor device type
  436. device_id: the result tensor device id
  437. """
  438. api = _TensorAPI()._lib
  439. length = len(tensors)
  440. c_tensors = [t._tensor for t in tensors]
  441. c_tensors = (_Ctensor * length)(*c_tensors)
  442. result_tensor = LiteTensor()
  443. api.LITE_tensor_concat(
  444. cast(byref(c_tensors), POINTER(c_void_p)),
  445. length,
  446. dim,
  447. device_type,
  448. device_id,
  449. byref(result_tensor._tensor),
  450. )
  451. result_tensor.update()
  452. return result_tensor
  453. def lite_dtype_2_numpy(dtype):
  454. """
  455. convert lite dtype to corresponding numpy dtype
  456. """
  457. assert isinstance(
  458. dtype, LiteDataType
  459. ), "input must be LiteDataType when using lite_dtype_2_numpy."
  460. return _lite_type_to_nptypes[dtype]