You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_local.cpp 1.2 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. #include "megdnn/oprs/nn.h"
  2. #include "test/common/checker.h"
  3. #include "test/cpu/fixture.h"
  4. namespace megdnn {
  5. namespace test {
  6. TEST_F(CPU, GROUP_LOCAL) {
  7. auto run = [&](size_t N, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW,
  8. size_t OC, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH,
  9. size_t SW, size_t group) {
  10. Checker<GroupLocal> checker(handle());
  11. GroupLocal::Param param;
  12. param.pad_h = PH;
  13. param.pad_w = PW;
  14. param.stride_h = SH;
  15. param.stride_w = SW;
  16. auto ICg = IC / group;
  17. auto OCg = OC / group;
  18. checker.set_param(param).exec(
  19. {{N, IC, IH, IW}, {group, OH, OW, ICg, FH, FW, OCg}, {}});
  20. };
  21. // simple groupped
  22. run(2, 6, 5, 5, 2, 2, 9, 4, 4, 0, 0, 1, 1, 3);
  23. // ungroupped
  24. run(1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1);
  25. // normal case
  26. run(2, 64, 7, 7, 3, 3, 32, 5, 5, 0, 0, 1, 1, 1);
  27. // padded and stridded case
  28. run(2, 32, 7, 7, 3, 3, 64, 9, 4, 2, 1, 1, 2, 4);
  29. // strided case with larger batch
  30. run(7, 32, 7, 7, 3, 3, 64, 3, 3, 0, 0, 2, 2, 8);
  31. }
  32. } // namespace test
  33. } // namespace megdnn
  34. // vim: syntax=cpp.doxygen