You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark.cpp 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #include "hcc_detail/hcc_defs_prologue.h"
  2. #include "test/rocm/fixture.h"
  3. #include "megdnn/oprs.h"
  4. #include "src/rocm/utils.h"
  5. #include "test/common/benchmarker.h"
  6. #include "test/common/tensor.h"
  7. #include "test/common/timer.h"
  8. #include "test/common/workspace_wrapper.h"
  9. #include "test/rocm/benchmarker.h"
  10. namespace megdnn {
  11. namespace test {
  12. #if MEGDNN_WITH_BENCHMARK
  13. TEST_F(ROCM, REDUCE_BENCHMARK) {
  14. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  15. auto benchmarker =
  16. ROCMBenchmarker<ReduceForward>(handle_rocm(), handle_naive(false));
  17. auto run = [&](size_t A, size_t B, size_t C) {
  18. auto dtype = dtype::Float32();
  19. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  20. benchmarker.set_display(true);
  21. ReduceForward::Param param;
  22. param.axis = 1;
  23. benchmarker.set_param(param);
  24. // warm up
  25. benchmarker.execs({{A, B, C}, {}});
  26. // do actual benchmark
  27. auto time_ms = benchmarker.execs({{A, B, C}, {}});
  28. time_ms = benchmarker.execs({{A, B, C}, {}});
  29. auto io = (double)(A * B * C + A * C) * dtype.size();
  30. auto gbps = io / (time_ms * 1e6);
  31. printf("io %.2fGB, flops %.3fGB/s\n", io / 1e9, gbps);
  32. };
  33. run(65536, 64, 1);
  34. run(1, 268435455, 1);
  35. run(256, 1048575, 1);
  36. run(1, 1048575, 256);
  37. run(256, 4095, 256);
  38. }
  39. TEST_F(ROCM, BATCHED_MATRIX_MUL_BENCHMARK) {
  40. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  41. auto benchmarker = ROCMBenchmarker<BatchedMatrixMulForward>(
  42. handle_rocm(), handle_naive(false));
  43. auto run = [&](size_t b, size_t m, size_t n, size_t k) {
  44. auto dtype = dtype::Float32();
  45. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  46. benchmarker.set_display(true);
  47. // warm up
  48. benchmarker.execs({{b, m, k}, {b, k, n}, {}});
  49. // do actual benchmark
  50. auto time_ms = benchmarker.execs({{b, m, k}, {b, k, n}, {}});
  51. time_ms = benchmarker.execs({{b, m, k}, {b, k, n}, {}});
  52. double flo = 2.0 * b * m * n * k;
  53. double flops = flo / (time_ms * 1e9);
  54. printf("mxnxk=%zux%zux%zu flo %.2fGB, flops %.3fTFLOPS\n", m, n, k, flo / 1e9,
  55. flops);
  56. };
  57. run(32, 128, 128, 128);
  58. run(32, 256, 256, 256);
  59. run(32, 512, 512, 512);
  60. run(32, 1024, 1024, 1024);
  61. run(32, 4096, 4096, 4096);
  62. //! resnet50 fwd
  63. run(32, 12544, 1024, 256);
  64. run(32, 12544, 1024, 512);
  65. run(32, 12544, 256, 1024);
  66. run(32, 12544, 256, 512);
  67. run(32, 12544, 64, 147);
  68. run(32, 196, 256, 2304);
  69. run(32, 3025, 64, 576);
  70. run(32, 3136, 2048, 1024);
  71. run(32, 3136, 2048, 512);
  72. run(32, 3136, 512, 1024);
  73. run(32, 3136, 512, 2048);
  74. run(32, 3136, 64, 576);
  75. run(32, 49, 512, 4608);
  76. run(32, 50176, 128, 256);
  77. run(32, 50176, 512, 256);
  78. run(32, 784, 128, 1152);
  79. //! resnet50 bwdwrw
  80. run(32, 147, 64, 12544);
  81. }
  82. TEST_F(ROCM, MATRIX_MUL_BENCHMARK) {
  83. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  84. auto benchmarker =
  85. ROCMBenchmarker<MatrixMulForward>(handle_rocm(), handle_naive(false));
  86. auto run = [&](size_t m, size_t n, size_t k) {
  87. auto dtype = dtype::Float32();
  88. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  89. benchmarker.set_display(true);
  90. // warm up
  91. benchmarker.execs({{m, k}, {k, n}, {}});
  92. // do actual benchmark
  93. auto time_ms = benchmarker.execs({{m, k}, {k, n}, {}});
  94. time_ms = benchmarker.execs({{m, k}, {k, n}, {}});
  95. double flo = 2.0 * m * n * k;
  96. double flops = flo / (time_ms * 1e9);
  97. printf("mxnxk=%zux%zux%zu flo %.2fGB, flops %.3fTFLOPS\n", m, n, k, flo / 1e9,
  98. flops);
  99. };
  100. run(128, 128, 128);
  101. run(256, 256, 256);
  102. run(512, 512, 512);
  103. run(1024, 1024, 1024);
  104. run(4096, 4096, 4096);
  105. //! resnet50 fwd
  106. run(12544, 1024, 256);
  107. run(12544, 1024, 512);
  108. run(12544, 256, 1024);
  109. run(12544, 256, 512);
  110. run(12544, 64, 147);
  111. run(196, 256, 2304);
  112. run(3025, 64, 576);
  113. run(3136, 2048, 1024);
  114. run(3136, 2048, 512);
  115. run(3136, 512, 1024);
  116. run(3136, 512, 2048);
  117. run(3136, 64, 576);
  118. run(49, 512, 4608);
  119. run(50176, 128, 256);
  120. run(50176, 512, 256);
  121. run(784, 128, 1152);
  122. //! resnet50 bwdwrw
  123. run(147, 64, 12544);
  124. }
  125. #endif
  126. } // namespace test
  127. } // namespace megdnn
  128. // vim: syntax=cpp.doxygen