You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

small_vector.h 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. /**
  2. * \file dnn/include/megdnn/thin/small_vector.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. //===- llvm/ADT/SmallVector.h - 'Normally small' vectors --------*- C++ -*-===//
  12. //
  13. // The LLVM Compiler Infrastructure
  14. //
  15. // This file is distributed under the University of Illinois Open Source
  16. // License. See LICENSE.TXT for details.
  17. //
  18. //===----------------------------------------------------------------------===//
  19. //
  20. // This file defines the SmallVector class.
  21. //
  22. //===----------------------------------------------------------------------===//
  23. /**
  24. * \file include/megdnn/thin/small_vector.h
  25. *
  26. * This file is part of MegDNN, a deep neural network run-time library
  27. * developed by Megvii.
  28. *
  29. * \brief thin megdnn function
  30. *
  31. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  32. */
  33. #pragma once
  34. #include "megdnn/arch.h"
  35. #include <algorithm>
  36. #include <cstdlib>
  37. #include <cstring>
  38. #include <iterator>
  39. #include <limits>
  40. #include <memory>
  41. #include <type_traits>
  42. #include "megdnn/internal/visibility_prologue.h"
  43. namespace megdnn {
  44. class SmallVectorBase {
  45. protected:
  46. void *m_begin_ptr, *m_end_ptr, *m_capacity_ptr;
  47. MGE_WIN_DECLSPEC_FUC MEGDNN_NORETURN static void on_invalid_at(
  48. size_t idx, size_t size);
  49. protected:
  50. SmallVectorBase(void* first_elm, size_t size)
  51. : m_begin_ptr(first_elm),
  52. m_end_ptr(first_elm),
  53. m_capacity_ptr(static_cast<char*>(first_elm) + size) {}
  54. MGE_WIN_DECLSPEC_FUC void grow_pod(
  55. void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size);
  56. public:
  57. size_t size_in_bytes() const {
  58. return size_t(static_cast<char*>(m_end_ptr) - static_cast<char*>(m_begin_ptr));
  59. }
  60. size_t capacity_in_bytes() const {
  61. return size_t(
  62. static_cast<char*>(m_capacity_ptr) - static_cast<char*>(m_begin_ptr));
  63. }
  64. bool empty() const { return m_begin_ptr == m_end_ptr; }
  65. };
  66. template <typename T, typename = void>
  67. class SmallVectorTemplateCommon : public SmallVectorBase {
  68. private:
  69. template <typename, unsigned>
  70. friend struct SmallVectorStorage;
  71. using U = typename std::aligned_storage<sizeof(T), alignof(T)>::type;
  72. U m_first_elm;
  73. protected:
  74. SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {}
  75. void grow_pod(size_t min_sz_in_bytes, size_t type_size) {
  76. SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size);
  77. }
  78. bool is_small() { return m_begin_ptr == static_cast<const void*>(&m_first_elm); }
  79. void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; }
  80. void set_end(T* p) { m_end_ptr = p; }
  81. public:
  82. using size_type = size_t;
  83. using difference_type = std::ptrdiff_t;
  84. using value_type = T;
  85. using iterator = T*;
  86. using const_iterator = const T*;
  87. using reverse_iterator = std::reverse_iterator<iterator>;
  88. using const_reverse_iterator = std::reverse_iterator<const_iterator>;
  89. using reference = T&;
  90. using const_reference = const T&;
  91. using pointer = T*;
  92. using const_pointer = const T*;
  93. size_t capacity() const { return capacity_ptr() - begin(); }
  94. protected:
  95. iterator capacity_ptr() { return static_cast<iterator>(m_capacity_ptr); }
  96. const_iterator capacity_ptr() const {
  97. return static_cast<const_iterator>(m_capacity_ptr);
  98. }
  99. public:
  100. // forwarding iterator creation
  101. iterator begin() { return static_cast<iterator>(m_begin_ptr); }
  102. const_iterator begin() const { return static_cast<const_iterator>(m_begin_ptr); }
  103. const_iterator cbegin() const { return static_cast<const_iterator>(m_begin_ptr); }
  104. iterator end() { return static_cast<iterator>(m_end_ptr); }
  105. const_iterator end() const { return static_cast<const_iterator>(m_end_ptr); }
  106. const_iterator cend() const { return static_cast<const_iterator>(m_end_ptr); }
  107. reference at(size_type idx) {
  108. if (idx >= size()) {
  109. on_invalid_at(idx, size());
  110. }
  111. return begin()[idx];
  112. }
  113. const_reference at(size_type idx) const {
  114. if (idx >= size()) {
  115. on_invalid_at(idx, size());
  116. }
  117. return begin()[idx];
  118. }
  119. reference operator[](size_type idx) { return begin()[idx]; }
  120. const_reference operator[](size_type idx) const { return begin()[idx]; }
  121. reference front() { return begin()[0]; }
  122. const_reference front() const { return begin()[0]; }
  123. reference back() { return rbegin()[0]; }
  124. const_reference back() const { return rbegin()[0]; }
  125. // reverse iterator creation method.
  126. reverse_iterator rbegin() { return reverse_iterator(end()); }
  127. const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); }
  128. reverse_iterator rend() { return reverse_iterator(begin()); }
  129. const_reverse_iterator rend() const { return const_reverse_iterator(begin()); }
  130. pointer data() { return pointer(begin()); }
  131. const_pointer data() const { return const_pointer(begin()); }
  132. size_type size() const { return end() - begin(); }
  133. size_type max_size() const {
  134. return std::numeric_limits<size_type>::max() / sizeof(T);
  135. }
  136. template <typename in_iter>
  137. in_iter find(in_iter first, in_iter last, const T& value) const {
  138. while (first != last) {
  139. if (*first == value)
  140. return first;
  141. ++first;
  142. }
  143. return last;
  144. }
  145. };
  146. template <typename T, bool is_pod>
  147. class SmallVectorTemplateBase : public SmallVectorTemplateCommon<T> {
  148. protected:
  149. SmallVectorTemplateBase(size_t size) : SmallVectorTemplateCommon<T>(size) {}
  150. static void destroy_range(T* start, T* end) {
  151. while (start != end) {
  152. --end;
  153. end->~T();
  154. }
  155. }
  156. template <typename It1, typename It2>
  157. static void uninitialized_move(It1 first, It1 last, It2 dest) {
  158. std::uninitialized_copy(
  159. std::make_move_iterator(first), std::make_move_iterator(last), dest);
  160. }
  161. template <typename It1, typename It2>
  162. static void uninitialized_copy(It1 first, It1 last, It2 dest) {
  163. std::uninitialized_copy(first, last, dest);
  164. }
  165. void grow(size_t min_sz = 0);
  166. public:
  167. void push_back(const T& _elm) {
  168. if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) {
  169. T elm = _elm;
  170. this->grow();
  171. new (static_cast<void*>(this->end())) T(std::move(elm));
  172. } else {
  173. new (static_cast<void*>(this->end())) T(_elm);
  174. }
  175. this->set_end(this->end() + 1);
  176. }
  177. void push_back(T&& elm) {
  178. if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) {
  179. this->grow();
  180. }
  181. new (static_cast<void*>(this->end())) T(std::move(elm));
  182. this->set_end(this->end() + 1);
  183. }
  184. void pop_back() {
  185. this->set_end(this->end() - 1);
  186. this->end()->~T();
  187. }
  188. };
  189. template <typename T, bool is_pod>
  190. void SmallVectorTemplateBase<T, is_pod>::grow(size_t min_sz) {
  191. size_t cur_capacity = this->capacity();
  192. size_t cur_sz = this->size();
  193. size_t new_capacity = (cur_capacity + 2) * 2;
  194. if (new_capacity < min_sz) {
  195. new_capacity = min_sz;
  196. }
  197. T* elms = static_cast<T*>(malloc(new_capacity * sizeof(T)));
  198. this->uninitialized_move(this->begin(), this->end(), elms);
  199. this->destroy_range(this->begin(), this->end());
  200. if (!this->is_small()) {
  201. free(this->begin());
  202. }
  203. this->m_begin_ptr = elms;
  204. this->set_end(elms + cur_sz);
  205. this->m_capacity_ptr = this->begin() + new_capacity;
  206. }
  207. template <typename T>
  208. class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
  209. protected:
  210. SmallVectorTemplateBase(size_t size) : SmallVectorTemplateCommon<T>(size) {}
  211. static void destroy_range(T*, T*) {}
  212. template <typename It1, typename It2>
  213. static void uninitialized_move(It1 first, It1 last, It2 dest) {
  214. uninitialized_copy(first, last, dest);
  215. }
  216. template <typename It1, typename It2>
  217. static void uninitialized_copy(It1 first, It1 last, It2 dest) {
  218. std::uninitialized_copy(first, last, dest);
  219. }
  220. template <typename T1, typename T2>
  221. static void uninitialized_copy(
  222. T1* first, T1* last, T2* dest,
  223. typename std::enable_if<std::is_same<
  224. typename std::remove_const<T1>::type, T2>::value>::type* =
  225. nullptr) {
  226. if (first != last)
  227. memcpy(dest, first, (last - first) * sizeof(T));
  228. }
  229. void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); }
  230. public:
  231. void push_back(const T& _elm) {
  232. if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) {
  233. T elm = _elm;
  234. this->grow();
  235. memcpy(this->end(), &elm, sizeof(T));
  236. } else {
  237. memcpy(this->end(), &_elm, sizeof(T));
  238. }
  239. this->set_end(this->end() + 1);
  240. }
  241. void pop_back() { this->set_end(this->end() - 1); }
  242. };
  243. /*!
  244. * \brief the implementation class of SmallVector
  245. *
  246. * SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N
  247. */
  248. template <typename T>
  249. class SmallVectorImpl : public SmallVectorTemplateBase<T, std::is_pod<T>::value> {
  250. using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>;
  251. public:
  252. using iterator = typename SuperClass::iterator;
  253. using const_iterator = typename SuperClass::const_iterator;
  254. using size_type = typename SuperClass::size_type;
  255. protected:
  256. explicit SmallVectorImpl(unsigned n)
  257. : SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {}
  258. public:
  259. SmallVectorImpl(const SmallVectorImpl&) = delete;
  260. ~SmallVectorImpl() {
  261. this->destroy_range(this->begin(), this->end());
  262. if (!this->is_small())
  263. free(this->begin());
  264. }
  265. void clear() {
  266. this->destroy_range(this->begin(), this->end());
  267. this->m_end_ptr = this->m_begin_ptr;
  268. }
  269. void resize(size_type n) {
  270. if (n < this->size()) {
  271. this->destroy_range(this->begin() + n, this->end());
  272. this->set_end(this->begin() + n);
  273. } else if (n > this->size()) {
  274. if (this->capacity() < n)
  275. this->grow(n);
  276. for (auto it = this->end(), end = this->begin() + n; it != end; ++it)
  277. new (&*it) T();
  278. this->set_end(this->begin() + n);
  279. }
  280. }
  281. void resize(size_type n, const T& _nv) {
  282. T nv = _nv;
  283. if (n < this->size()) {
  284. this->destroy_range(this->begin() + n, this->end());
  285. this->set_end(this->begin() + n);
  286. } else if (n > this->size()) {
  287. if (this->capacity() < n)
  288. this->grow(n);
  289. std::uninitialized_fill(this->end(), this->begin() + n, nv);
  290. this->set_end(this->begin() + n);
  291. }
  292. }
  293. void reserve(size_type n) {
  294. if (this->capacity() < n) {
  295. this->grow(n);
  296. }
  297. }
  298. T pop_back_val() {
  299. T result = std::move(this->back());
  300. this->pop_back();
  301. return result;
  302. }
  303. void swap(SmallVectorImpl<T>& rhs);
  304. /// Add the specified range to the end of the SmallVector.
  305. template <
  306. typename in_iter,
  307. typename = typename std::enable_if<std::is_convertible<
  308. typename std::iterator_traits<in_iter>::iterator_category,
  309. std::input_iterator_tag>::value>::type>
  310. void append(in_iter in_start, in_iter in_end) {
  311. size_type num_inputs = std::distance(in_start, in_end);
  312. // Grow allocated space if needed.
  313. if (num_inputs > size_type(this->capacity_ptr() - this->end()))
  314. this->grow(this->size() + num_inputs);
  315. // Copy the new elements over.
  316. this->uninitialized_copy(in_start, in_end, this->end());
  317. this->set_end(this->end() + num_inputs);
  318. }
  319. /// Add the specified range to the end of the SmallVector.
  320. void append(size_type num_inputs, const T& _elm) {
  321. T elm = _elm;
  322. // Grow allocated space if needed.
  323. if (num_inputs > size_type(this->capacity_ptr() - this->end()))
  324. this->grow(this->size() + num_inputs);
  325. // Copy the new elements over.
  326. std::uninitialized_fill_n(this->end(), num_inputs, elm);
  327. this->set_end(this->end() + num_inputs);
  328. }
  329. void append(std::initializer_list<T> init_list) {
  330. append(init_list.begin(), init_list.end());
  331. }
  332. // FIXME: Consider assigning over existing elements, rather than clearing &
  333. // re-initializing them - for all assign(...) variants.
  334. void assign(size_type num_elms, const T& _elm) {
  335. T elm = _elm;
  336. clear();
  337. if (this->capacity() < num_elms)
  338. this->grow(num_elms);
  339. this->set_end(this->begin() + num_elms);
  340. std::uninitialized_fill(this->begin(), this->end(), elm);
  341. }
  342. template <
  343. typename in_iter,
  344. typename = typename std::enable_if<std::is_convertible<
  345. typename std::iterator_traits<in_iter>::iterator_category,
  346. std::input_iterator_tag>::value>::type>
  347. void assign(in_iter in_start, in_iter in_end) {
  348. clear();
  349. append(in_start, in_end);
  350. }
  351. void assign(std::initializer_list<T> init_list) {
  352. clear();
  353. append(init_list);
  354. }
  355. iterator erase(const_iterator cit) {
  356. // Just cast away constness because this is a non-const member function.
  357. iterator it = const_cast<iterator>(cit);
  358. iterator n = it;
  359. // Shift all elms down one.
  360. std::move(it + 1, this->end(), it);
  361. // Drop the last elm.
  362. this->pop_back();
  363. return (n);
  364. }
  365. iterator erase(const_iterator c_first, const_iterator c_last) {
  366. // Just cast away constness because this is a non-const member function.
  367. iterator first = const_cast<iterator>(c_first);
  368. iterator last = const_cast<iterator>(c_last);
  369. iterator n = first;
  370. // Shift all elms down.
  371. iterator it = std::move(last, this->end(), first);
  372. // Drop the last elms.
  373. this->destroy_range(it, this->end());
  374. this->set_end(it);
  375. return (n);
  376. }
  377. iterator insert(iterator it, T&& elm) {
  378. if (it == this->end()) { // Important special case for empty vector.
  379. this->push_back(std::move(elm));
  380. return this->end() - 1;
  381. }
  382. if (this->m_end_ptr >= this->m_capacity_ptr) {
  383. size_t elm_idx = it - this->begin();
  384. this->grow();
  385. it = this->begin() + elm_idx;
  386. }
  387. new (static_cast<void*>(this->end())) T(std::move(this->back()));
  388. // Push everything else over.
  389. std::move_backward(it, this->end() - 1, this->end());
  390. this->set_end(this->end() + 1);
  391. // If we just moved the element we're inserting, be sure to update
  392. // the reference.
  393. T* elm_ptr = &elm;
  394. if (it <= elm_ptr && elm_ptr < this->m_end_ptr)
  395. ++elm_ptr;
  396. *it = std::move(*elm_ptr);
  397. return it;
  398. }
  399. iterator insert(iterator it, const T& _elm) {
  400. if (it == this->end()) { // Important special case for empty vector.
  401. this->push_back(_elm);
  402. return this->end() - 1;
  403. }
  404. T elm = _elm;
  405. if (this->m_end_ptr >= this->m_capacity_ptr) {
  406. size_t elm_idx = it - this->begin();
  407. this->grow();
  408. it = this->begin() + elm_idx;
  409. }
  410. new (static_cast<void*>(this->end())) T(std::move(this->back()));
  411. // Push everything else over.
  412. std::move_backward(it, this->end() - 1, this->end());
  413. this->set_end(this->end() + 1);
  414. // If we just moved the element we're inserting, be sure to update
  415. // the reference.
  416. const T* elm_ptr = &elm;
  417. if (it <= elm_ptr && elm_ptr < this->m_end_ptr)
  418. ++elm_ptr;
  419. *it = *elm_ptr;
  420. return it;
  421. }
  422. iterator insert(iterator it, size_type num_to_insert, const T& _elm) {
  423. // Convert iterator to elm# to avoid invalidating iterator
  424. // when we reserve()
  425. size_t elm_idx = it - this->begin();
  426. if (it == this->end()) { // Important special case for empty vector.
  427. append(num_to_insert, _elm);
  428. return this->begin() + elm_idx;
  429. }
  430. T elm = _elm;
  431. // Ensure there is enough space.
  432. reserve(this->size() + num_to_insert);
  433. // Uninvalidate the iterator.
  434. it = this->begin() + elm_idx;
  435. // If there are more elements between the insertion point and
  436. // the end of the range than there are being inserted,
  437. // we can use a simple approach to insertion.
  438. // Since we already reserved space, we know that this won't
  439. // reallocate the vector.
  440. if (size_t(this->end() - it) >= num_to_insert) {
  441. T* old_end = this->end();
  442. append(std::move_iterator<iterator>(this->end() - num_to_insert),
  443. std::move_iterator<iterator>(this->end()));
  444. // Copy the existing elements that get replaced.
  445. std::move_backward(it, old_end - num_to_insert, old_end);
  446. std::fill_n(it, num_to_insert, elm);
  447. return it;
  448. }
  449. // Otherwise, we're inserting more elements than exist already,
  450. // and we're not inserting at the end.
  451. // Move over the elements that we're about to overwrite.
  452. T* old_end = this->end();
  453. this->set_end(this->end() + num_to_insert);
  454. size_t num_overwritten = old_end - it;
  455. this->uninitialized_move(it, old_end, this->end() - num_overwritten);
  456. // Replace the overwritten part.
  457. std::fill_n(it, num_overwritten, elm);
  458. // Insert the non-overwritten middle part.
  459. std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm);
  460. return it;
  461. }
  462. template <
  463. typename IterType,
  464. typename = typename std::enable_if<std::is_convertible<
  465. typename std::iterator_traits<IterType>::iterator_category,
  466. std::input_iterator_tag>::value>::type>
  467. iterator insert(iterator it, IterType from, IterType to) {
  468. // Convert iterator to elm# to avoid invalidating iterator
  469. // when we reserve()
  470. size_t elm_idx = it - this->begin();
  471. if (it == this->end()) { // Important special case for empty vector.
  472. append(from, to);
  473. return this->begin() + elm_idx;
  474. }
  475. size_t num_to_insert = std::distance(from, to);
  476. // Ensure there is enough space.
  477. reserve(this->size() + num_to_insert);
  478. // Uninvalidate the iterator.
  479. it = this->begin() + elm_idx;
  480. // If there are more elements between the insertion point and
  481. // the end of the range than there are being inserted,
  482. // we can use a simple approach to insertion.
  483. // Since we already reserved space, we know that this won't
  484. // reallocate the vector.
  485. if (size_t(this->end() - it) >= num_to_insert) {
  486. T* old_end = this->end();
  487. append(std::move_iterator<iterator>(this->end() - num_to_insert),
  488. std::move_iterator<iterator>(this->end()));
  489. // Copy the existing elements that get replaced.
  490. std::move_backward(it, old_end - num_to_insert, old_end);
  491. std::copy(from, to, it);
  492. return it;
  493. }
  494. // Otherwise, we're inserting more elements than exist already,
  495. // and we're not inserting at the end.
  496. // Move over the elements that we're about to overwrite.
  497. T* old_end = this->end();
  498. this->set_end(this->end() + num_to_insert);
  499. size_t num_overwritten = old_end - it;
  500. this->uninitialized_move(it, old_end, this->end() - num_overwritten);
  501. // Replace the overwritten part.
  502. for (T* iter = it; num_overwritten > 0; --num_overwritten) {
  503. *iter = *from;
  504. ++iter;
  505. ++from;
  506. }
  507. // Insert the non-overwritten middle part.
  508. this->uninitialized_copy(from, to, old_end);
  509. return it;
  510. }
  511. void insert(iterator it, std::initializer_list<T> init_list) {
  512. insert(it, init_list.begin(), init_list.end());
  513. }
  514. template <typename... ArgTypes>
  515. void emplace_back(ArgTypes&&... args) {
  516. if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) {
  517. this->grow();
  518. }
  519. new (static_cast<void*>(this->end())) T(std::forward<ArgTypes>(args)...);
  520. this->set_end(this->end() + 1);
  521. }
  522. SmallVectorImpl& operator=(const SmallVectorImpl& rhs);
  523. SmallVectorImpl& operator=(SmallVectorImpl&& rhs);
  524. bool operator==(const SmallVectorImpl<T>& rhs) const {
  525. if (this->size() != rhs.size())
  526. return false;
  527. return std::equal(this->begin(), this->end(), rhs.begin());
  528. }
  529. bool operator!=(const SmallVectorImpl<T>& rhs) const { return !(*this == rhs); }
  530. bool operator<(const SmallVectorImpl<T>& rhs) const {
  531. return std::lexicographical_compare(
  532. this->begin(), this->end(), rhs.begin(), rhs.end());
  533. }
  534. };
  535. template <typename T>
  536. void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) {
  537. if (this == &rhs)
  538. return;
  539. // We can only avoid copying elements if neither vector is small.
  540. if (!this->is_small() && !rhs.is_small()) {
  541. std::swap(this->m_begin_ptr, rhs.m_begin_ptr);
  542. std::swap(this->m_end_ptr, rhs.m_end_ptr);
  543. std::swap(this->m_capacity_ptr, rhs.m_capacity_ptr);
  544. return;
  545. }
  546. if (rhs.size() > this->capacity())
  547. this->grow(rhs.size());
  548. if (this->size() > rhs.capacity())
  549. rhs.grow(this->size());
  550. // Swap the shared elements.
  551. size_t num_shared = this->size();
  552. if (num_shared > rhs.size())
  553. num_shared = rhs.size();
  554. for (size_type i = 0; i != num_shared; ++i)
  555. std::swap((*this)[i], rhs[i]);
  556. // Copy over the extra elms.
  557. if (this->size() > rhs.size()) {
  558. size_t elm_diff = this->size() - rhs.size();
  559. this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end());
  560. rhs.set_end(rhs.end() + elm_diff);
  561. this->destroy_range(this->begin() + num_shared, this->end());
  562. this->set_end(this->begin() + num_shared);
  563. } else if (rhs.size() > this->size()) {
  564. size_t elm_diff = rhs.size() - this->size();
  565. this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end());
  566. this->set_end(this->end() + elm_diff);
  567. this->destroy_range(rhs.begin() + num_shared, rhs.end());
  568. rhs.set_end(rhs.begin() + num_shared);
  569. }
  570. }
  571. template <typename T>
  572. SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(const SmallVectorImpl<T>& rhs) {
  573. if (this == &rhs)
  574. return *this;
  575. size_t rhs_sz = rhs.size();
  576. size_t cur_sz = this->size();
  577. if (cur_sz >= rhs_sz) {
  578. iterator new_end;
  579. if (rhs_sz) {
  580. new_end = std::copy(rhs.begin(), rhs.end(), this->begin());
  581. } else {
  582. new_end = this->begin();
  583. }
  584. this->destroy_range(new_end, this->end());
  585. this->set_end(new_end);
  586. return *this;
  587. }
  588. if (this->capacity() < rhs_sz) {
  589. // save time for no copy when growing
  590. this->destroy_range(this->begin(), this->end());
  591. this->set_end(this->begin());
  592. cur_sz = 0;
  593. this->grow(rhs_sz);
  594. } else if (cur_sz) {
  595. std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin());
  596. }
  597. std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz);
  598. this->set_end(this->begin() + rhs_sz);
  599. return *this;
  600. }
  601. template <typename T>
  602. SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) {
  603. // avoid self assignment
  604. if (this == &rhs)
  605. return *this;
  606. // copy ptr when rhs is small
  607. if (!rhs.is_small()) {
  608. this->destroy_range(this->begin(), this->end());
  609. if (!this->is_small())
  610. free(this->begin());
  611. this->m_begin_ptr = rhs.m_begin_ptr;
  612. this->m_end_ptr = rhs.m_end_ptr;
  613. this->m_capacity_ptr = rhs.m_capacity_ptr;
  614. rhs.reset_to_small();
  615. return *this;
  616. }
  617. size_t rhs_sz = rhs.size();
  618. size_t cur_sz = this->size();
  619. if (cur_sz >= rhs_sz) {
  620. iterator new_end = this->begin();
  621. if (rhs_sz) {
  622. new_end = std::move(rhs.begin(), rhs.end(), new_end);
  623. }
  624. this->destroy_range(new_end, this->end());
  625. this->set_end(new_end);
  626. rhs.clear();
  627. return *this;
  628. }
  629. if (this->capacity() < rhs_sz) {
  630. this->destroy_range(this->begin(), this->end());
  631. this->set_end(this->begin());
  632. cur_sz = 0;
  633. this->grow(rhs_sz);
  634. } else if (cur_sz) {
  635. std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin());
  636. }
  637. this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz);
  638. this->set_end(this->begin() + rhs_sz);
  639. rhs.clear();
  640. return *this;
  641. }
  642. template <typename T, unsigned N>
  643. struct SmallVectorStorage {
  644. typename SmallVectorTemplateCommon<T>::U inline_elms[N - 1];
  645. };
  646. template <typename T>
  647. struct SmallVectorStorage<T, 1> {};
  648. template <typename T>
  649. struct SmallVectorStorage<T, 0> {};
  650. /*!
  651. * \brief This is a 'vector' (really, a variable-sized array), optimized for the
  652. * case when the array is small.
  653. *
  654. * It contains some number of elements in-place,
  655. * which allows it to avoid heap allocation when the actual number of elements
  656. * is below that threshold. This allows normal "small" cases to be fast without
  657. * losing generality for large inputs.
  658. *
  659. * Note that this does not attempt to be exception safe.
  660. *
  661. * SmallVector<T, N>& can be converted to SmallVectorImpl<T>& to erase the
  662. * template param \p N; this is useful for function params.
  663. *
  664. * \tparam T emelment type
  665. * \tparam N number of elements to be stored in the class object
  666. */
  667. template <typename T, unsigned N = 4>
  668. class SmallVector : public SmallVectorImpl<T> {
  669. SmallVectorStorage<T, N> m_storage;
  670. public:
  671. SmallVector() : SmallVectorImpl<T>(N) {}
  672. explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl<T>(N) {
  673. this->assign(size, value);
  674. }
  675. template <
  676. typename IterType,
  677. typename = typename std::enable_if<std::is_convertible<
  678. typename std::iterator_traits<IterType>::iterator_category,
  679. std::input_iterator_tag>::value>::type>
  680. SmallVector(IterType first, IterType last) : SmallVectorImpl<T>(N) {
  681. this->append(first, last);
  682. }
  683. SmallVector(std::initializer_list<T> init_list) : SmallVectorImpl<T>(N) {
  684. this->assign(init_list);
  685. }
  686. SmallVector(const SmallVector& rhs) : SmallVectorImpl<T>(N) {
  687. if (!rhs.empty())
  688. SmallVectorImpl<T>::operator=(rhs);
  689. }
  690. ~SmallVector() {}
  691. const SmallVector& operator=(const SmallVector& rhs) {
  692. SmallVectorImpl<T>::operator=(rhs);
  693. return *this;
  694. }
  695. SmallVector(SmallVector&& rhs) : SmallVectorImpl<T>(N) {
  696. if (!rhs.empty())
  697. SmallVectorImpl<T>::operator=(std::move(rhs));
  698. }
  699. SmallVector(SmallVectorImpl<T>&& rhs) : SmallVectorImpl<T>(N) {
  700. if (!rhs.empty())
  701. SmallVectorImpl<T>::operator=(std::move(rhs));
  702. }
  703. const SmallVector& operator=(SmallVector&& rhs) {
  704. SmallVectorImpl<T>::operator=(std::move(rhs));
  705. return *this;
  706. }
  707. const SmallVector& operator=(SmallVectorImpl<T>&& rhs) {
  708. SmallVectorImpl<T>::operator=(std::move(rhs));
  709. return *this;
  710. }
  711. const SmallVector& operator=(std::initializer_list<T> init_list) {
  712. this->assign(init_list);
  713. return *this;
  714. }
  715. };
  716. template <typename T, unsigned n>
  717. static inline size_t capacity_in_bytes(const SmallVector<T, n>& vec) {
  718. return vec.capacity_in_bytes();
  719. }
  720. template <typename T>
  721. inline typename SmallVectorImpl<T>::const_iterator find(
  722. const SmallVectorImpl<T>& vec, const T& value) {
  723. return vec.find(vec.begin(), vec.end(), value);
  724. }
  725. } // end namespace megdnn
  726. #include "megdnn/internal/visibility_epilogue.h"
  727. namespace std {
  728. /// Implement std::swap in terms of SmallVector swap.
  729. template <typename T>
  730. inline void swap(megdnn::SmallVectorImpl<T>& lhs, megdnn::SmallVectorImpl<T>& rhs) {
  731. lhs.swap(rhs);
  732. }
  733. /// Implement std::swap in terms of SmallVector swap.
  734. template <typename T, unsigned N>
  735. inline void swap(megdnn::SmallVector<T, N>& lhs, megdnn::SmallVector<T, N>& rhs) {
  736. lhs.swap(rhs);
  737. }
  738. } // end namespace std
  739. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台