You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #pragma once
  2. #include <gtest/gtest.h>
  3. #include "megdnn/basic_types.h"
  4. #include "megdnn/handle.h"
  5. #include <memory>
  6. #include "test/common/comparator.h"
  7. namespace megdnn {
  8. namespace test {
  9. /**
  10. * \brief A thin wrapper over megdnn::TensorND.
  11. *
  12. * dtype is determined by T; layout.dtype is ignored.
  13. */
  14. template <typename T = dt_float32, typename Comparator = DefaultComparator<T>>
  15. class Tensor {
  16. public:
  17. Tensor(Handle* handle, TensorLayout layout);
  18. ~Tensor();
  19. T* ptr();
  20. const T* ptr() const;
  21. TensorND tensornd() const { return m_tensornd; }
  22. TensorLayout layout() const;
  23. template <typename C>
  24. void check_with(const Tensor<T, C>& rhs) const;
  25. private:
  26. Handle* m_handle;
  27. TensorND m_tensornd;
  28. Comparator m_comparator;
  29. };
  30. /**
  31. * \brief A wrapper over host and device tensor.
  32. *
  33. * dtype is determined by T; layout.dtype is ignored.
  34. */
  35. template <typename T = dt_float32, typename Comparator = DefaultComparator<T>>
  36. class SyncedTensor {
  37. public:
  38. SyncedTensor(Handle* dev_handle, TensorLayout layout);
  39. SyncedTensor(Handle* dev_handle, const TensorShape& shape)
  40. : SyncedTensor(
  41. dev_handle,
  42. TensorLayout{shape, typename DTypeTrait<T>::dtype()}) {}
  43. ~SyncedTensor() = default;
  44. const T* ptr_host();
  45. const T* ptr_dev();
  46. T* ptr_mutable_host();
  47. T* ptr_mutable_dev();
  48. TensorND tensornd_host();
  49. TensorND tensornd_dev();
  50. TensorLayout layout() const;
  51. template <typename C>
  52. void check_with(SyncedTensor<T, C>& rhs);
  53. private:
  54. std::unique_ptr<Handle> m_handle_host;
  55. Handle* m_handle_dev;
  56. Tensor<T, Comparator> m_tensor_host, m_tensor_dev;
  57. enum class SyncState {
  58. HOST,
  59. DEV,
  60. SYNCED,
  61. UNINITED,
  62. } m_sync_state;
  63. void ensure_host();
  64. void ensure_dev();
  65. };
  66. //! make a device tensor on handle by value on host tensor
  67. std::shared_ptr<TensorND> make_tensor_h2d(Handle* handle, const TensorND& htensor);
  68. //! make a host tensor from device tensor on handle
  69. std::shared_ptr<TensorND> make_tensor_d2h(Handle* handle, const TensorND& dtensor);
  70. //! load tensors onto host from file (can be dumpped by megbrain tests)
  71. std::vector<std::shared_ptr<TensorND>> load_tensors(const char* fpath);
  72. void init_gaussian(
  73. SyncedTensor<dt_float32>& tensor, dt_float32 mean = 0.0f,
  74. dt_float32 stddev = 1.0f);
  75. } // namespace test
  76. } // namespace megdnn
  77. #include "test/common/tensor.inl"
  78. // vim: syntax=cpp.doxygen