You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

benchmark.cpp 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /**
  2. * \file dnn/test/rocm/benchmark.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "hcc_detail/hcc_defs_prologue.h"
  12. #include "test/rocm/fixture.h"
  13. #include "megdnn/oprs.h"
  14. #include "src/rocm/utils.h"
  15. #include "test/common/benchmarker.h"
  16. #include "test/common/tensor.h"
  17. #include "test/common/timer.h"
  18. #include "test/common/workspace_wrapper.h"
  19. #include "test/rocm/benchmarker.h"
  20. namespace megdnn {
  21. namespace test {
  22. #if MEGDNN_WITH_BENCHMARK
  23. TEST_F(ROCM, REDUCE_BENCHMARK) {
  24. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  25. auto benchmarker =
  26. ROCMBenchmarker<ReduceForward>(handle_rocm(), handle_naive(false));
  27. auto run = [&](size_t A, size_t B, size_t C) {
  28. auto dtype = dtype::Float32();
  29. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  30. benchmarker.set_display(true);
  31. ReduceForward::Param param;
  32. param.axis = 1;
  33. benchmarker.set_param(param);
  34. // warm up
  35. benchmarker.execs({{A, B, C}, {}});
  36. // do actual benchmark
  37. auto time_ms = benchmarker.execs({{A, B, C}, {}});
  38. time_ms = benchmarker.execs({{A, B, C}, {}});
  39. auto io = (double)(A * B * C + A * C) * dtype.size();
  40. auto gbps = io / (time_ms * 1e6);
  41. printf("io %.2fGB, flops %.3fGB/s\n", io / 1e9, gbps);
  42. };
  43. run(65536, 64, 1);
  44. run(1, 268435455, 1);
  45. run(256, 1048575, 1);
  46. run(1, 1048575, 256);
  47. run(256, 4095, 256);
  48. }
  49. TEST_F(ROCM, BATCHED_MATRIX_MUL_BENCHMARK) {
  50. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  51. auto benchmarker = ROCMBenchmarker<BatchedMatrixMulForward>(
  52. handle_rocm(), handle_naive(false));
  53. auto run = [&](size_t b, size_t m, size_t n, size_t k) {
  54. auto dtype = dtype::Float32();
  55. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  56. benchmarker.set_display(true);
  57. // warm up
  58. benchmarker.execs({{b, m, k}, {b, k, n}, {}});
  59. // do actual benchmark
  60. auto time_ms = benchmarker.execs({{b, m, k}, {b, k, n}, {}});
  61. time_ms = benchmarker.execs({{b, m, k}, {b, k, n}, {}});
  62. double flo = 2.0 * b * m * n * k;
  63. double flops = flo / (time_ms * 1e9);
  64. printf("mxnxk=%zux%zux%zu flo %.2fGB, flops %.3fTFLOPS\n", m, n, k, flo / 1e9,
  65. flops);
  66. };
  67. run(32, 128, 128, 128);
  68. run(32, 256, 256, 256);
  69. run(32, 512, 512, 512);
  70. run(32, 1024, 1024, 1024);
  71. run(32, 4096, 4096, 4096);
  72. //! resnet50 fwd
  73. run(32, 12544, 1024, 256);
  74. run(32, 12544, 1024, 512);
  75. run(32, 12544, 256, 1024);
  76. run(32, 12544, 256, 512);
  77. run(32, 12544, 64, 147);
  78. run(32, 196, 256, 2304);
  79. run(32, 3025, 64, 576);
  80. run(32, 3136, 2048, 1024);
  81. run(32, 3136, 2048, 512);
  82. run(32, 3136, 512, 1024);
  83. run(32, 3136, 512, 2048);
  84. run(32, 3136, 64, 576);
  85. run(32, 49, 512, 4608);
  86. run(32, 50176, 128, 256);
  87. run(32, 50176, 512, 256);
  88. run(32, 784, 128, 1152);
  89. //! resnet50 bwdwrw
  90. run(32, 147, 64, 12544);
  91. }
  92. TEST_F(ROCM, MATRIX_MUL_BENCHMARK) {
  93. megdnn::rocm::enable_miopen_algo_search(handle_rocm(), true);
  94. auto benchmarker =
  95. ROCMBenchmarker<MatrixMulForward>(handle_rocm(), handle_naive(false));
  96. auto run = [&](size_t m, size_t n, size_t k) {
  97. auto dtype = dtype::Float32();
  98. benchmarker.set_dtype(0, dtype).set_dtype(1, dtype);
  99. benchmarker.set_display(true);
  100. // warm up
  101. benchmarker.execs({{m, k}, {k, n}, {}});
  102. // do actual benchmark
  103. auto time_ms = benchmarker.execs({{m, k}, {k, n}, {}});
  104. time_ms = benchmarker.execs({{m, k}, {k, n}, {}});
  105. double flo = 2.0 * m * n * k;
  106. double flops = flo / (time_ms * 1e9);
  107. printf("mxnxk=%zux%zux%zu flo %.2fGB, flops %.3fTFLOPS\n", m, n, k, flo / 1e9,
  108. flops);
  109. };
  110. run(128, 128, 128);
  111. run(256, 256, 256);
  112. run(512, 512, 512);
  113. run(1024, 1024, 1024);
  114. run(4096, 4096, 4096);
  115. //! resnet50 fwd
  116. run(12544, 1024, 256);
  117. run(12544, 1024, 512);
  118. run(12544, 256, 1024);
  119. run(12544, 256, 512);
  120. run(12544, 64, 147);
  121. run(196, 256, 2304);
  122. run(3025, 64, 576);
  123. run(3136, 2048, 1024);
  124. run(3136, 2048, 512);
  125. run(3136, 512, 1024);
  126. run(3136, 512, 2048);
  127. run(3136, 64, 576);
  128. run(49, 512, 4608);
  129. run(50176, 128, 256);
  130. run(50176, 512, 256);
  131. run(784, 128, 1152);
  132. //! resnet50 bwdwrw
  133. run(147, 64, 12544);
  134. }
  135. #endif
  136. } // namespace test
  137. } // namespace megdnn
  138. // vim: syntax=cpp.doxygen