You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.inl 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /**
  2. * \file dnn/test/common/tensor.inl
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./tensor.h"
  12. #include <memory>
  13. #include "megdnn/basic_types.h"
  14. #include "test/common/get_dtype_from_static_type.h"
  15. #include "test/common/index.h"
  16. #include "test/common/utils.h"
  17. namespace megdnn {
  18. namespace test {
  19. template <typename T, typename C>
  20. Tensor<T, C>::Tensor(Handle* handle, TensorLayout layout)
  21. : m_handle(handle), m_comparator(C()) {
  22. if (!layout.dtype.valid())
  23. layout.dtype = get_dtype_from_static_type<T>();
  24. auto raw_ptr = megdnn_malloc(m_handle, layout.span().dist_byte());
  25. m_tensornd = TensorND{raw_ptr, layout};
  26. }
  27. template <typename T, typename C>
  28. Tensor<T, C>::~Tensor() {
  29. megdnn_free(m_handle, m_tensornd.raw_ptr());
  30. }
  31. template <typename T, typename C>
  32. T* Tensor<T, C>::ptr() {
  33. return m_tensornd.ptr<T>();
  34. }
  35. template <typename T, typename C>
  36. const T* Tensor<T, C>::ptr() const {
  37. return m_tensornd.ptr<T>();
  38. }
  39. template <typename T, typename C>
  40. TensorLayout Tensor<T, C>::layout() const {
  41. return m_tensornd.layout;
  42. }
  43. template <typename T, typename C>
  44. template <typename C_>
  45. void Tensor<T, C>::check_with(const Tensor<T, C_>& rhs) const {
  46. // compare layout
  47. ASSERT_TRUE(this->m_tensornd.layout.eq_layout(rhs.m_tensornd.layout))
  48. << "this->layout is " << this->m_tensornd.layout.to_string()
  49. << "rhs.layout is " << rhs.m_tensornd.layout.to_string();
  50. // compare value
  51. auto n = m_tensornd.layout.total_nr_elems();
  52. auto p0 = this->ptr(), p1 = rhs.ptr();
  53. for (size_t linear_idx = 0; linear_idx < n; ++linear_idx) {
  54. auto index = Index(m_tensornd.layout, linear_idx);
  55. auto offset = index.positive_offset();
  56. ASSERT_TRUE(m_comparator.is_same(p0[offset], p1[offset]))
  57. << "Index is " << index.to_string() << "; layout is "
  58. << m_tensornd.layout.to_string() << "; this->ptr()[offset] is "
  59. << this->ptr()[offset] << "; rhs.ptr()[offset] is "
  60. << rhs.ptr()[offset];
  61. }
  62. }
  63. template <typename T, typename C>
  64. SyncedTensor<T, C>::SyncedTensor(Handle* dev_handle, TensorLayout layout)
  65. : m_handle_host(create_cpu_handle(2, false)),
  66. m_handle_dev(dev_handle),
  67. m_tensor_host(m_handle_host.get(), layout),
  68. m_tensor_dev(m_handle_dev, layout),
  69. m_sync_state(SyncState::UNINITED) {}
  70. template <typename T, typename C>
  71. const T* SyncedTensor<T, C>::ptr_host() {
  72. ensure_host();
  73. return m_tensor_host.tensornd().template ptr<T>();
  74. }
  75. template <typename T, typename C>
  76. const T* SyncedTensor<T, C>::ptr_dev() {
  77. ensure_dev();
  78. return m_tensor_dev.tensornd().template ptr<T>();
  79. }
  80. template <typename T, typename C>
  81. T* SyncedTensor<T, C>::ptr_mutable_host() {
  82. ensure_host();
  83. m_sync_state = SyncState::HOST;
  84. return m_tensor_host.tensornd().template ptr<T>();
  85. }
  86. template <typename T, typename C>
  87. T* SyncedTensor<T, C>::ptr_mutable_dev() {
  88. ensure_dev();
  89. m_sync_state = SyncState::DEV;
  90. return m_tensor_dev.tensornd().template ptr<T>();
  91. }
  92. template <typename T, typename C>
  93. TensorND SyncedTensor<T, C>::tensornd_host() {
  94. ensure_host();
  95. m_sync_state = SyncState::HOST;
  96. return m_tensor_host.tensornd();
  97. }
  98. template <typename T, typename C>
  99. TensorND SyncedTensor<T, C>::tensornd_dev() {
  100. ensure_dev();
  101. m_sync_state = SyncState::DEV;
  102. return m_tensor_dev.tensornd();
  103. }
  104. template <typename T, typename C>
  105. TensorLayout SyncedTensor<T, C>::layout() const {
  106. return m_tensor_host.tensornd().layout;
  107. }
  108. template <typename T, typename C>
  109. template <typename C_>
  110. void SyncedTensor<T, C>::check_with(SyncedTensor<T, C_>& rhs) {
  111. this->ensure_host();
  112. rhs.ensure_host();
  113. this->m_tensor_host.check_with(rhs.m_tensor_host);
  114. }
  115. template <typename T, typename C>
  116. void SyncedTensor<T, C>::ensure_host() {
  117. if (m_sync_state == SyncState::HOST || m_sync_state == SyncState::SYNCED) {
  118. return;
  119. }
  120. if (m_sync_state == SyncState::DEV) {
  121. megdnn_memcpy_D2H(
  122. m_handle_dev, m_tensor_host.ptr(), m_tensor_dev.ptr(),
  123. m_tensor_host.layout().span().dist_byte());
  124. }
  125. m_sync_state = SyncState::SYNCED;
  126. }
  127. template <typename T, typename C>
  128. void SyncedTensor<T, C>::ensure_dev() {
  129. if (m_sync_state == SyncState::DEV || m_sync_state == SyncState::SYNCED) {
  130. return;
  131. }
  132. if (m_sync_state == SyncState::HOST) {
  133. megdnn_memcpy_H2D(
  134. m_handle_dev, m_tensor_dev.ptr(), m_tensor_host.ptr(),
  135. m_tensor_host.layout().span().dist_byte());
  136. }
  137. m_sync_state = SyncState::SYNCED;
  138. }
  139. } // namespace test
  140. } // namespace megdnn
  141. // vim: syntax=cpp.doxygen