You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import threading
  10. import numpy as np
  11. from .base import *
  12. from .struct import *
  13. from .tensor import *
  14. class TensorBatchCollector:
  15. """
  16. this is a tensor utils to collect subtensor in batch continuous
  17. """
  18. def __init__(
  19. self,
  20. shape,
  21. dtype=LiteDataType.LITE_INT8,
  22. device_type=LiteDeviceType.LITE_CUDA,
  23. device_id=0,
  24. is_pinned_host=False,
  25. tensor=None,
  26. ):
  27. self._mutex = threading.Lock()
  28. self.dev_type = device_type
  29. self.is_pinned_host = is_pinned_host
  30. self.dev_id = device_id
  31. self.shape = shape
  32. self.dtype = LiteLayout(dtype=dtype).data_type
  33. self._free_list = list(range(self.shape[0]))
  34. if tensor is not None:
  35. assert (
  36. tensor.layout.shapes[0 : tensor.layout.ndim] == shape
  37. ), "The tensor set to TensorBatchCollector is not right."
  38. self._tensor = tensor
  39. self.dtype = tensor.layout.data_type
  40. self.device_type = tensor.device_type
  41. self.device_id = tensor.device_type
  42. else:
  43. self._tensor = LiteTensor(
  44. LiteLayout(shape, dtype), device_type, device_id, is_pinned_host
  45. )
  46. def collect_id(self, array, batch_id):
  47. if isinstance(array, np.ndarray):
  48. shape = array.shape
  49. assert list(shape) == self.shape[1:]
  50. in_dtype = ctype_to_lite_dtypes[np.ctypeslib.as_ctypes_type(array.dtype)]
  51. assert in_dtype == self.dtype
  52. # get the batch index
  53. with self._mutex:
  54. if batch_id in self._free_list:
  55. self._free_list.remove(batch_id)
  56. # get the subtensor
  57. subtensor = self._tensor.slice([batch_id], [batch_id + 1])
  58. if subtensor.device_type == LiteDeviceType.LITE_CPU:
  59. subtensor.set_data_by_copy(array)
  60. else:
  61. pinned_tensor = LiteTensor(
  62. subtensor.layout, self.dev_type, self.dev_id, True
  63. )
  64. pinned_tensor.set_data_by_share(array)
  65. subtensor.copy_from(pinned_tensor)
  66. else:
  67. assert isinstance(array, LiteTensor)
  68. ndim = array.layout.ndim
  69. shape = list(array.layout.shapes)[0:ndim]
  70. assert list(shape) == self.shape[1:]
  71. in_dtype = array.layout.data_type
  72. assert in_dtype == self.dtype
  73. # get the batch index
  74. with self._mutex:
  75. if batch_id in self._free_list:
  76. self._free_list.remove(batch_id)
  77. # get the subtensor
  78. subtensor = self._tensor.slice([batch_id], [batch_id + 1])
  79. subtensor.copy_from(array)
  80. return batch_id
  81. def collect(self, array):
  82. with self._mutex:
  83. if len(self._free_list) == 0:
  84. return -1
  85. idx = self._free_list.pop(0)
  86. return self.collect_id(array, idx)
  87. def collect_by_ctypes(self, data, length):
  88. """
  89. collect with ctypes data input
  90. """
  91. with self._mutex:
  92. if len(self._free_list) == 0:
  93. return -1
  94. idx = self._free_list.pop(0)
  95. # get the subtensor
  96. subtensor = self._tensor.slice([idx], [idx + 1])
  97. if subtensor.device_type == LiteDeviceType.LITE_CPU:
  98. subtensor.set_data_by_copy(data, length)
  99. else:
  100. pinned_tensor = LiteTensor(
  101. subtensor.layout, self.dev_type, self.dev_id, True
  102. )
  103. pinned_tensor.set_data_by_share(data, length)
  104. subtensor.copy_from(pinned_tensor)
  105. def free(self, indexes):
  106. with self._mutex:
  107. self._free_list.extend(indexes)
  108. def get(self):
  109. return self._tensor
  110. def to_numpy(self):
  111. return self._tensor.to_numpy()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台