You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_local.cpp 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #include "megdnn/oprs/nn.h"
  2. #include "test/common/checker.h"
  3. #include "test/cuda/fixture.h"
  4. #if MEGDNN_WITH_BENCHMARK
  5. #include "test/common/benchmarker.h"
  6. #endif
  7. namespace megdnn {
  8. namespace test {
  9. TEST_F(CUDA, GROUP_LOCAL_FORWARD) {
  10. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
  11. size_t OC, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
  12. size_t SW, size_t group) {
  13. Checker<GroupLocal> checker(handle_cuda());
  14. GroupLocal::Param param;
  15. param.pad_h = PH;
  16. param.pad_w = PW;
  17. param.stride_h = SH;
  18. param.stride_w = SW;
  19. auto ICg = IC / group;
  20. auto OCg = OC / group;
  21. checker.set_param(param).exec(
  22. {{N, IC, IH, IW}, {group, OH, OW, ICg, FH, FW, OCg}, {}});
  23. };
  24. for (size_t IC = 1; IC <= 16; ++IC)
  25. for (size_t N = 1; N <= 8; ++N) {
  26. size_t group = 5;
  27. size_t H = 7, W = 7;
  28. size_t OC = 7;
  29. run(N, IC * group, H, W, 3, 3, OC * group, H, W, 1, 1, 1, 1, group);
  30. }
  31. for (size_t N : {2, 64}) {
  32. // normal case
  33. run(N, 64, 7, 7, 3, 3, 32, 5, 5, 0, 0, 1, 1, 2);
  34. // padded case
  35. run(N, 32, 7, 7, 3, 3, 64, 7, 7, 1, 1, 1, 1, 2);
  36. // strided case
  37. run(N, 64, 7, 7, 3, 3, 64, 3, 3, 0, 0, 2, 2, 4);
  38. }
  39. }
  40. #if MEGDNN_WITH_BENCHMARK
  41. TEST_F(CUDA, BENCHMARK_GROUP_LOCAL_FORWARD) {
  42. Benchmarker<GroupLocalForward> B(handle_cuda());
  43. B.execs({{2, 352, 4, 4}, {22, 4, 4, 16, 3, 3, 16}, {}});
  44. B.execs({{2, 192, 8, 8}, {12, 8, 8, 16, 3, 3, 16}, {}});
  45. B.execs({{2, 176, 4, 4}, {11, 4, 4, 16, 3, 3, 16}, {}});
  46. }
  47. #endif
  48. TEST_F(CUDA, GROUP_LOCAL_BACKWARD_DATA) {
  49. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
  50. size_t OC, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
  51. size_t SW, size_t group) {
  52. Checker<GroupLocalBackwardData> checker(handle_cuda());
  53. GroupLocal::Param param;
  54. param.pad_h = PH;
  55. param.pad_w = PW;
  56. param.stride_h = SH;
  57. param.stride_w = SW;
  58. auto ICg = IC / group;
  59. auto OCg = OC / group;
  60. checker.set_param(param).exec({
  61. {group, OH, OW, ICg, FH, FW, OCg},
  62. {N, OC, OH, OW},
  63. {N, IC, IH, IW},
  64. });
  65. };
  66. for (size_t N : {64}) {
  67. // normal case
  68. run(N, 64, 7, 7, 3, 3, 32, 5, 5, 0, 0, 1, 1, 2);
  69. // padded case
  70. run(N, 32, 7, 7, 3, 3, 64, 7, 7, 1, 1, 1, 1, 2);
  71. // strided case
  72. run(N, 64, 7, 7, 3, 3, 64, 3, 3, 0, 0, 2, 2, 4);
  73. }
  74. }
  75. TEST_F(CUDA, GROUP_LOCAL_BACKWARD_FILTER) {
  76. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
  77. size_t OC, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
  78. size_t SW, size_t group) {
  79. Checker<GroupLocalBackwardFilter> checker(handle_cuda());
  80. GroupLocal::Param param;
  81. param.pad_h = PH;
  82. param.pad_w = PW;
  83. param.stride_h = SH;
  84. param.stride_w = SW;
  85. auto ICg = IC / group;
  86. auto OCg = OC / group;
  87. checker.set_param(param).exec({
  88. {N, IC, IH, IW},
  89. {N, OC, OH, OW},
  90. {group, OH, OW, ICg, FH, FW, OCg},
  91. });
  92. };
  93. for (size_t N : {64}) {
  94. // normal case
  95. run(N, 64, 7, 7, 3, 3, 32, 5, 5, 0, 0, 1, 1, 2);
  96. // padded case
  97. run(N, 32, 7, 7, 3, 3, 64, 7, 7, 1, 1, 1, 1, 2);
  98. // strided case
  99. run(N, 64, 7, 7, 3, 3, 64, 3, 3, 0, 0, 2, 2, 4);
  100. }
  101. }
  102. } // namespace test
  103. } // namespace megdnn
  104. // vim: syntax=cpp.doxygen