You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

local.cpp 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #include "test/arm_common/fixture.h"
  2. #include "test/common/benchmarker.h"
  3. #include "test/common/checker.h"
  4. #include "test/common/local.h"
  5. #include "test/common/task_record_check.h"
  6. #include "test/common/timer.h"
  7. namespace megdnn {
  8. namespace test {
  9. using Param = param::Convolution;
  10. TEST_F(ARM_COMMON, LOCAL_FORWARD) {
  11. auto args = local::get_args();
  12. Checker<LocalForward> checker(handle());
  13. for (auto&& arg : args) {
  14. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  15. }
  16. NormalRNG rng(10.f);
  17. checker.set_rng(0, &rng).set_rng(1, &rng);
  18. args = local::get_args_for_fp16();
  19. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  20. for (auto&& arg : args) {
  21. checker.set_dtype(0, dtype::Float16())
  22. .set_dtype(1, dtype::Float16())
  23. .set_dtype(2, dtype::Float16());
  24. checker.set_epsilon(1e-2);
  25. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  26. }
  27. #endif
  28. }
  29. TEST_F(ARM_COMMON, LOCAL_FORWARD_RECORD) {
  30. auto args = local::get_args();
  31. TaskRecordChecker<LocalForward> checker(0);
  32. for (auto&& arg : args) {
  33. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  34. }
  35. NormalRNG rng(10.f);
  36. checker.set_rng(0, &rng).set_rng(1, &rng);
  37. args = local::get_args_for_fp16();
  38. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  39. for (auto&& arg : args) {
  40. checker.set_dtype(0, dtype::Float16())
  41. .set_dtype(1, dtype::Float16())
  42. .set_dtype(2, dtype::Float16());
  43. checker.set_epsilon(1e-2);
  44. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  45. }
  46. #endif
  47. }
  48. #if MEGDNN_WITH_BENCHMARK
  49. TEST_F(ARM_COMMON, BENCHMARK_LOCAL_FORWARD) {
  50. auto run = [&](const TensorShapeArray& shapes, Param param) {
  51. Benchmarker<LocalForward> benchmarker(handle());
  52. size_t RUN = 50;
  53. benchmarker.set_dtype(0, dtype::Float32())
  54. .set_dtype(1, dtype::Float32())
  55. .set_dtype(2, dtype::Float32());
  56. auto tfloat32 =
  57. benchmarker.set_display(true).set_times(RUN).set_param(param).exec(
  58. shapes);
  59. int N = shapes[0][0];
  60. int IC = shapes[0][1];
  61. int IH = shapes[0][2];
  62. int IW = shapes[0][3];
  63. int OH = shapes[1][0];
  64. int OW = shapes[1][1];
  65. int FH = shapes[1][3];
  66. int FW = shapes[1][4];
  67. int OC = shapes[1][5];
  68. std::cout << "LOCAL FORWARD, src: {" << N << ", " << IC << ", " << IH << ", "
  69. << IW << "}" << std::endl;
  70. std::cout << "LOCAL FORWARD, filter: {" << OH << ", " << OW << ", " << IC
  71. << ", " << FH << ", " << FW << ", " << OC << "}" << std::endl;
  72. std::cout << "LOCAL FORWARD (f32), bandwidth: "
  73. << (1.f * N * OC * OH * OW * FH * FW * IC + 1.f * N * IC * IH * IW) *
  74. sizeof(float) * 1e-9 / (tfloat32 / RUN * 1e-3)
  75. << "GBPS" << std::endl;
  76. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  77. benchmarker.set_dtype(0, dtype::Float16())
  78. .set_dtype(1, dtype::Float16())
  79. .set_dtype(2, dtype::Float16());
  80. auto tfloat16 =
  81. benchmarker.set_display(true).set_times(RUN).set_param(param).exec(
  82. shapes);
  83. std::cout << "LOCAL FORWARD (f16), bandwidth: "
  84. << (1.f * N * OC * OH * OW * FH * FW * IC + 1.f * N * IC * IH * IW) *
  85. sizeof(dt_float16) * 1e-9 / (tfloat16 / RUN * 1e-3)
  86. << "GBPS" << std::endl;
  87. #endif
  88. };
  89. Param param;
  90. param.mode = param::Convolution::Mode::CONVOLUTION;
  91. param.pad_h = param.pad_w = 1;
  92. param.stride_h = param.stride_w = 1;
  93. run({{1, 4, 320, 256}, {320, 256, 4, 3, 3, 24}, {}}, param);
  94. param.stride_h = param.stride_w = 2;
  95. run({{1, 4, 320, 256}, {160, 128, 4, 3, 3, 24}, {}}, param);
  96. param.pad_h = param.pad_w = 2;
  97. param.stride_h = param.stride_w = 1;
  98. run({{1, 4, 64, 64}, {64, 64, 4, 5, 5, 24}, {}}, param);
  99. param.stride_h = param.stride_w = 2;
  100. run({{1, 4, 64, 64}, {32, 32, 4, 5, 5, 24}, {}}, param);
  101. param.pad_h = param.pad_w = 3;
  102. param.stride_h = param.stride_w = 1;
  103. run({{1, 4, 64, 64}, {64, 64, 4, 7, 7, 24}, {}}, param);
  104. param.stride_h = param.stride_w = 2;
  105. run({{1, 4, 64, 64}, {32, 32, 4, 7, 7, 24}, {}}, param);
  106. param.pad_h = param.pad_w = 1;
  107. param.stride_h = param.stride_w = 1;
  108. run({{2, 128, 8, 8}, {8, 8, 128, 3, 3, 128}, {}}, param);
  109. run({{1, 16, 64, 64}, {64, 64, 16, 3, 3, 16}, {}}, param);
  110. }
  111. #endif
  112. } // namespace test
  113. } // namespace megdnn
  114. // vim: syntax=cpp.doxygen