You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.inl 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. /**
  2. * \file dnn/test/common/tensor.inl
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./tensor.h"
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/index.h"
  14. #include "test/common/get_dtype_from_static_type.h"
  15. #include "test/common/utils.h"
  16. #include <memory>
  17. namespace megdnn {
  18. namespace test {
  19. template <typename T, typename C>
  20. Tensor<T, C>::Tensor(Handle *handle, TensorLayout layout):
  21. m_handle(handle),
  22. m_comparator(C())
  23. {
  24. if (!layout.dtype.valid())
  25. layout.dtype = get_dtype_from_static_type<T>();
  26. m_tensornd.raw_ptr = megdnn_malloc(m_handle, layout.span().dist_byte());
  27. m_tensornd.layout = layout;
  28. }
  29. template <typename T, typename C>
  30. Tensor<T, C>::~Tensor()
  31. {
  32. megdnn_free(m_handle, m_tensornd.raw_ptr);
  33. }
  34. template <typename T, typename C>
  35. T *Tensor<T, C>::ptr()
  36. {
  37. return m_tensornd.ptr<T>();
  38. }
  39. template <typename T, typename C>
  40. const T *Tensor<T, C>::ptr() const
  41. {
  42. return m_tensornd.ptr<T>();
  43. }
  44. template <typename T, typename C>
  45. TensorLayout Tensor<T, C>::layout() const
  46. {
  47. return m_tensornd.layout;
  48. }
  49. template <typename T, typename C> template <typename C_>
  50. void Tensor<T, C>::check_with(const Tensor<T, C_> &rhs) const
  51. {
  52. // compare layout
  53. ASSERT_TRUE(this->m_tensornd.layout.eq_layout(rhs.m_tensornd.layout))
  54. << "this->layout is " << this->m_tensornd.layout.to_string()
  55. << "rhs.layout is " << rhs.m_tensornd.layout.to_string();
  56. // compare value
  57. auto n = m_tensornd.layout.total_nr_elems();
  58. auto p0 = this->ptr(), p1 = rhs.ptr();
  59. for (size_t linear_idx = 0; linear_idx < n; ++linear_idx) {
  60. auto index = Index(m_tensornd.layout, linear_idx);
  61. auto offset = index.positive_offset();
  62. ASSERT_TRUE(m_comparator.is_same(p0[offset], p1[offset]))
  63. << "Index is " << index.to_string() << "; layout is "
  64. << m_tensornd.layout.to_string() << "; this->ptr()[offset] is "
  65. << this->ptr()[offset] << "; rhs.ptr()[offset] is "
  66. << rhs.ptr()[offset];
  67. }
  68. }
  69. template <typename T, typename C>
  70. SyncedTensor<T, C>::SyncedTensor(Handle *dev_handle, TensorLayout layout):
  71. m_handle_host(create_cpu_handle(2, false)),
  72. m_handle_dev(dev_handle),
  73. m_tensor_host(m_handle_host.get(), layout),
  74. m_tensor_dev(m_handle_dev, layout),
  75. m_sync_state(SyncState::UNINITED)
  76. {
  77. }
  78. template <typename T, typename C>
  79. const T *SyncedTensor<T, C>::ptr_host()
  80. {
  81. ensure_host();
  82. return m_tensor_host.tensornd().template ptr<T>();
  83. }
  84. template <typename T, typename C>
  85. const T *SyncedTensor<T, C>::ptr_dev()
  86. {
  87. ensure_dev();
  88. return m_tensor_dev.tensornd().template ptr<T>();
  89. }
  90. template <typename T, typename C>
  91. T *SyncedTensor<T, C>::ptr_mutable_host()
  92. {
  93. ensure_host();
  94. m_sync_state = SyncState::HOST;
  95. return m_tensor_host.tensornd().template ptr<T>();
  96. }
  97. template <typename T, typename C>
  98. T *SyncedTensor<T, C>::ptr_mutable_dev()
  99. {
  100. ensure_dev();
  101. m_sync_state = SyncState::DEV;
  102. return m_tensor_dev.tensornd().template ptr<T>();
  103. }
  104. template <typename T, typename C>
  105. TensorND SyncedTensor<T, C>::tensornd_host()
  106. {
  107. ensure_host();
  108. m_sync_state = SyncState::HOST;
  109. return m_tensor_host.tensornd();
  110. }
  111. template <typename T, typename C>
  112. TensorND SyncedTensor<T, C>::tensornd_dev()
  113. {
  114. ensure_dev();
  115. m_sync_state = SyncState::DEV;
  116. return m_tensor_dev.tensornd();
  117. }
  118. template <typename T, typename C>
  119. TensorLayout SyncedTensor<T, C>::layout() const
  120. {
  121. return m_tensor_host.tensornd().layout;
  122. }
  123. template <typename T, typename C> template <typename C_>
  124. void SyncedTensor<T, C>::check_with(SyncedTensor<T, C_> &rhs)
  125. {
  126. this->ensure_host();
  127. rhs.ensure_host();
  128. this->m_tensor_host.check_with(rhs.m_tensor_host);
  129. }
  130. template <typename T, typename C>
  131. void SyncedTensor<T, C>::ensure_host()
  132. {
  133. if (m_sync_state == SyncState::HOST || m_sync_state == SyncState::SYNCED) {
  134. return;
  135. }
  136. if (m_sync_state == SyncState::DEV) {
  137. megdnn_memcpy_D2H(m_handle_dev,
  138. m_tensor_host.ptr(), m_tensor_dev.ptr(),
  139. m_tensor_host.layout().span().dist_byte());
  140. }
  141. m_sync_state = SyncState::SYNCED;
  142. }
  143. template <typename T, typename C>
  144. void SyncedTensor<T, C>::ensure_dev()
  145. {
  146. if (m_sync_state == SyncState::DEV || m_sync_state == SyncState::SYNCED) {
  147. return;
  148. }
  149. if (m_sync_state == SyncState::HOST) {
  150. megdnn_memcpy_H2D(m_handle_dev,
  151. m_tensor_dev.ptr(), m_tensor_host.ptr(),
  152. m_tensor_host.layout().span().dist_byte());
  153. }
  154. m_sync_state = SyncState::SYNCED;
  155. }
  156. } // namespace test
  157. } // namespace megdnn
  158. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台