You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

local.cpp 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. /**
  2. * \file dnn/test/arm_common/local.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/arm_common/fixture.h"
  12. #include "test/common/benchmarker.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/local.h"
  15. #include "test/common/task_record_check.h"
  16. #include "test/common/timer.h"
  17. namespace megdnn {
  18. namespace test {
  19. using Param = param::Convolution;
  20. TEST_F(ARM_COMMON, LOCAL_FORWARD) {
  21. auto args = local::get_args();
  22. Checker<LocalForward> checker(handle());
  23. for (auto&& arg : args) {
  24. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  25. }
  26. NormalRNG rng(10.f);
  27. checker.set_rng(0, &rng).set_rng(1, &rng);
  28. args = local::get_args_for_fp16();
  29. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  30. for (auto&& arg : args) {
  31. checker.set_dtype(0, dtype::Float16())
  32. .set_dtype(1, dtype::Float16())
  33. .set_dtype(2, dtype::Float16());
  34. checker.set_epsilon(1e-2);
  35. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  36. }
  37. #endif
  38. }
  39. TEST_F(ARM_COMMON, LOCAL_FORWARD_RECORD) {
  40. auto args = local::get_args();
  41. TaskRecordChecker<LocalForward> checker(0);
  42. for (auto&& arg : args) {
  43. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  44. }
  45. NormalRNG rng(10.f);
  46. checker.set_rng(0, &rng).set_rng(1, &rng);
  47. args = local::get_args_for_fp16();
  48. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  49. for (auto&& arg : args) {
  50. checker.set_dtype(0, dtype::Float16())
  51. .set_dtype(1, dtype::Float16())
  52. .set_dtype(2, dtype::Float16());
  53. checker.set_epsilon(1e-2);
  54. checker.set_param(arg.param).execs({arg.sshape(), arg.fshape(), arg.dshape()});
  55. }
  56. #endif
  57. }
  58. #if MEGDNN_WITH_BENCHMARK
  59. TEST_F(ARM_COMMON, BENCHMARK_LOCAL_FORWARD) {
  60. auto run = [&](const TensorShapeArray& shapes, Param param) {
  61. Benchmarker<LocalForward> benchmarker(handle());
  62. size_t RUN = 50;
  63. benchmarker.set_dtype(0, dtype::Float32())
  64. .set_dtype(1, dtype::Float32())
  65. .set_dtype(2, dtype::Float32());
  66. auto tfloat32 =
  67. benchmarker.set_display(true).set_times(RUN).set_param(param).exec(
  68. shapes);
  69. int N = shapes[0][0];
  70. int IC = shapes[0][1];
  71. int IH = shapes[0][2];
  72. int IW = shapes[0][3];
  73. int OH = shapes[1][0];
  74. int OW = shapes[1][1];
  75. int FH = shapes[1][3];
  76. int FW = shapes[1][4];
  77. int OC = shapes[1][5];
  78. std::cout << "LOCAL FORWARD, src: {" << N << ", " << IC << ", " << IH << ", "
  79. << IW << "}" << std::endl;
  80. std::cout << "LOCAL FORWARD, filter: {" << OH << ", " << OW << ", " << IC
  81. << ", " << FH << ", " << FW << ", " << OC << "}" << std::endl;
  82. std::cout << "LOCAL FORWARD (f32), bandwidth: "
  83. << (1.f * N * OC * OH * OW * FH * FW * IC + 1.f * N * IC * IH * IW) *
  84. sizeof(float) * 1e-9 / (tfloat32 / RUN * 1e-3)
  85. << "GBPS" << std::endl;
  86. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  87. benchmarker.set_dtype(0, dtype::Float16())
  88. .set_dtype(1, dtype::Float16())
  89. .set_dtype(2, dtype::Float16());
  90. auto tfloat16 =
  91. benchmarker.set_display(true).set_times(RUN).set_param(param).exec(
  92. shapes);
  93. std::cout << "LOCAL FORWARD (f16), bandwidth: "
  94. << (1.f * N * OC * OH * OW * FH * FW * IC + 1.f * N * IC * IH * IW) *
  95. sizeof(dt_float16) * 1e-9 / (tfloat16 / RUN * 1e-3)
  96. << "GBPS" << std::endl;
  97. #endif
  98. };
  99. Param param;
  100. param.mode = param::Convolution::Mode::CONVOLUTION;
  101. param.pad_h = param.pad_w = 1;
  102. param.stride_h = param.stride_w = 1;
  103. run({{1, 4, 320, 256}, {320, 256, 4, 3, 3, 24}, {}}, param);
  104. param.stride_h = param.stride_w = 2;
  105. run({{1, 4, 320, 256}, {160, 128, 4, 3, 3, 24}, {}}, param);
  106. param.pad_h = param.pad_w = 2;
  107. param.stride_h = param.stride_w = 1;
  108. run({{1, 4, 64, 64}, {64, 64, 4, 5, 5, 24}, {}}, param);
  109. param.stride_h = param.stride_w = 2;
  110. run({{1, 4, 64, 64}, {32, 32, 4, 5, 5, 24}, {}}, param);
  111. param.pad_h = param.pad_w = 3;
  112. param.stride_h = param.stride_w = 1;
  113. run({{1, 4, 64, 64}, {64, 64, 4, 7, 7, 24}, {}}, param);
  114. param.stride_h = param.stride_w = 2;
  115. run({{1, 4, 64, 64}, {32, 32, 4, 7, 7, 24}, {}}, param);
  116. param.pad_h = param.pad_w = 1;
  117. param.stride_h = param.stride_w = 1;
  118. run({{2, 128, 8, 8}, {8, 8, 128, 3, 3, 128}, {}}, param);
  119. run({{1, 16, 64, 64}, {64, 64, 16, 3, 3, 16}, {}}, param);
  120. }
  121. #endif
  122. } // namespace test
  123. } // namespace megdnn
  124. // vim: syntax=cpp.doxygen