You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

deduce_layout_proxy.h 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "test/common/utils.h"
  4. namespace megdnn {
  5. namespace test {
  6. template <typename Opr, size_t Arity, bool can_deduce_layout>
  7. struct DeduceLayoutProxy;
  8. template <typename Opr, size_t Arity>
  9. struct DeduceLayoutProxy<Opr, Arity, false> {
  10. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  11. };
  12. template <typename Opr>
  13. struct DeduceLayoutProxy<Opr, 2, true> {
  14. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  15. megdnn_assert(layouts.size() == 2);
  16. opr->deduce_layout(layouts[0], layouts[1]);
  17. }
  18. };
  19. template <typename Opr>
  20. struct DeduceLayoutProxy<Opr, 3, true> {
  21. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  22. megdnn_assert(layouts.size() == 3);
  23. opr->deduce_layout(layouts[0], layouts[1], layouts[2]);
  24. }
  25. };
  26. template <typename Opr>
  27. struct DeduceLayoutProxy<Opr, 4, true> {
  28. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  29. megdnn_assert(layouts.size() == 4);
  30. opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3]);
  31. }
  32. };
  33. template <typename Opr>
  34. struct DeduceLayoutProxy<Opr, 5, true> {
  35. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  36. megdnn_assert(layouts.size() == 5);
  37. opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]);
  38. }
  39. };
  40. template <typename Opr>
  41. struct DeduceLayoutProxy<Opr, 6, true> {
  42. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  43. megdnn_assert(layouts.size() == 6);
  44. opr->deduce_layout(
  45. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]);
  46. }
  47. };
  48. template <typename Opr>
  49. struct DeduceLayoutProxy<Opr, 5, false> {
  50. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  51. };
  52. template <typename Opr>
  53. struct DeduceLayoutProxy<Opr, 6, false> {
  54. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  55. };
  56. template <typename Opr>
  57. struct DeduceLayoutProxy<Opr, 7, false> {
  58. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  59. };
  60. template <typename Opr>
  61. struct DeduceLayoutProxy<Opr, 7, true> {
  62. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  63. megdnn_assert(layouts.size() == 7);
  64. opr->deduce_layout(
  65. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
  66. layouts[6]);
  67. }
  68. };
  69. template <typename Opr>
  70. struct DeduceLayoutProxy<Opr, 8, true> {
  71. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  72. megdnn_assert(layouts.size() == 8);
  73. opr->deduce_layout(
  74. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
  75. layouts[6], layouts[7]);
  76. }
  77. };
  78. template <typename Opr>
  79. struct DeduceLayoutProxy<Opr, 9, true> {
  80. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  81. megdnn_assert(layouts.size() == 9);
  82. opr->deduce_layout(
  83. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
  84. layouts[6], layouts[7], layouts[8]);
  85. }
  86. };
  87. template <typename Opr>
  88. struct DeduceLayoutProxy<Opr, 10, true> {
  89. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  90. megdnn_assert(layouts.size() == 10);
  91. opr->deduce_layout(
  92. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
  93. layouts[6], layouts[7], layouts[8], layouts[9]);
  94. }
  95. };
  96. template <typename Opr>
  97. struct DeduceLayoutProxy<Opr, 10, false> {
  98. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  99. };
  100. template <typename Opr>
  101. struct DeduceLayoutProxy<Opr, 13, true> {
  102. static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) {
  103. megdnn_assert(layouts.size() == 13);
  104. opr->deduce_layout(
  105. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5],
  106. layouts[6], layouts[7], layouts[8], layouts[9], layouts[10],
  107. layouts[11], layouts[12]);
  108. }
  109. };
  110. template <typename Opr>
  111. struct DeduceLayoutProxy<Opr, 13, false> {
  112. static void deduce_layout(Opr*, TensorLayoutArray&) {}
  113. };
  114. } // namespace test
  115. } // namespace megdnn
  116. // vim: syntax=cpp.doxygen