You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. /**
  2. * \file src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/tensor.h"
  12. #include "function_base.h"
  13. #include "tensor_impl_base.h"
  14. #if LITE_BUILD_WITH_MGE
  15. #include "megbrain/comp_node.h"
  16. #include "megbrain/tensor.h"
  17. #include "mge/function_dft.h"
  18. #include "mge/tensor_impl.h"
  19. #endif
  20. #include <memory>
  21. using namespace lite;
  22. size_t Layout::get_elem_size() const {
  23. size_t elesize = 1;
  24. switch (data_type) {
  25. case LiteDataType::LITE_INT64:
  26. elesize = 8;
  27. break;
  28. case LiteDataType::LITE_FLOAT:
  29. case LiteDataType::LITE_INT:
  30. case LiteDataType::LITE_UINT:
  31. elesize = 4;
  32. break;
  33. case LiteDataType::LITE_HALF:
  34. case LiteDataType::LITE_INT16:
  35. case LiteDataType::LITE_UINT16:
  36. elesize = 2;
  37. break;
  38. case LiteDataType::LITE_INT8:
  39. case LiteDataType::LITE_UINT8:
  40. elesize = 1;
  41. break;
  42. default:
  43. LITE_THROW("not support data type.");
  44. }
  45. return elesize;
  46. }
  47. bool Layout::operator==(const Layout& other) const {
  48. bool equal = true;
  49. equal &= (ndim == other.ndim);
  50. equal &= (data_type == other.data_type);
  51. for (size_t i = 0; i < ndim; i++) {
  52. equal &= (shapes[i] == other.shapes[i]);
  53. }
  54. return equal;
  55. }
  56. Tensor::~Tensor() = default;
  57. Tensor::Tensor() {
  58. LITE_ERROR_HANDLER_BEGIN
  59. m_tensor_impl =
  60. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  61. "create_tensor");
  62. LITE_ERROR_HANDLER_END
  63. }
  64. Tensor::Tensor(LiteDeviceType device_type, bool is_pinned_host)
  65. : m_is_pinned_host(is_pinned_host), m_device_type(device_type) {
  66. LITE_ERROR_HANDLER_BEGIN
  67. m_tensor_impl =
  68. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  69. "create_tensor", device_type, is_pinned_host);
  70. LITE_ERROR_HANDLER_END
  71. }
  72. Tensor::Tensor(LiteDeviceType device_type, const Layout& layout, bool is_pinned_host)
  73. : m_is_pinned_host(is_pinned_host),
  74. m_layout(layout),
  75. m_device_type(device_type) {
  76. LITE_ERROR_HANDLER_BEGIN
  77. m_tensor_impl =
  78. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  79. "create_tensor", device_type, layout, is_pinned_host);
  80. LITE_ERROR_HANDLER_END
  81. }
  82. Tensor::Tensor(
  83. int device_id, LiteDeviceType device_type, const Layout& layout,
  84. bool is_pinned_host)
  85. : m_is_pinned_host(is_pinned_host),
  86. m_device_id(device_id),
  87. m_layout(layout),
  88. m_device_type(device_type) {
  89. LITE_ERROR_HANDLER_BEGIN
  90. m_tensor_impl =
  91. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  92. "create_tensor", device_id, device_type, layout, is_pinned_host);
  93. LITE_ERROR_HANDLER_END
  94. }
  95. Tensor::Tensor(
  96. int device_id, int stream_id, LiteDeviceType device_type, bool is_pinned_host)
  97. : m_is_pinned_host(is_pinned_host),
  98. m_device_id(device_id),
  99. m_device_type(device_type) {
  100. LITE_ERROR_HANDLER_BEGIN
  101. m_tensor_impl =
  102. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  103. "create_tensor", device_id, stream_id, device_type, is_pinned_host);
  104. LITE_ERROR_HANDLER_END
  105. }
  106. Tensor::Tensor(
  107. LiteBackend backend, LiteDeviceType device_type, int device_id,
  108. const Layout& layout, bool is_pinned_host) {
  109. if (backend == LiteBackend::LITE_DEFAULT) {
  110. m_tensor_impl =
  111. call_func<TensorImplDft, std::shared_ptr<lite::Tensor::TensorImplBase>>(
  112. "create_tensor", device_id, device_type, layout,
  113. is_pinned_host);
  114. } else {
  115. LITE_MARK_USED_VAR(device_type);
  116. LITE_MARK_USED_VAR(is_pinned_host);
  117. LITE_MARK_USED_VAR(layout);
  118. LITE_MARK_USED_VAR(device_id);
  119. LITE_THROW("unknow backend, enum id is : %d.");
  120. }
  121. }
  122. void Tensor::reshape(const std::vector<int>& shape) {
  123. LITE_ASSERT(m_layout.ndim > 0, "The tensor to be reshape is empty.");
  124. uint32_t length = shape.size();
  125. LITE_ASSERT(length < Layout::MAXDIM, "The ndim of reshape input is too large.");
  126. Layout new_layout = m_layout;
  127. new_layout.ndim = length;
  128. size_t total_length = get_tensor_total_size_in_byte() / m_layout.get_elem_size();
  129. uint32_t unfixed_number = 0;
  130. uint32_t unfixed_index = 0;
  131. for (uint32_t i = 0; i < length; i++) {
  132. if (shape[i] == -1) {
  133. unfixed_number += 1;
  134. unfixed_index = i;
  135. } else {
  136. LITE_ASSERT(shape[i] > 0, "The reshape inputs invalid.");
  137. new_layout.shapes[i] = shape[i];
  138. }
  139. }
  140. LITE_ASSERT(unfixed_number <= 1, "The reshape inputs invalid.");
  141. if (unfixed_number) {
  142. size_t left = total_length;
  143. for (uint32_t i = 0; i < length; i++) {
  144. if (i == unfixed_index) {
  145. continue;
  146. } else {
  147. LITE_ASSERT(
  148. left > 0 && (left % new_layout.shapes[i] == 0),
  149. "The reshape inputs invalid.");
  150. left = left / new_layout.shapes[i];
  151. }
  152. }
  153. LITE_ASSERT(left > 0, "The reshape inputs invalid.");
  154. new_layout.shapes[unfixed_index] = left;
  155. }
  156. size_t new_total = 1;
  157. for (uint32_t i = 0; i < length; i++) {
  158. new_total *= new_layout.shapes[i];
  159. }
  160. LITE_ASSERT(new_total == total_length, "The reshape inputs invalid.");
  161. m_layout = new_layout;
  162. m_tensor_impl->reshape(m_layout);
  163. }
  164. size_t Tensor::get_tensor_total_size_in_byte() const {
  165. LITE_ERROR_HANDLER_BEGIN
  166. size_t elemsize = m_layout.get_elem_size();
  167. size_t total = m_layout.ndim == 0 ? 0 : 1;
  168. for (size_t i = 0; i < m_layout.ndim; i++) {
  169. total *= m_layout.shapes[i];
  170. }
  171. return total * elemsize;
  172. LITE_ERROR_HANDLER_END
  173. }
  174. void* Tensor::get_memory_ptr() const {
  175. LITE_ERROR_HANDLER_BEGIN
  176. LITE_ASSERT(m_layout.ndim != 0, "Tensor layout is not valid when get memory ptr.");
  177. return m_tensor_impl->get_memory_ptr();
  178. LITE_ERROR_HANDLER_END
  179. }
  180. void* Tensor::get_memory_ptr(const std::vector<size_t>& idx) const {
  181. LITE_ERROR_HANDLER_BEGIN
  182. return m_tensor_impl->get_memory_ptr(idx);
  183. LITE_ERROR_HANDLER_END
  184. }
  185. std::shared_ptr<Tensor> Tensor::slice(
  186. const std::vector<size_t>& start, const std::vector<size_t>& end,
  187. const std::vector<size_t>& step) {
  188. LITE_ERROR_HANDLER_BEGIN
  189. auto ret = m_tensor_impl->slice(start, end, step);
  190. ret->update_from_implement();
  191. return ret;
  192. LITE_ERROR_HANDLER_END
  193. }
  194. void Tensor::fill_zero() {
  195. LITE_ERROR_HANDLER_BEGIN
  196. LITE_ASSERT(
  197. m_layout.ndim > 0, "fill_zero can't apply on a tensor with empty layout.");
  198. m_tensor_impl->fill_zero();
  199. LITE_ERROR_HANDLER_END
  200. }
  201. void Tensor::share_memory_with(const Tensor& src_tensor) {
  202. LITE_ERROR_HANDLER_BEGIN
  203. LITE_ASSERT(src_tensor.m_layout.ndim > 0, "To be shared tensor with empty layout.");
  204. m_tensor_impl->share_memory_with(src_tensor.m_tensor_impl.get());
  205. update_from_implement();
  206. LITE_ERROR_HANDLER_END
  207. }
  208. void Tensor::set_layout(const Layout& layout) {
  209. LITE_ERROR_HANDLER_BEGIN
  210. m_layout = layout;
  211. m_tensor_impl->set_layout(layout);
  212. LITE_ERROR_HANDLER_END
  213. }
  214. void Tensor::reset(void* prepared_data, size_t data_length_in_byte) {
  215. LITE_ERROR_HANDLER_BEGIN
  216. LITE_ASSERT(m_layout.ndim, "Tensor layout is empty, please reset with layout");
  217. LITE_ASSERT(
  218. data_length_in_byte >= get_tensor_total_size_in_byte(),
  219. "the memory reset to the tensor is too small.");
  220. m_tensor_impl->reset(prepared_data);
  221. LITE_ERROR_HANDLER_END
  222. }
  223. void Tensor::reset(void* prepared_data, const Layout& layout) {
  224. LITE_ERROR_HANDLER_BEGIN
  225. m_layout = layout;
  226. m_tensor_impl->reset(prepared_data, layout);
  227. LITE_ERROR_HANDLER_END
  228. }
  229. bool Tensor::is_continue_memory() const {
  230. LITE_ERROR_HANDLER_BEGIN
  231. return m_tensor_impl->is_continue_memory();
  232. LITE_ERROR_HANDLER_END
  233. }
  234. void Tensor::copy_from(const Tensor& src) {
  235. LITE_ERROR_HANDLER_BEGIN
  236. LITE_ASSERT(
  237. src.get_layout().ndim != 0,
  238. "when tensor copy, the src tensor layout is empty.");
  239. m_tensor_impl->copy_from(src.m_tensor_impl.get());
  240. update_from_implement();
  241. LITE_ERROR_HANDLER_END
  242. }
  243. void Tensor::update_from_implement() {
  244. LITE_ERROR_HANDLER_BEGIN
  245. m_layout = m_tensor_impl->get_layout();
  246. m_device_type = m_tensor_impl->get_device_type();
  247. m_device_id = m_tensor_impl->get_device_id();
  248. m_is_pinned_host = m_tensor_impl->is_pinned_host();
  249. LITE_ERROR_HANDLER_END
  250. }
  251. void LiteAny::type_missmatch(size_t expect, size_t get) const {
  252. LITE_THROW(ssprintf(
  253. "The type store in LiteAny is not match the visit type, type of "
  254. "storage enum is %zu, type of visit enum is %zu.",
  255. expect, get));
  256. }
  257. namespace lite {
  258. #define GET_TYPE(ctype, ENUM) \
  259. template <> \
  260. LiteAny::Type LiteAny::get_type<ctype>() const { \
  261. return ENUM; \
  262. }
  263. GET_TYPE(std::string, STRING)
  264. GET_TYPE(int32_t, INT32)
  265. GET_TYPE(uint32_t, UINT32)
  266. GET_TYPE(int8_t, INT8)
  267. GET_TYPE(uint8_t, UINT8)
  268. GET_TYPE(int64_t, INT64)
  269. GET_TYPE(uint64_t, UINT64)
  270. GET_TYPE(float, FLOAT)
  271. GET_TYPE(bool, BOOL)
  272. GET_TYPE(void*, VOID_PTR)
  273. } // namespace lite
  274. std::shared_ptr<Tensor> TensorUtils::concat(
  275. const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device,
  276. int dst_device_id) {
  277. if (tensors.size() <= 0) {
  278. return std::make_shared<Tensor>();
  279. }
  280. if (dst_device == LiteDeviceType::LITE_DEVICE_DEFAULT) {
  281. dst_device = tensors.front().get_device_type();
  282. }
  283. if (dst_device_id == -1) {
  284. dst_device_id = tensors.front().get_device_id();
  285. }
  286. bool is_pinned_host = tensors.front().is_pinned_host();
  287. auto layout = tensors.front().get_layout();
  288. LITE_ASSERT(static_cast<int>(layout.ndim) > dim, "the dim in concat is error.");
  289. size_t sum_in_dim = layout.shapes[dim];
  290. for (size_t i = 1; i < tensors.size(); ++i) {
  291. auto other_layout = tensors[i].get_layout();
  292. LITE_ASSERT(
  293. other_layout.ndim == layout.ndim,
  294. "the dim size of tensors is not same!");
  295. LITE_ASSERT(
  296. other_layout.data_type == layout.data_type,
  297. "the dtype of tensors is not same!");
  298. for (size_t j = 0; j < other_layout.ndim; ++j) {
  299. if (dim == static_cast<int>(j)) {
  300. sum_in_dim += other_layout.shapes[j];
  301. continue;
  302. }
  303. LITE_ASSERT(
  304. other_layout.shapes[j] == layout.shapes[j],
  305. "the shape of tensors is not same!");
  306. }
  307. }
  308. layout.shapes[dim] = sum_in_dim;
  309. auto result =
  310. std::make_shared<Tensor>(dst_device_id, dst_device, layout, is_pinned_host);
  311. size_t index = 0;
  312. std::vector<size_t> start(dim + 1, 0);
  313. std::vector<size_t> end(dim + 1, 0);
  314. for (int i = 0; i < dim; i++) {
  315. end[i] = layout.shapes[i];
  316. }
  317. for (size_t i = 0; i < tensors.size(); ++i) {
  318. auto&& tensor = tensors[i];
  319. auto layout = tensor.get_layout();
  320. if (layout.shapes[dim] == 0)
  321. continue;
  322. start[dim] = index;
  323. end[dim] = index + layout.shapes[dim];
  324. auto&& sub_dst = result->slice(start, end);
  325. sub_dst->copy_from(tensor);
  326. index += layout.shapes[dim];
  327. }
  328. return result;
  329. }
  330. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}