You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.inl 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #include "./tensor.h"
  2. #include <memory>
  3. #include "megdnn/basic_types.h"
  4. #include "test/common/get_dtype_from_static_type.h"
  5. #include "test/common/index.h"
  6. #include "test/common/utils.h"
  7. namespace megdnn {
  8. namespace test {
  9. template <typename T, typename C>
  10. Tensor<T, C>::Tensor(Handle* handle, TensorLayout layout)
  11. : m_handle(handle), m_comparator(C()) {
  12. if (!layout.dtype.valid())
  13. layout.dtype = get_dtype_from_static_type<T>();
  14. auto raw_ptr = megdnn_malloc(m_handle, layout.span().dist_byte());
  15. m_tensornd = TensorND{raw_ptr, layout};
  16. }
  17. template <typename T, typename C>
  18. Tensor<T, C>::~Tensor() {
  19. megdnn_free(m_handle, m_tensornd.raw_ptr());
  20. }
  21. template <typename T, typename C>
  22. T* Tensor<T, C>::ptr() {
  23. return m_tensornd.ptr<T>();
  24. }
  25. template <typename T, typename C>
  26. const T* Tensor<T, C>::ptr() const {
  27. return m_tensornd.ptr<T>();
  28. }
  29. template <typename T, typename C>
  30. TensorLayout Tensor<T, C>::layout() const {
  31. return m_tensornd.layout;
  32. }
  33. template <typename T, typename C>
  34. template <typename C_>
  35. void Tensor<T, C>::check_with(const Tensor<T, C_>& rhs) const {
  36. // compare layout
  37. ASSERT_TRUE(this->m_tensornd.layout.eq_layout(rhs.m_tensornd.layout))
  38. << "this->layout is " << this->m_tensornd.layout.to_string()
  39. << "rhs.layout is " << rhs.m_tensornd.layout.to_string();
  40. // compare value
  41. auto n = m_tensornd.layout.total_nr_elems();
  42. auto p0 = this->ptr(), p1 = rhs.ptr();
  43. for (size_t linear_idx = 0; linear_idx < n; ++linear_idx) {
  44. auto index = Index(m_tensornd.layout, linear_idx);
  45. auto offset = index.positive_offset();
  46. ASSERT_TRUE(m_comparator.is_same(p0[offset], p1[offset]))
  47. << "Index is " << index.to_string() << "; layout is "
  48. << m_tensornd.layout.to_string() << "; this->ptr()[offset] is "
  49. << this->ptr()[offset] << "; rhs.ptr()[offset] is "
  50. << rhs.ptr()[offset];
  51. }
  52. }
  53. template <typename T, typename C>
  54. SyncedTensor<T, C>::SyncedTensor(Handle* dev_handle, TensorLayout layout)
  55. : m_handle_host(create_cpu_handle(2, false)),
  56. m_handle_dev(dev_handle),
  57. m_tensor_host(m_handle_host.get(), layout),
  58. m_tensor_dev(m_handle_dev, layout),
  59. m_sync_state(SyncState::UNINITED) {}
  60. template <typename T, typename C>
  61. const T* SyncedTensor<T, C>::ptr_host() {
  62. ensure_host();
  63. return m_tensor_host.tensornd().template ptr<T>();
  64. }
  65. template <typename T, typename C>
  66. const T* SyncedTensor<T, C>::ptr_dev() {
  67. ensure_dev();
  68. return m_tensor_dev.tensornd().template ptr<T>();
  69. }
  70. template <typename T, typename C>
  71. T* SyncedTensor<T, C>::ptr_mutable_host() {
  72. ensure_host();
  73. m_sync_state = SyncState::HOST;
  74. return m_tensor_host.tensornd().template ptr<T>();
  75. }
  76. template <typename T, typename C>
  77. T* SyncedTensor<T, C>::ptr_mutable_dev() {
  78. ensure_dev();
  79. m_sync_state = SyncState::DEV;
  80. return m_tensor_dev.tensornd().template ptr<T>();
  81. }
  82. template <typename T, typename C>
  83. TensorND SyncedTensor<T, C>::tensornd_host() {
  84. ensure_host();
  85. m_sync_state = SyncState::HOST;
  86. return m_tensor_host.tensornd();
  87. }
  88. template <typename T, typename C>
  89. TensorND SyncedTensor<T, C>::tensornd_dev() {
  90. ensure_dev();
  91. m_sync_state = SyncState::DEV;
  92. return m_tensor_dev.tensornd();
  93. }
  94. template <typename T, typename C>
  95. TensorLayout SyncedTensor<T, C>::layout() const {
  96. return m_tensor_host.tensornd().layout;
  97. }
  98. template <typename T, typename C>
  99. template <typename C_>
  100. void SyncedTensor<T, C>::check_with(SyncedTensor<T, C_>& rhs) {
  101. this->ensure_host();
  102. rhs.ensure_host();
  103. this->m_tensor_host.check_with(rhs.m_tensor_host);
  104. }
  105. template <typename T, typename C>
  106. void SyncedTensor<T, C>::ensure_host() {
  107. if (m_sync_state == SyncState::HOST || m_sync_state == SyncState::SYNCED) {
  108. return;
  109. }
  110. if (m_sync_state == SyncState::DEV) {
  111. megdnn_memcpy_D2H(
  112. m_handle_dev, m_tensor_host.ptr(), m_tensor_dev.ptr(),
  113. m_tensor_host.layout().span().dist_byte());
  114. }
  115. m_sync_state = SyncState::SYNCED;
  116. }
  117. template <typename T, typename C>
  118. void SyncedTensor<T, C>::ensure_dev() {
  119. if (m_sync_state == SyncState::DEV || m_sync_state == SyncState::SYNCED) {
  120. return;
  121. }
  122. if (m_sync_state == SyncState::HOST) {
  123. megdnn_memcpy_H2D(
  124. m_handle_dev, m_tensor_dev.ptr(), m_tensor_host.ptr(),
  125. m_tensor_host.layout().span().dist_byte());
  126. }
  127. m_sync_state = SyncState::SYNCED;
  128. }
  129. } // namespace test
  130. } // namespace megdnn
  131. // vim: syntax=cpp.doxygen