You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

local.cpp 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. /**
  2. * \file dnn/test/arm_common/local.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "test/arm_common/fixture.h"
  12. #include "test/common/benchmarker.h"
  13. #include "test/common/checker.h"
  14. #include "test/common/local.h"
  15. #include "test/common/timer.h"
  16. namespace megdnn {
  17. namespace test {
  18. using Param = param::Convolution;
  19. TEST_F(ARM_COMMON, LOCAL_FORWARD) {
  20. auto args = local::get_args();
  21. Checker<LocalForward> checker(handle());
  22. for (auto&& arg : args) {
  23. checker.set_param(arg.param).execs(
  24. {arg.sshape(), arg.fshape(), arg.dshape()});
  25. }
  26. NormalRNG rng(10.f);
  27. checker.set_rng(0, &rng).set_rng(1, &rng);
  28. args = local::get_args_for_fp16();
  29. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  30. for (auto&& arg : args) {
  31. checker.set_dtype(0, dtype::Float16())
  32. .set_dtype(1, dtype::Float16())
  33. .set_dtype(2, dtype::Float16());
  34. checker.set_epsilon(1e-2);
  35. checker.set_param(arg.param).execs(
  36. {arg.sshape(), arg.fshape(), arg.dshape()});
  37. }
  38. #endif
  39. }
  40. #if MEGDNN_WITH_BENCHMARK
  41. TEST_F(ARM_COMMON, BENCHMARK_LOCAL_FORWARD) {
  42. auto run = [&](const TensorShapeArray& shapes, Param param) {
  43. Benchmarker<LocalForward> benchmarker(handle());
  44. size_t RUN = 50;
  45. benchmarker.set_dtype(0, dtype::Float32())
  46. .set_dtype(1, dtype::Float32())
  47. .set_dtype(2, dtype::Float32());
  48. auto tfloat32 = benchmarker.set_display(true)
  49. .set_times(RUN)
  50. .set_param(param)
  51. .exec(shapes);
  52. int N = shapes[0][0];
  53. int IC = shapes[0][1];
  54. int IH = shapes[0][2];
  55. int IW = shapes[0][3];
  56. int OH = shapes[1][0];
  57. int OW = shapes[1][1];
  58. int FH = shapes[1][3];
  59. int FW = shapes[1][4];
  60. int OC = shapes[1][5];
  61. std::cout << "LOCAL FORWARD, src: {" << N << ", " << IC << ", " << IH
  62. << ", " << IW << "}" << std::endl;
  63. std::cout << "LOCAL FORWARD, filter: {" << OH << ", " << OW << ", "
  64. << IC << ", " << FH << ", " << FW << ", " << OC << "}"
  65. << std::endl;
  66. std::cout << "LOCAL FORWARD (f32), bandwidth: "
  67. << (1.f * N * OC * OH * OW * FH * FW * IC +
  68. 1.f * N * IC * IH * IW) *
  69. sizeof(float) * 1e-9 / (tfloat32 / RUN * 1e-3)
  70. << "GBPS" << std::endl;
  71. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  72. benchmarker.set_dtype(0, dtype::Float16())
  73. .set_dtype(1, dtype::Float16())
  74. .set_dtype(2, dtype::Float16());
  75. auto tfloat16 = benchmarker.set_display(true)
  76. .set_times(RUN)
  77. .set_param(param)
  78. .exec(shapes);
  79. std::cout << "LOCAL FORWARD (f16), bandwidth: "
  80. << (1.f * N * OC * OH * OW * FH * FW * IC +
  81. 1.f * N * IC * IH * IW) *
  82. sizeof(dt_float16) * 1e-9 / (tfloat16 / RUN * 1e-3)
  83. << "GBPS" << std::endl;
  84. #endif
  85. };
  86. Param param;
  87. param.mode = param::Convolution::Mode::CONVOLUTION;
  88. param.pad_h = param.pad_w = 1;
  89. param.stride_h = param.stride_w = 1;
  90. run({{1, 4, 320, 256}, {320, 256, 4, 3, 3, 24}, {}}, param);
  91. param.stride_h = param.stride_w = 2;
  92. run({{1, 4, 320, 256}, {160, 128, 4, 3, 3, 24}, {}}, param);
  93. param.pad_h = param.pad_w = 2;
  94. param.stride_h = param.stride_w = 1;
  95. run({{1, 4, 64, 64}, {64, 64, 4, 5, 5, 24}, {}}, param);
  96. param.stride_h = param.stride_w = 2;
  97. run({{1, 4, 64, 64}, {32, 32, 4, 5, 5, 24}, {}}, param);
  98. param.pad_h = param.pad_w = 3;
  99. param.stride_h = param.stride_w = 1;
  100. run({{1, 4, 64, 64}, {64, 64, 4, 7, 7, 24}, {}}, param);
  101. param.stride_h = param.stride_w = 2;
  102. run({{1, 4, 64, 64}, {32, 32, 4, 7, 7, 24}, {}}, param);
  103. param.pad_h = param.pad_w = 1;
  104. param.stride_h = param.stride_w = 1;
  105. run({{2, 128, 8, 8}, {8, 8, 128, 3, 3, 128}, {}}, param);
  106. run({{1, 16, 64, 64}, {64, 64, 16, 3, 3, 16}, {}}, param);
  107. }
  108. #endif
  109. } // namespace test
  110. } // namespace megdnn
  111. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台