You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group_local.h 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #pragma once
  2. #include <cstddef>
  3. #include "megdnn/basic_types.h"
  4. #include "megdnn/opr_param_defs.h"
  5. namespace megdnn {
  6. namespace test {
  7. namespace group_local {
  8. struct TestArg {
  9. param::Convolution param;
  10. size_t n, ic, ih, iw, groups, ocpg, oh, ow, fh, fw;
  11. TestArg(param::Convolution param, size_t n, size_t ic, size_t ih, size_t iw,
  12. size_t groups, size_t ocpg, size_t oh, size_t ow, size_t fh, size_t fw)
  13. : param(param),
  14. n(n),
  15. ic(ic),
  16. ih(ih),
  17. iw(iw),
  18. groups(groups),
  19. ocpg(ocpg),
  20. oh(oh),
  21. ow(ow),
  22. fh(fh),
  23. fw(fw) {
  24. param.sparse = param::Convolution::Sparse::GROUP;
  25. }
  26. TensorShape sshape() const { return {n, ic, ih, iw}; }
  27. TensorShape fshape() const {
  28. size_t icpg = ic / groups;
  29. return {groups, oh, ow, icpg, fh, fw, ocpg};
  30. }
  31. TensorShape dshape() {
  32. size_t oc = ocpg * groups;
  33. return {n, oc, oh, ow};
  34. }
  35. };
  36. static inline std::vector<TestArg> get_args_for_fp16() {
  37. std::vector<TestArg> test_args;
  38. test_args.emplace_back(
  39. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 1, 1, 1, 1},
  40. 64, 16, 8, 7, 4, 4, 8, 7, 3, 3);
  41. test_args.emplace_back(
  42. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0, 0, 1, 1},
  43. 15, 15, 7, 7, 5, 3, 5, 5, 3, 3);
  44. test_args.emplace_back(
  45. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 1, 1, 1, 1},
  46. 15, 15, 5, 5, 5, 3, 5, 5, 3, 3);
  47. test_args.emplace_back(
  48. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0, 0, 2, 2},
  49. 15, 15, 7, 7, 5, 3, 3, 3, 3, 3);
  50. /*! \warning: this operator need reduce values along the axis of IC, so this
  51. * will results in large error in fp16 situation. so in the test cases, we
  52. * use small IC values.
  53. */
  54. // clang-format off
  55. for (size_t N: {1, 2})
  56. for (size_t OC: {16, 32, 48, 64})
  57. {
  58. test_args.emplace_back(
  59. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION,
  60. 0, 0, 1, 1},
  61. N, 16, 7, 7, 4, OC / 4, 5, 5, 3, 3);
  62. }
  63. // clang-format on
  64. return test_args;
  65. }
  66. static inline std::vector<TestArg> get_args() {
  67. std::vector<TestArg> test_args;
  68. test_args.emplace_back(
  69. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 1, 1, 1, 1},
  70. 64, 16, 8, 7, 4, 4, 8, 7, 3, 3);
  71. test_args.emplace_back(
  72. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0, 0, 1, 1},
  73. 15, 15, 7, 7, 5, 3, 5, 5, 3, 3);
  74. test_args.emplace_back(
  75. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 1, 1, 1, 1},
  76. 15, 15, 5, 5, 5, 3, 5, 5, 3, 3);
  77. test_args.emplace_back(
  78. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION, 0, 0, 2, 2},
  79. 15, 15, 7, 7, 5, 3, 3, 3, 3, 3);
  80. // clang-format off
  81. for (size_t N: {1, 2})
  82. for (size_t OC: {16, 32, 48, 64})
  83. {
  84. test_args.emplace_back(
  85. param::Convolution{param::Convolution::Mode::CROSS_CORRELATION,
  86. 0, 0, 1, 1},
  87. N, 32, 7, 7, 4, OC / 4, 5, 5, 3, 3);
  88. }
  89. // clang-format on
  90. return test_args;
  91. }
  92. } // namespace group_local
  93. } // namespace test
  94. } // namespace megdnn
  95. // vim: syntax=cpp.doxygen