You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmarker.h 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "megdnn/tensor_format.h"
  4. #include "src/rocm/utils.h"
  5. #include "test/common/opr_proxy.h"
  6. #include "test/common/rng.h"
  7. #include "test/rocm/fixture.h"
  8. #include "hip_header.h"
  9. #include <map>
  10. namespace megdnn {
  11. namespace test {
  12. template <typename Opr>
  13. class ROCMBenchmarker {
  14. public:
  15. using Param = typename Opr::Param;
  16. ROCMBenchmarker(Handle* handle_rocm, Handle* handle_naive);
  17. const Handle* handle() const { return m_handle_rocm; }
  18. /*!
  19. * \brief benchmark opr on current param/dtype/rng config
  20. * \returns elapsed time in ms
  21. *
  22. * ROCMBenchmarker would construct TensorLayout vectors from shapes and
  23. * dtypes and call exec(TensorLayoutArray &).
  24. */
  25. float exec(const TensorShapeArray& shapes);
  26. float exec(TensorLayoutArray layouts);
  27. //! disabiguate overloaded exec
  28. float execs(const TensorShapeArray& shapes) { return exec(make_layouts(shapes)); }
  29. float execl(const TensorLayoutArray& layouts) { return exec(layouts); }
  30. ROCMBenchmarker& set_param(Param param) {
  31. m_param = param;
  32. return *this;
  33. }
  34. ROCMBenchmarker& set_dtype(size_t idx, DType dtype) {
  35. m_dtype[idx] = dtype;
  36. return *this;
  37. }
  38. ROCMBenchmarker& set_rng(size_t idx, RNG* rng) {
  39. m_rng[idx] = rng;
  40. return *this;
  41. }
  42. ROCMBenchmarker& set_proxy(const OprProxy<Opr>& proxy) {
  43. m_proxy = proxy;
  44. return *this;
  45. }
  46. ROCMBenchmarker& set_display(bool display) {
  47. m_display = display;
  48. return *this;
  49. }
  50. ROCMBenchmarker& set_fmt(size_t idx, TensorFormat fmt) {
  51. m_fmt[idx] = fmt;
  52. return *this;
  53. }
  54. TensorLayoutArray make_layouts(const TensorShapeArray& shapes) {
  55. TensorLayoutArray layouts(shapes.size());
  56. for (size_t i = 0; i < shapes.size(); ++i) {
  57. DType dt =
  58. (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] : dtype::Float32());
  59. TensorFormat fmt =
  60. (m_fmt.find(i) != m_fmt.end() ? m_fmt[i]
  61. : DefaultTensorFormat::make());
  62. layouts[i] = TensorLayout(shapes[i], dt, fmt);
  63. }
  64. return layouts;
  65. }
  66. private:
  67. class ROCMTimer {
  68. private:
  69. bool m_started, m_stopped;
  70. hipEvent_t m_event_start, m_event_end;
  71. hipStream_t m_stream;
  72. public:
  73. ROCMTimer() = delete;
  74. ROCMTimer(hipStream_t strm) : m_stream{strm} {
  75. hip_check(hipEventCreate(&m_event_start));
  76. hip_check(hipEventCreate(&m_event_end));
  77. reset();
  78. }
  79. ~ROCMTimer() {
  80. hip_check(hipEventDestroy(m_event_start));
  81. hip_check(hipEventDestroy(m_event_end));
  82. }
  83. void start() {
  84. megdnn_assert(!m_started);
  85. megdnn_assert(!m_stopped);
  86. m_started = true;
  87. hip_check(hipEventRecord(m_event_start, m_stream));
  88. }
  89. void stop() {
  90. megdnn_assert(m_started);
  91. megdnn_assert(!m_stopped);
  92. m_stopped = true;
  93. hip_check(hipEventRecord(m_event_end, m_stream));
  94. }
  95. float get_time_in_ms() const {
  96. megdnn_assert(m_started);
  97. megdnn_assert(m_stopped);
  98. hip_check(hipEventSynchronize(m_event_end));
  99. float ms;
  100. hip_check(hipEventElapsedTime(&ms, m_event_start, m_event_end));
  101. return ms;
  102. }
  103. void reset() {
  104. m_started = false;
  105. m_stopped = false;
  106. }
  107. };
  108. bool m_display = true;
  109. Handle* m_handle_naive;
  110. Handle* m_handle_rocm;
  111. std::unique_ptr<RNG> m_default_rng;
  112. std::map<size_t, RNG*> m_rng;
  113. std::map<size_t, DType> m_dtype;
  114. std::map<size_t, TensorFormat> m_fmt;
  115. Param m_param;
  116. OprProxy<Opr> m_proxy;
  117. ROCMTimer m_device_timer;
  118. };
  119. } // namespace test
  120. } // namespace megdnn
  121. #include "test/rocm/benchmarker.inl"
  122. // vim: syntax=cpp.doxygen