You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cpp 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790
  1. /**
  2. * \file src/core/impl/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/tensor.h"
  12. #include "megbrain/comp_node_env.h"
  13. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  14. #include "megbrain/opr/param_defs.h"
  15. #include "megdnn/oprs.h"
  16. #include <thread>
  17. #include <cmath>
  18. #include <cstring>
  19. using namespace mgb;
  20. namespace {
  21. //! implement non-contiguous d2d copy
  22. void noncont_tensor_copy(
  23. const DeviceTensorND& dest, const DeviceTensorND& src, bool contig_dest,
  24. bool contig_src) {
  25. auto src_cn = src.comp_node();
  26. auto dst_cn = dest.comp_node();
  27. if (src_cn.device_type() == dst_cn.device_type()) {
  28. // perform relayout op for better performance when src and dst are
  29. // placed on comp nodes with the same device type
  30. auto&& src_env = CompNodeEnv::from_comp_node(src.comp_node());
  31. auto relayout = opr::intl::get_megdnn_global_opr<megdnn::Relayout>(dst_cn);
  32. dst_cn.activate();
  33. relayout->exec(
  34. const_cast<DeviceTensorND&>(src).as_megdnn(), dest.as_megdnn(),
  35. MegDNNHandle::get(src_env).handle());
  36. } else {
  37. if (contig_src) {
  38. mgb_assert(!contig_dest);
  39. DeviceTensorND tmp{dst_cn};
  40. tmp.copy_from(src);
  41. dest.copy_from_fixlayout(tmp);
  42. return;
  43. }
  44. DeviceTensorND tmp;
  45. tmp.copy_from(src);
  46. dest.copy_from_fixlayout(tmp);
  47. }
  48. }
  49. //! implement non-contiguous h2h copy
  50. void noncont_tensor_copy(
  51. const HostTensorND& dest, const HostTensorND& src, bool, bool) {
  52. auto opr =
  53. opr::intl::get_megdnn_global_opr<megdnn::Relayout>(CompNode::default_cpu());
  54. opr->exec(const_cast<HostTensorND&>(src).as_megdnn(), dest.as_megdnn());
  55. }
  56. //! implement non-contiguous d2h copy
  57. void noncont_tensor_copy(
  58. const HostTensorND& dest, const DeviceTensorND& src, bool contig_dest,
  59. bool contig_src) {
  60. if (contig_src) {
  61. mgb_assert(!contig_dest);
  62. HostTensorND tmp;
  63. tmp.copy_from(src).sync();
  64. dest.copy_from_fixlayout(tmp); // sync not needed for h2h copy
  65. return;
  66. }
  67. DeviceTensorND tmp;
  68. tmp.copy_from(src);
  69. dest.copy_from_fixlayout(tmp);
  70. }
  71. //! implement non-contiguous h2d copy
  72. void noncont_tensor_copy(
  73. const DeviceTensorND& dest, const HostTensorND& src, bool contig_dest,
  74. bool contig_src) {
  75. if (contig_src) {
  76. mgb_assert(!contig_dest);
  77. DeviceTensorND tmp;
  78. // no need to sync because device free is async-safe with respect to
  79. // host thread
  80. tmp.copy_from(src);
  81. dest.copy_from_fixlayout(tmp);
  82. return;
  83. }
  84. HostTensorND tmp;
  85. tmp.copy_from(src);
  86. dest.copy_from_fixlayout(tmp).sync();
  87. }
  88. } // anonymous namespace
  89. /* ============= Slice and SubTensorSpec ============= */
  90. SubTensorSpec SubTensorSpec::make_from_offset_elem(
  91. const TensorLayout& layout, ptrdiff_t offset_elem) {
  92. mgb_assert(layout.ndim && layout.dtype.valid());
  93. return {layout, offset_elem};
  94. }
  95. SubTensorSpec Slice::apply(TensorLayout layout, int axis) const {
  96. mgb_assert(layout.ndim > 0 && layout.dtype.valid());
  97. if (axis == megdnn::param::OptionalAxisV1::INVALID_AXIS) {
  98. axis = 0;
  99. layout = layout.collapse_contiguous();
  100. mgb_assert(
  101. layout.ndim == 1,
  102. "apply Slice with axis==INVALID_AXIS on non-contig layout");
  103. }
  104. // axis in [-ndim, ndim) is available
  105. if (axis < 0)
  106. axis += layout.ndim;
  107. mgb_assert(
  108. axis >= 0 && static_cast<size_t>(axis) < layout.ndim,
  109. "invalid axis: %d; ndim=%zu", axis, layout.ndim);
  110. ptrdiff_t size_ax = layout.shape[axis];
  111. ptrdiff_t begin, end, step = m_step.val_with_default(1);
  112. mgb_assert(step, "Slice step can not be zero");
  113. auto tostr = [](const Maybe<ptrdiff_t>& v) -> std::string {
  114. if (!v.valid())
  115. return "None";
  116. return std::to_string(v.val());
  117. };
  118. auto mod_size = [size_ax](ptrdiff_t v) -> ptrdiff_t {
  119. if (size_ax == 0)
  120. return 0;
  121. return v < 0 ? v + size_ax : v;
  122. };
  123. MGB_MARK_USED_VAR(tostr);
  124. #define CHECK(cond) \
  125. if (m_is_scalar_idx) { \
  126. mgb_assert( \
  127. cond, "index out of bound: layout=%s; request index=%s, axis=%d", \
  128. layout.to_string().c_str(), tostr(m_begin).c_str(), axis); \
  129. } else { \
  130. mgb_assert( \
  131. cond, \
  132. "index out of bound: layout=%s; request begin=%s end=%s step=%s " \
  133. "axis=%d", \
  134. layout.to_string().c_str(), tostr(m_begin).c_str(), \
  135. tostr(m_end).c_str(), tostr(m_step).c_str(), axis); \
  136. }
  137. if (step > 0) {
  138. begin = mod_size(m_begin.val_with_default(0));
  139. end = mod_size(m_end.val_with_default(size_ax));
  140. if (!m_is_scalar_idx) {
  141. end = std::min(end, size_ax);
  142. begin = std::min(begin, end);
  143. }
  144. CHECK(begin >= 0 && end >= begin && end <= size_ax)
  145. } else {
  146. begin = mod_size(m_begin.val_with_default(size_ax - 1));
  147. end = m_end.valid() ? mod_size(m_end.val()) : -1;
  148. if (!m_is_scalar_idx) {
  149. begin = std::min(begin, std::max<ptrdiff_t>(size_ax - 1, 0));
  150. end = std::min(end, begin);
  151. }
  152. CHECK(step < 0 && begin >= 0 && end <= begin && begin < size_ax && end >= -1)
  153. }
  154. auto step_abs = std::abs(step);
  155. layout.shape[axis] = (std::abs(end - begin) + step_abs - 1) / step_abs;
  156. auto orig_stride = layout.stride[axis];
  157. layout.stride[axis] *= step;
  158. // make stride as contiguous as possible
  159. if (layout.shape[axis] != 1 && axis)
  160. --axis;
  161. if (layout.shape[axis] == 1) {
  162. auto stride = layout.stride[axis] =
  163. axis + 1 < static_cast<int>(layout.ndim)
  164. ? layout.stride[axis + 1] * layout.shape[axis + 1]
  165. : 1;
  166. for (int i = axis - 1; i >= 0; --i) {
  167. if (layout.shape[i] == 1) {
  168. layout.stride[i] = stride;
  169. } else {
  170. break;
  171. }
  172. }
  173. }
  174. auto offset_elem = layout.is_empty() ? 0 : orig_stride * begin;
  175. return SubTensorSpec::make_from_offset_elem(layout, offset_elem);
  176. #undef CHECK
  177. }
  178. void SubTensorSpec::merge_with(const SubTensorSpec& rhs) {
  179. mgb_assert(
  180. m_layout.dtype.valid() && m_layout.dtype == rhs.m_layout.dtype &&
  181. rhs.m_layout.ndim);
  182. m_offset_elem += rhs.m_offset_elem;
  183. m_layout = rhs.m_layout;
  184. }
  185. /* ===================== TensorStorage ===================== */
  186. class mgb::HostTensorStorageTrait {
  187. public:
  188. static void* alloc(CompNode node, size_t size) { return node.alloc_host(size); }
  189. static void free(CompNode node, void* data) { node.free_host(data); }
  190. };
  191. class mgb::DeviceTensorStorageTrait {
  192. public:
  193. static void* alloc(CompNode node, size_t size) { return node.alloc_device(size); }
  194. static void free(CompNode node, void* data) { node.free_device(data); }
  195. };
  196. template <class Trait>
  197. TensorStorage<Trait>& TensorStorage<Trait>::operator=(const TensorStorage& rhs) {
  198. if (rhs.m_size > rhs.m_capacity) {
  199. rhs.ptr();
  200. }
  201. m_allow_realloc = rhs.m_allow_realloc;
  202. m_comp_node = rhs.m_comp_node;
  203. m_size = rhs.m_size;
  204. m_capacity = rhs.m_capacity;
  205. m_offset = rhs.m_offset;
  206. m_data = rhs.m_data;
  207. m_ref_ptr = rhs.m_ref_ptr;
  208. return *this;
  209. }
  210. template <class Trait>
  211. TensorStorage<Trait>& TensorStorage<Trait>::ensure_size(size_t sz) {
  212. if (sz > m_size) {
  213. mgb_throw_if(
  214. !m_allow_realloc || m_offset, MegBrainError,
  215. "can not grow a tensor that does not allow realloc");
  216. check_comp_node_valid();
  217. }
  218. m_size = sz;
  219. return *this;
  220. }
  221. template <class Trait>
  222. TensorStorage<Trait> TensorStorage<Trait>::sub(ptrdiff_t offset) const {
  223. ptr(); // apply lazy resize
  224. ptrdiff_t toff = offset + m_offset;
  225. if (offset == static_cast<ptrdiff_t>(m_size)) {
  226. return {false, m_comp_node, 0, 0, 0, RawStorage{}};
  227. }
  228. mgb_assert(
  229. toff >= 0 && offset < static_cast<ptrdiff_t>(m_size),
  230. "bad subtensor: offset=%td m_offset=%zu m_size=%zu", offset, m_offset,
  231. m_size);
  232. return {false,
  233. m_comp_node,
  234. m_size - offset,
  235. m_capacity - offset,
  236. static_cast<size_t>(toff),
  237. m_data,
  238. m_ref_ptr};
  239. }
  240. template <class Trait>
  241. dt_byte* TensorStorage<Trait>::apply_lazy_and_get_ptr() {
  242. check_comp_node_valid();
  243. if (m_size > m_capacity) {
  244. mgb_assert(m_allow_realloc && !m_offset);
  245. m_data.reset(); // free old ptr
  246. m_capacity = 0; // to be exception safe
  247. auto ptr = static_cast<dt_byte*>(Trait::alloc(m_comp_node, m_size));
  248. mgb_throw_if(!ptr, SystemError, "failed to allocate memory");
  249. CompNode cn = m_comp_node;
  250. m_data.reset(ptr, [cn](void* p) { Trait::free(cn, p); });
  251. m_ref_ptr = std::make_shared<void*>(static_cast<void*>(nullptr));
  252. m_capacity = m_size;
  253. }
  254. *m_ref_ptr = static_cast<void*>(m_data.get());
  255. return m_data.get() + m_offset;
  256. }
  257. template <class Trait>
  258. TensorStorage<Trait>& TensorStorage<Trait>::comp_node(
  259. CompNode node, bool allow_mem_node_change) {
  260. mgb_assert(node.valid());
  261. if (m_comp_node.valid() && node.mem_node() != m_comp_node.mem_node()) {
  262. mgb_assert(allow_mem_node_change);
  263. m_allow_realloc = true;
  264. m_size = m_capacity = m_offset = 0;
  265. m_data.reset();
  266. }
  267. m_comp_node = node;
  268. return *this;
  269. }
  270. template <class Trait>
  271. void TensorStorage<Trait>::reset(CompNode node, size_t size, RawStorage data) {
  272. mgb_assert(m_allow_realloc);
  273. m_comp_node = node;
  274. m_size = size;
  275. m_capacity = size;
  276. m_offset = 0;
  277. m_data = std::move(data);
  278. m_ref_ptr = std::make_shared<void*>(static_cast<void*>(m_data.get()));
  279. }
  280. template <class Trait>
  281. void TensorStorage<Trait>::only_reset_raw_storage(
  282. CompNode node, size_t size, RawStorage data, size_t offset) {
  283. mgb_assert(m_allow_realloc);
  284. m_comp_node = node;
  285. m_size = size;
  286. m_capacity = size;
  287. m_offset = offset;
  288. m_data = std::move(data);
  289. *m_ref_ptr = static_cast<void*>(m_data.get());
  290. }
  291. template <class Trait>
  292. template <class RTrait, typename>
  293. TensorStorage<Trait> TensorStorage<Trait>::make_proxy(
  294. const TensorStorage<RTrait>& src) {
  295. mgb_assert(
  296. src.comp_node().mem_node() == CompNode::default_cpu().mem_node(),
  297. "proxy source should be on CPU; got %s",
  298. src.comp_node().to_string().c_str());
  299. src.ptr();
  300. return {true, src.m_comp_node, src.m_size, src.m_capacity,
  301. src.m_offset, src.m_data, src.m_ref_ptr};
  302. }
  303. template <class Trait>
  304. void TensorStorage<Trait>::on_invalid_comp_node() {
  305. mgb_throw(
  306. MegBrainError,
  307. "trying to acccess TensorStorage with invalid "
  308. "comp node");
  309. }
  310. namespace mgb {
  311. // host to host
  312. template <>
  313. template <>
  314. MGE_WIN_DECLSPEC_FUC void TensorStorage<HostTensorStorageTrait>::copy_from(
  315. const TensorStorage<HostTensorStorageTrait>& src, size_t size) const {
  316. mgb_assert(size <= this->size() && size <= src.size());
  317. memcpy(ptr(), src.ptr(), size);
  318. }
  319. // device to host
  320. template <>
  321. template <>
  322. MGE_WIN_DECLSPEC_FUC void TensorStorage<HostTensorStorageTrait>::copy_from(
  323. const TensorStorage<DeviceTensorStorageTrait>& src, size_t size) const {
  324. bool need_sync = false;
  325. mgb_assert(size <= this->size() && size <= src.size());
  326. if (m_comp_node != src.comp_node()) {
  327. auto default_cpu = CompNode::default_cpu();
  328. if (src.comp_node() != default_cpu) {
  329. mgb_assert(
  330. m_comp_node == default_cpu,
  331. "inconsistent D2H copy:"
  332. " copy from device to host using different comp nodes:"
  333. " device_node=%s host_node=%s",
  334. src.comp_node().to_string().c_str(),
  335. m_comp_node.to_string().c_str());
  336. // copy_from() should use m_comp_node, and default_cpu is
  337. // synchronous with current thread, so this copy has no
  338. // synchronizing ambiguity and we only need to sync on host
  339. need_sync = true;
  340. }
  341. }
  342. megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
  343. megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
  344. src.comp_node().copy_to_host_ref(dst_ptr, src_ptr, size);
  345. if (need_sync)
  346. src.comp_node().sync();
  347. }
  348. // host to device
  349. template <>
  350. template <>
  351. MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from(
  352. const TensorStorage<HostTensorStorageTrait>& src, size_t size) const {
  353. mgb_assert(size <= this->size() && size <= src.size());
  354. megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
  355. megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
  356. m_comp_node.copy_to_device_ref(dst_ptr, src_ptr, size);
  357. }
  358. // device to device
  359. template <>
  360. template <>
  361. MGE_WIN_DECLSPEC_FUC void TensorStorage<DeviceTensorStorageTrait>::copy_from(
  362. const TensorStorage<DeviceTensorStorageTrait>& src, size_t size) const {
  363. mgb_assert(size <= this->size() && size <= src.size());
  364. if (src.comp_node().device_type() == CompNode::DeviceType::CPU &&
  365. comp_node().device_type() == CompNode::DeviceType::CUDA) {
  366. // current thread(i.e. cuda dispatcher thread) should wait for all
  367. // operations on src's comp_node to finish, otherwise a race condition
  368. // might occur between the worker thread of src's comp_node and the
  369. // thread responsible for copying pageable memory in \p src to a pinned
  370. // buffer, refer to
  371. // https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html
  372. //
  373. // Note: it is highly recommended that copy tensor from cpu to cuda
  374. // with asynchronized disaptching(see graph option async_exec_level),
  375. // or main thread might be blocked by worker thread corresponding to
  376. // the src's comp_node, resulting in bad performance
  377. //
  378. // TODO: consider using cudaMallocHost or cudaHostRegister
  379. // to pin the memory of src tensor, so it does not require synchronization
  380. // and is more efficient
  381. src.comp_node().sync();
  382. megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
  383. megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
  384. comp_node().copy_to_device_ref(dst_ptr, src_ptr, size);
  385. } else {
  386. megdnn::RefPtr src_ptr(src.get_ref_ptr(), src.offset(), false);
  387. megdnn::RefPtr dst_ptr(get_ref_ptr(), offset(), false);
  388. src.comp_node().peer_copy_to_ref(m_comp_node, dst_ptr, src_ptr, size);
  389. }
  390. }
  391. // proxy host to device
  392. template TensorStorage<DeviceTensorStorageTrait> TensorStorage<
  393. DeviceTensorStorageTrait>::
  394. make_proxy<HostTensorStorageTrait, void>(
  395. const TensorStorage<HostTensorStorageTrait>&);
  396. // proxy device to host
  397. template TensorStorage<HostTensorStorageTrait> TensorStorage<HostTensorStorageTrait>::
  398. make_proxy<DeviceTensorStorageTrait, void>(
  399. const TensorStorage<DeviceTensorStorageTrait>&);
  400. } // namespace mgb
  401. /* ===================== TensorND ===================== */
  402. // ctor def {
  403. #define DEF \
  404. template <class TensorStorage> \
  405. TensorND<TensorStorage>::TensorND
  406. DEF() = default;
  407. DEF(CompNode node) : m_storage{node} {}
  408. DEF(DType dtype) : m_layout{dtype} {}
  409. DEF(CompNode node, DType dtype) : m_storage{node}, m_layout{dtype} {}
  410. //! allocate contiguous from given comp node, shape and dtype
  411. DEF(CompNode node, const TensorShape& shape, DType dtype)
  412. : m_storage{node}, m_layout{dtype} {
  413. resize(shape);
  414. }
  415. DEF(CompNode node, const TensorShape& shape, DType dtype, TensorFormat format)
  416. : m_storage{node}, m_layout{dtype, format} {
  417. resize(shape);
  418. }
  419. //! allocate contiguous from given comp node and layout (strides not
  420. //! used)
  421. DEF(CompNode node, const TensorLayout& layout)
  422. : TensorND(node, layout, layout.dtype, layout.format) {
  423. mgb_assert(
  424. layout.is_contiguous() || layout.is_empty(),
  425. "non-contiguous layout used for initializing a tensor: %s",
  426. layout.to_string().c_str());
  427. }
  428. #undef DEF
  429. // ctor def }
  430. // def {
  431. #define DEF(name, ret) \
  432. template <class TensorStorage> \
  433. typename TensorND<TensorStorage>::ChainReturnType ret \
  434. TensorND<TensorStorage>::name
  435. DEF(resize, &)(const TensorShape& shape) {
  436. mgb_assert(m_layout.dtype.valid());
  437. m_layout.init_contiguous_stride(shape);
  438. m_storage.ensure_size(m_layout.span().dist_byte());
  439. return static_cast<ChainReturnType&>(*this);
  440. }
  441. DEF(reset, &)(TensorStorage storage, const TensorLayout& layout) {
  442. //! The storage to be reset is either satisfy the layout or empty.
  443. //! Empty storage is used after weight preprocess for saving memory and
  444. //! checking layout when running
  445. mgb_assert(!layout.ndim || storage.valid_span(layout.span()) || storage.empty());
  446. m_storage = std::move(storage);
  447. m_layout = layout;
  448. return static_cast<ChainReturnType&>(*this);
  449. }
  450. DEF(only_reset_raw_storage, &)(TensorStorage storage) {
  451. //! The storage to be reset is either satisfy the layout or empty.
  452. //! Empty storage is used after weight preprocess for saving memory and
  453. //! checking layout when running
  454. mgb_assert(storage.valid_span(m_layout.span()) || storage.empty());
  455. m_storage.only_reset_raw_storage(
  456. storage.comp_node(), storage.size(), storage.raw_storage(),
  457. storage.offset());
  458. return static_cast<ChainReturnType&>(*this);
  459. }
  460. DEF(comp_node, &)(CompNode comp_node, bool allow_mem_node_change) {
  461. auto orig_cn = m_storage.comp_node_allow_invalid();
  462. m_storage.comp_node(comp_node, allow_mem_node_change);
  463. if (orig_cn.valid() && orig_cn.mem_node() != comp_node.mem_node()) {
  464. m_layout.ndim = 0;
  465. }
  466. return static_cast<ChainReturnType&>(*this);
  467. }
  468. DEF(storage, &)(const TensorStorage& storage) {
  469. if (m_storage.empty() || storage.empty() || m_storage.ptr() != storage.ptr()) {
  470. m_storage = storage;
  471. m_layout.ndim = 0;
  472. }
  473. return static_cast<ChainReturnType&>(*this);
  474. }
  475. DEF(dtype, &)(DType dtype) {
  476. if (m_layout.dtype != dtype) {
  477. m_layout.modify_dtype_inplace(dtype);
  478. m_layout.ndim = 0;
  479. }
  480. return static_cast<ChainReturnType&>(*this);
  481. }
  482. DEF(format, &)(TensorFormat format) {
  483. if (m_layout.format != format) {
  484. m_layout.format = format;
  485. m_layout.ndim = 0;
  486. }
  487. return static_cast<ChainReturnType&>(*this);
  488. }
  489. DEF(operator[], )(std::initializer_list<Slice> slice) const {
  490. auto subspec = SubTensorSpec::make_from_offset_elem(m_layout, 0);
  491. size_t axis = 0;
  492. for (auto&& i : slice) {
  493. subspec.merge_with(i.apply(subspec.layout(), axis));
  494. axis++;
  495. }
  496. return sub(subspec);
  497. }
  498. DEF(sub, )(const SubTensorSpec& spec) const {
  499. mgb_assert(
  500. spec.layout().dtype == dtype() && spec.layout().format == format(),
  501. "invalid subtensor spec: sub_layout=%s self=%s",
  502. spec.layout().to_string().c_str(), m_layout.to_string().c_str());
  503. ChainReturnType rst;
  504. rst.reset(m_storage.sub(spec.offset_byte()), spec.layout());
  505. return rst;
  506. }
  507. #undef DEF
  508. // def }
  509. /* ===================== TensorND::copy_from ===================== */
  510. namespace {
  511. /**
  512. * \brief determine whether to check overlap of two tensors.
  513. * \return true : when HostStorage || (DeviceStorage && SUPPORT_UNIFIED_ADDRESS)
  514. * \note when both support unified address, we can treat them both on CPU. So,
  515. * overlap check should be done
  516. */
  517. template <typename TensorStorage, typename RStorage>
  518. inline bool should_check_overlap(
  519. const TensorND<TensorStorage>& dst, const TensorND<RStorage>& src) {
  520. return true;
  521. }
  522. template <>
  523. inline bool should_check_overlap<HostTensorStorage, DeviceTensorStorage>(
  524. const HostTensorND& dst, const DeviceTensorND& src) {
  525. return src.comp_node().contain_flag(CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
  526. }
  527. template <>
  528. inline bool should_check_overlap<DeviceTensorStorage, HostTensorStorage>(
  529. const DeviceTensorND& dst, const HostTensorND& src) {
  530. return dst.comp_node().contain_flag(CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
  531. }
  532. /**
  533. * \brief D2D tensor copy should check overlap when
  534. * 1. They are on the same mem node. But note that the address must be logical
  535. * comparable. i.e. the original address alloc on enflame is uncomparable.
  536. * 2. They both support unified address, so can be treated as CPU address.
  537. */
  538. template <>
  539. inline bool should_check_overlap<DeviceTensorStorage, DeviceTensorStorage>(
  540. const DeviceTensorND& dst, const DeviceTensorND& src) {
  541. bool is_same_memnode = dst.comp_node().mem_node() == src.comp_node().mem_node();
  542. bool unified_address =
  543. src.comp_node().contain_flag(CompNode::Flag::SUPPORT_UNIFIED_ADDRESS) &&
  544. dst.comp_node().contain_flag(CompNode::Flag::SUPPORT_UNIFIED_ADDRESS);
  545. return is_same_memnode || unified_address;
  546. }
  547. /**
  548. * \brief check overlap of two tensors. throw exception when overlapped
  549. */
  550. inline void check_overlapped(
  551. const dt_byte* dst_min, const dt_byte* dst_max, const dt_byte* src_min,
  552. const dt_byte* src_max) {
  553. mgb_throw_if(
  554. src_min < dst_max && dst_min < src_max, TensorCopyOverlapError,
  555. "cound not perform copy between overlapped tensors");
  556. }
  557. } // namespace
  558. template <class TensorStorage>
  559. template <class RStorage>
  560. typename TensorND<TensorStorage>::ChainReturnType& TensorND<TensorStorage>::copy_from(
  561. const TensorND<RStorage>& src) {
  562. if (!m_storage.comp_node_valid())
  563. m_storage.comp_node(src.comp_node());
  564. if (m_layout.dtype.valid())
  565. m_layout.dtype.assert_is(src.dtype());
  566. else
  567. m_layout.dtype = src.dtype();
  568. m_layout = TensorLayout(src.shape(), m_layout.dtype);
  569. size_t size_bytes = m_layout.span().dist_byte();
  570. m_storage.ensure_size(size_bytes);
  571. if (!size_bytes) {
  572. return static_cast<ChainReturnType&>(*this);
  573. }
  574. // requirement:
  575. // default case, physical contiguous
  576. // lowbit aligned, logical contiguous
  577. if (src.layout().is_physical_contiguous() ||
  578. (src.layout().format.is_lowbit_aligned() && src.layout().is_contiguous())) {
  579. if (should_check_overlap(*this, src)) {
  580. check_overlapped(
  581. m_storage.ptr(), m_storage.ptr() + size_bytes, src.storage().ptr(),
  582. src.storage().ptr() + size_bytes);
  583. }
  584. m_storage.copy_from(src.storage(), size_bytes);
  585. return static_cast<ChainReturnType&>(*this);
  586. }
  587. return const_cast<ChainReturnType&>(copy_from_fixlayout(src));
  588. }
  589. template <class TensorStorage>
  590. template <class RStorage>
  591. const typename TensorND<TensorStorage>::ChainReturnType& TensorND<
  592. TensorStorage>::copy_from_fixlayout(const TensorND<RStorage>& src) const {
  593. dtype().assert_is(src.dtype());
  594. mgb_assert(
  595. m_layout.eq_shape(src.layout()),
  596. "shape differs in copy_from_fixlayout: %s vs %s",
  597. static_cast<const TensorShape&>(m_layout).to_string().c_str(),
  598. static_cast<const TensorShape&>(src.layout()).to_string().c_str());
  599. if (src.empty()) {
  600. return static_cast<const ChainReturnType&>(*this);
  601. }
  602. mgb_assert(
  603. m_layout.is_non_overlapping_strong(),
  604. "copy dest must have non-overlapping layout");
  605. TensorLayout::Span src_span = src.layout().span(), dst_span = layout().span();
  606. if (should_check_overlap(*this, src)) {
  607. check_overlapped(
  608. this->raw_ptr() + dst_span.low_byte,
  609. this->raw_ptr() + dst_span.high_byte, src.raw_ptr() + src_span.low_byte,
  610. src.raw_ptr() + src_span.high_byte);
  611. }
  612. bool self_contig =
  613. m_layout.is_physical_contiguous() ||
  614. (m_layout.format.is_lowbit_aligned() && m_layout.is_contiguous()),
  615. src_contig = src.layout().is_physical_contiguous() ||
  616. (src.layout().format.is_lowbit_aligned() &&
  617. src.layout().is_contiguous());
  618. if (self_contig && src_contig) {
  619. if ((m_layout.format.is_default() && src.layout().format.is_default()) ||
  620. (m_layout.format.is_lowbit_aligned() &&
  621. src.layout().format.is_lowbit_aligned())) {
  622. mgb_assert(
  623. src_span.low_byte == 0 && dst_span.low_byte == 0 &&
  624. src_span.high_byte == dst_span.high_byte);
  625. m_storage.copy_from(src.storage(), src_span.high_byte);
  626. } else {
  627. mgb_assert(src_span.low_byte == 0 && dst_span.low_byte == 0);
  628. m_storage.copy_from(
  629. src.storage(), std::min(src_span.high_byte, dst_span.high_byte));
  630. }
  631. return static_cast<const ChainReturnType&>(*this);
  632. }
  633. noncont_tensor_copy(*this, src, self_contig, src_contig);
  634. return static_cast<const ChainReturnType&>(*this);
  635. }
  636. /* =================== misc =================== */
  637. void mgb::dev_tensor_memset(const DeviceTensorND& tensor, int val) {
  638. auto&& env = CompNodeEnv::from_comp_node(tensor.comp_node());
  639. env.activate();
  640. size_t size = tensor.layout().span().dist_byte();
  641. switch (env.property().type) {
  642. #if MGB_CUDA
  643. case CompNode::DeviceType::CUDA:
  644. MGB_CUDA_CHECK(cudaMemsetAsync(
  645. tensor.raw_ptr(), val, size, env.cuda_env().stream));
  646. break;
  647. #endif
  648. #if MGB_ATLAS
  649. case CompNode::DeviceType::ATLAS:
  650. #if MGB_USE_ATLAS_ASYNC_API
  651. MGB_ATLAS_CHECK(aclrtMemsetAsync(
  652. tensor.raw_ptr(), -1, val, size, env.atlas_env().stream));
  653. #else
  654. MGB_ATLAS_CHECK(aclrtMemset(tensor.raw_ptr(), -1, val, size));
  655. #endif
  656. break;
  657. #endif
  658. #if MGB_CAMBRICON
  659. case CompNode::DeviceType::CAMBRICON:
  660. MGB_CNRT_CHECK(cnrtSyncQueue(env.cnrt_env().queue));
  661. MGB_CNRT_CHECK(cnrtMemset(tensor.raw_ptr(), val, size));
  662. break;
  663. #endif
  664. case CompNode::DeviceType::CPU: {
  665. auto fill = [tensor, size, val]() {
  666. std::memset(tensor.as_megdnn().raw_ptr(), val, size);
  667. };
  668. env.cpu_env().dispatch(fill);
  669. } break;
  670. default:
  671. mgb_throw(
  672. MegBrainError, "unhandled comp node in dev_tensor_memset: %s",
  673. tensor.comp_node().to_string().c_str());
  674. }
  675. }
  676. namespace mgb {
  677. template class TensorStorage<HostTensorStorageTrait>;
  678. template class TensorStorage<DeviceTensorStorageTrait>;
  679. template class TensorND<TensorStorage<HostTensorStorageTrait>>;
  680. template class TensorND<TensorStorage<DeviceTensorStorageTrait>>;
  681. /* ===== copy_from related ===== */
  682. #define HT_RAW TensorND<HostTensorStorage>
  683. #define DT_RAW TensorND<DeviceTensorStorage>
  684. #define HT(f) f<HostTensorStorage>(const HT_RAW&)
  685. #define DT(f) f<DeviceTensorStorage>(const DT_RAW&)
  686. #define INST(f, c) \
  687. template c HostTensorND& HT_RAW::HT(f) c; \
  688. template c HostTensorND& HT_RAW::DT(f) c; \
  689. template c DeviceTensorND& DT_RAW::HT(f) c; \
  690. template c DeviceTensorND& DT_RAW::DT(f) c
  691. INST(copy_from, );
  692. INST(copy_from_fixlayout, const);
  693. #undef INST
  694. #undef DT
  695. #undef HT
  696. #undef DT_RAW
  697. #undef HT_RAW
  698. } // namespace mgb
  699. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}