You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_local.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. /**
  2. * \file dnn/test/cuda/group_local.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/oprs/nn.h"
  12. #include "test/common/checker.h"
  13. #include "test/cuda/fixture.h"
  14. #if MEGDNN_WITH_BENCHMARK
  15. #include "test/common/benchmarker.h"
  16. #endif
  17. namespace megdnn {
  18. namespace test {
  19. TEST_F(CUDA, GROUP_LOCAL_FORWARD) {
  20. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
  21. size_t OC, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
  22. size_t SW, size_t group) {
  23. Checker<GroupLocal> checker(handle_cuda());
  24. GroupLocal::Param param;
  25. param.pad_h = PH;
  26. param.pad_w = PW;
  27. param.stride_h = SH;
  28. param.stride_w = SW;
  29. auto ICg = IC / group;
  30. auto OCg = OC / group;
  31. checker.set_param(param).exec(
  32. {{N, IC, IH, IW}, {group, OH, OW, ICg, FH, FW, OCg}, {}});
  33. };
  34. for (size_t IC = 1; IC <= 16; ++IC)
  35. for (size_t N = 1; N <= 8; ++N) {
  36. size_t group = 5;
  37. size_t H = 7, W = 7;
  38. size_t OC = 7;
  39. run(N, IC * group, H, W, 3, 3, OC * group, H, W, 1, 1, 1, 1, group);
  40. }
  41. for (size_t N : {2, 64}) {
  42. // normal case
  43. run(N, 64, 7, 7, 3, 3, 32, 5, 5, 0, 0, 1, 1, 2);
  44. // padded case
  45. run(N, 32, 7, 7, 3, 3, 64, 7, 7, 1, 1, 1, 1, 2);
  46. // strided case
  47. run(N, 64, 7, 7, 3, 3, 64, 3, 3, 0, 0, 2, 2, 4);
  48. }
  49. }
  50. #if MEGDNN_WITH_BENCHMARK
  51. TEST_F(CUDA, BENCHMARK_GROUP_LOCAL_FORWARD) {
  52. Benchmarker<GroupLocalForward> B(handle_cuda());
  53. B.execs({{2, 352, 4, 4}, {22, 4, 4, 16, 3, 3, 16}, {}});
  54. B.execs({{2, 192, 8, 8}, {12, 8, 8, 16, 3, 3, 16}, {}});
  55. B.execs({{2, 176, 4, 4}, {11, 4, 4, 16, 3, 3, 16}, {}});
  56. }
  57. #endif
  58. TEST_F(CUDA, GROUP_LOCAL_BACKWARD_DATA) {
  59. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
  60. size_t OC, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
  61. size_t SW, size_t group) {
  62. Checker<GroupLocalBackwardData> checker(handle_cuda());
  63. GroupLocal::Param param;
  64. param.pad_h = PH;
  65. param.pad_w = PW;
  66. param.stride_h = SH;
  67. param.stride_w = SW;
  68. auto ICg = IC / group;
  69. auto OCg = OC / group;
  70. checker.set_param(param).exec({
  71. {group, OH, OW, ICg, FH, FW, OCg},
  72. {N, OC, OH, OW},
  73. {N, IC, IH, IW},
  74. });
  75. };
  76. for (size_t N : {64}) {
  77. // normal case
  78. run(N, 64, 7, 7, 3, 3, 32, 5, 5, 0, 0, 1, 1, 2);
  79. // padded case
  80. run(N, 32, 7, 7, 3, 3, 64, 7, 7, 1, 1, 1, 1, 2);
  81. // strided case
  82. run(N, 64, 7, 7, 3, 3, 64, 3, 3, 0, 0, 2, 2, 4);
  83. }
  84. }
  85. TEST_F(CUDA, GROUP_LOCAL_BACKWARD_FILTER) {
  86. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
  87. size_t OC, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
  88. size_t SW, size_t group) {
  89. Checker<GroupLocalBackwardFilter> checker(handle_cuda());
  90. GroupLocal::Param param;
  91. param.pad_h = PH;
  92. param.pad_w = PW;
  93. param.stride_h = SH;
  94. param.stride_w = SW;
  95. auto ICg = IC / group;
  96. auto OCg = OC / group;
  97. checker.set_param(param).exec({
  98. {N, IC, IH, IW},
  99. {N, OC, OH, OW},
  100. {group, OH, OW, ICg, FH, FW, OCg},
  101. });
  102. };
  103. for (size_t N : {64}) {
  104. // normal case
  105. run(N, 64, 7, 7, 3, 3, 32, 5, 5, 0, 0, 1, 1, 2);
  106. // padded case
  107. run(N, 32, 7, 7, 3, 3, 64, 7, 7, 1, 1, 1, 1, 2);
  108. // strided case
  109. run(N, 64, 7, 7, 3, 3, 64, 3, 3, 0, 0, 2, 2, 4);
  110. }
  111. }
  112. } // namespace test
  113. } // namespace megdnn
  114. // vim: syntax=cpp.doxygen