You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.h 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. /**
  2. * \file dnn/test/common/tensor.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <gtest/gtest.h>
  13. #include "megdnn/basic_types.h"
  14. #include "megdnn/handle.h"
  15. #include <memory>
  16. #include "test/common/comparator.h"
  17. namespace megdnn {
  18. namespace test {
  19. /**
  20. * \brief A thin wrapper over megdnn::TensorND.
  21. *
  22. * dtype is determined by T; layout.dtype is ignored.
  23. */
  24. template <typename T = dt_float32, typename Comparator = DefaultComparator<T>>
  25. class Tensor {
  26. public:
  27. Tensor(Handle* handle, TensorLayout layout);
  28. ~Tensor();
  29. T* ptr();
  30. const T* ptr() const;
  31. TensorND tensornd() const { return m_tensornd; }
  32. TensorLayout layout() const;
  33. template <typename C>
  34. void check_with(const Tensor<T, C>& rhs) const;
  35. private:
  36. Handle* m_handle;
  37. TensorND m_tensornd;
  38. Comparator m_comparator;
  39. };
  40. /**
  41. * \brief A wrapper over host and device tensor.
  42. *
  43. * dtype is determined by T; layout.dtype is ignored.
  44. */
  45. template <typename T = dt_float32, typename Comparator = DefaultComparator<T>>
  46. class SyncedTensor {
  47. public:
  48. SyncedTensor(Handle* dev_handle, TensorLayout layout);
  49. SyncedTensor(Handle* dev_handle, const TensorShape& shape)
  50. : SyncedTensor(
  51. dev_handle,
  52. TensorLayout{shape, typename DTypeTrait<T>::dtype()}) {}
  53. ~SyncedTensor() = default;
  54. const T* ptr_host();
  55. const T* ptr_dev();
  56. T* ptr_mutable_host();
  57. T* ptr_mutable_dev();
  58. TensorND tensornd_host();
  59. TensorND tensornd_dev();
  60. TensorLayout layout() const;
  61. template <typename C>
  62. void check_with(SyncedTensor<T, C>& rhs);
  63. private:
  64. std::unique_ptr<Handle> m_handle_host;
  65. Handle* m_handle_dev;
  66. Tensor<T, Comparator> m_tensor_host, m_tensor_dev;
  67. enum class SyncState {
  68. HOST,
  69. DEV,
  70. SYNCED,
  71. UNINITED,
  72. } m_sync_state;
  73. void ensure_host();
  74. void ensure_dev();
  75. };
  76. //! make a device tensor on handle by value on host tensor
  77. std::shared_ptr<TensorND> make_tensor_h2d(Handle* handle, const TensorND& htensor);
  78. //! make a host tensor from device tensor on handle
  79. std::shared_ptr<TensorND> make_tensor_d2h(Handle* handle, const TensorND& dtensor);
  80. //! load tensors onto host from file (can be dumpped by megbrain tests)
  81. std::vector<std::shared_ptr<TensorND>> load_tensors(const char* fpath);
  82. void init_gaussian(
  83. SyncedTensor<dt_float32>& tensor, dt_float32 mean = 0.0f,
  84. dt_float32 stddev = 1.0f);
  85. } // namespace test
  86. } // namespace megdnn
  87. #include "test/common/tensor.inl"
  88. // vim: syntax=cpp.doxygen