You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.inl 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * \file dnn/test/common/tensor.inl
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./tensor.h"
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/index.h"
  14. #include "test/common/get_dtype_from_static_type.h"
  15. #include "test/common/utils.h"
  16. #include <memory>
  17. namespace megdnn {
  18. namespace test {
  19. template <typename T, typename C>
  20. Tensor<T, C>::Tensor(Handle *handle, TensorLayout layout):
  21. m_handle(handle),
  22. m_comparator(C())
  23. {
  24. layout.dtype = get_dtype_from_static_type<T>();
  25. m_tensornd.raw_ptr = megdnn_malloc(m_handle, layout.span().dist_byte());
  26. m_tensornd.layout = layout;
  27. }
  28. template <typename T, typename C>
  29. Tensor<T, C>::~Tensor()
  30. {
  31. megdnn_free(m_handle, m_tensornd.raw_ptr);
  32. }
  33. template <typename T, typename C>
  34. T *Tensor<T, C>::ptr()
  35. {
  36. return m_tensornd.ptr<T>();
  37. }
  38. template <typename T, typename C>
  39. const T *Tensor<T, C>::ptr() const
  40. {
  41. return m_tensornd.ptr<T>();
  42. }
  43. template <typename T, typename C>
  44. TensorLayout Tensor<T, C>::layout() const
  45. {
  46. return m_tensornd.layout;
  47. }
  48. template <typename T, typename C> template <typename C_>
  49. void Tensor<T, C>::check_with(const Tensor<T, C_> &rhs) const
  50. {
  51. // compare layout
  52. ASSERT_TRUE(this->m_tensornd.layout.eq_layout(rhs.m_tensornd.layout))
  53. << "this->layout is " << this->m_tensornd.layout.to_string()
  54. << "rhs.layout is " << rhs.m_tensornd.layout.to_string();
  55. // compare value
  56. auto n = m_tensornd.layout.total_nr_elems();
  57. auto p0 = this->ptr(), p1 = rhs.ptr();
  58. for (size_t linear_idx = 0; linear_idx < n; ++linear_idx) {
  59. auto index = Index(m_tensornd.layout, linear_idx);
  60. auto offset = index.positive_offset();
  61. ASSERT_TRUE(m_comparator.is_same(p0[offset], p1[offset]))
  62. << "Index is " << index.to_string()
  63. << "; layout is " << m_tensornd.layout.to_string()
  64. << "; this->ptr()[offset] is " << this->ptr()[offset]
  65. << "; rhs.ptr()[offset] is " << rhs.ptr()[offset];
  66. }
  67. }
  68. template <typename T, typename C>
  69. SyncedTensor<T, C>::SyncedTensor(Handle *dev_handle, TensorLayout layout):
  70. m_handle_host(create_cpu_handle(2, false)),
  71. m_handle_dev(dev_handle),
  72. m_tensor_host(m_handle_host.get(), layout),
  73. m_tensor_dev(m_handle_dev, layout),
  74. m_sync_state(SyncState::UNINITED)
  75. {
  76. }
  77. template <typename T, typename C>
  78. const T *SyncedTensor<T, C>::ptr_host()
  79. {
  80. ensure_host();
  81. return m_tensor_host.tensornd().template ptr<T>();
  82. }
  83. template <typename T, typename C>
  84. const T *SyncedTensor<T, C>::ptr_dev()
  85. {
  86. ensure_dev();
  87. return m_tensor_dev.tensornd().template ptr<T>();
  88. }
  89. template <typename T, typename C>
  90. T *SyncedTensor<T, C>::ptr_mutable_host()
  91. {
  92. ensure_host();
  93. m_sync_state = SyncState::HOST;
  94. return m_tensor_host.tensornd().template ptr<T>();
  95. }
  96. template <typename T, typename C>
  97. T *SyncedTensor<T, C>::ptr_mutable_dev()
  98. {
  99. ensure_dev();
  100. m_sync_state = SyncState::DEV;
  101. return m_tensor_dev.tensornd().template ptr<T>();
  102. }
  103. template <typename T, typename C>
  104. TensorND SyncedTensor<T, C>::tensornd_host()
  105. {
  106. ensure_host();
  107. m_sync_state = SyncState::HOST;
  108. return m_tensor_host.tensornd();
  109. }
  110. template <typename T, typename C>
  111. TensorND SyncedTensor<T, C>::tensornd_dev()
  112. {
  113. ensure_dev();
  114. m_sync_state = SyncState::DEV;
  115. return m_tensor_dev.tensornd();
  116. }
  117. template <typename T, typename C>
  118. TensorLayout SyncedTensor<T, C>::layout() const
  119. {
  120. return m_tensor_host.tensornd().layout;
  121. }
  122. template <typename T, typename C> template <typename C_>
  123. void SyncedTensor<T, C>::check_with(SyncedTensor<T, C_> &rhs)
  124. {
  125. this->ensure_host();
  126. rhs.ensure_host();
  127. this->m_tensor_host.check_with(rhs.m_tensor_host);
  128. }
  129. template <typename T, typename C>
  130. void SyncedTensor<T, C>::ensure_host()
  131. {
  132. if (m_sync_state == SyncState::HOST || m_sync_state == SyncState::SYNCED) {
  133. return;
  134. }
  135. if (m_sync_state == SyncState::DEV) {
  136. megdnn_memcpy_D2H(m_handle_dev,
  137. m_tensor_host.ptr(), m_tensor_dev.ptr(),
  138. m_tensor_host.layout().span().dist_byte());
  139. }
  140. m_sync_state = SyncState::SYNCED;
  141. }
  142. template <typename T, typename C>
  143. void SyncedTensor<T, C>::ensure_dev()
  144. {
  145. if (m_sync_state == SyncState::DEV || m_sync_state == SyncState::SYNCED) {
  146. return;
  147. }
  148. if (m_sync_state == SyncState::HOST) {
  149. megdnn_memcpy_H2D(m_handle_dev,
  150. m_tensor_dev.ptr(), m_tensor_host.ptr(),
  151. m_tensor_host.layout().span().dist_byte());
  152. }
  153. m_sync_state = SyncState::SYNCED;
  154. }
  155. } // namespace test
  156. } // namespace megdnn
  157. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台