You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_algo_proxy.h 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "src/common/opr_trait.h"
  4. #include "test/common/utils.h"
  5. namespace megdnn {
  6. namespace test {
  7. template <typename Opr, size_t Arity>
  8. struct AlgoProxy;
  9. #define DEF_ALGO_PROXY(arity) \
  10. template <typename Opr> \
  11. struct AlgoProxy<Opr, arity> { \
  12. static std::vector<typename Opr::AlgorithmInfo> get_all_algorithms_info_safe( \
  13. Opr* opr, const TensorLayoutArray& layouts) { \
  14. megdnn_assert(layouts.size() == arity); \
  15. return opr->get_all_algorithms_info_safe(LAYOUTS); \
  16. } \
  17. static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
  18. Opr* opr, const TensorLayoutArray& layouts) { \
  19. megdnn_assert(layouts.size() == arity); \
  20. return opr->get_algorithm_info_heuristic(LAYOUTS); \
  21. } \
  22. static size_t get_workspace_in_bytes( \
  23. Opr* opr, const TensorLayoutArray& layouts) { \
  24. megdnn_assert(layouts.size() == arity); \
  25. return opr->get_workspace_in_bytes(LAYOUTS); \
  26. } \
  27. static void exec( \
  28. Opr* opr, const TensorNDArray& tensors, Workspace workspace) { \
  29. megdnn_assert(tensors.size() == arity); \
  30. return opr->exec(TENSORS, workspace); \
  31. } \
  32. }
  33. #define LAYOUTS layouts[0], layouts[1]
  34. #define TENSORS tensors[0], tensors[1]
  35. DEF_ALGO_PROXY(2);
  36. #undef LAYOUTS
  37. #undef TENSORS
  38. #define LAYOUTS layouts[0], layouts[1], layouts[2]
  39. #define TENSORS tensors[0], tensors[1], tensors[2]
  40. DEF_ALGO_PROXY(3);
  41. #undef LAYOUTS
  42. #undef TENSORS
  43. #define LAYOUTS layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]
  44. #define TENSORS tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]
  45. DEF_ALGO_PROXY(5);
  46. #undef LAYOUTS
  47. #undef TENSORS
  48. #define LAYOUTS \
  49. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5], \
  50. layouts[6], layouts[7]
  51. #define TENSORS \
  52. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5], \
  53. tensors[6], tensors[7]
  54. DEF_ALGO_PROXY(8);
  55. #undef LAYOUTS
  56. #undef TENSORS
  57. #undef DEF_ALGO_PROXY
  58. #define DEF_ALGO_PROXY(Opr, arity) \
  59. template <> \
  60. struct AlgoProxy<Opr, arity> { \
  61. static std::vector<typename Opr::AlgorithmInfo> get_all_algorithms_info_safe( \
  62. Opr* opr, const TensorLayoutArray& layouts) { \
  63. megdnn_assert(layouts.size() == arity); \
  64. return opr->get_all_algorithms_info_safe(LAYOUTS); \
  65. } \
  66. static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
  67. Opr* opr, const TensorLayoutArray& layouts) { \
  68. megdnn_assert(layouts.size() == arity); \
  69. return opr->get_algorithm_info_heuristic(LAYOUTS); \
  70. } \
  71. static size_t get_workspace_in_bytes( \
  72. Opr* opr, const TensorLayoutArray& layouts, \
  73. const typename Opr::PreprocessedFilter* preprocessed_filter = \
  74. nullptr) { \
  75. megdnn_assert(layouts.size() == arity); \
  76. return opr->get_workspace_in_bytes(LAYOUTS, preprocessed_filter); \
  77. } \
  78. static void exec( \
  79. Opr* opr, const TensorNDArray& tensors, \
  80. const typename Opr::PreprocessedFilter* preprocessed_filter, \
  81. Workspace workspace) { \
  82. megdnn_assert(tensors.size() == arity); \
  83. return opr->exec(TENSORS, preprocessed_filter, workspace); \
  84. } \
  85. static void exec( \
  86. Opr* opr, const TensorNDArray& tensors, Workspace workspace) { \
  87. megdnn_assert(tensors.size() == arity); \
  88. return opr->exec(TENSORS, nullptr, workspace); \
  89. } \
  90. static size_t get_preprocess_workspace_in_bytes( \
  91. Opr* opr, const TensorLayoutArray& layouts) { \
  92. megdnn_assert(layouts.size() == arity); \
  93. return opr->get_preprocess_workspace_in_bytes(LAYOUTS); \
  94. } \
  95. static SmallVector<TensorLayout> deduce_preprocessed_filter_layout( \
  96. Opr* opr, const TensorLayoutArray& layouts) { \
  97. megdnn_assert(layouts.size() == arity); \
  98. return opr->deduce_preprocessed_filter_layout(LAYOUTS); \
  99. } \
  100. static void exec_preprocess( \
  101. Opr* opr, const TensorNDArray& tensors, \
  102. const TensorLayoutArray& layouts, \
  103. Opr::PreprocessedFilter* preprocessed_filter, \
  104. _megdnn_workspace workspace) { \
  105. megdnn_assert(layouts.size() == arity && tensors.size() == arity); \
  106. return opr->exec_preprocess( \
  107. PREPROCESS_ARGS, preprocessed_filter, workspace); \
  108. } \
  109. };
  110. #define LAYOUTS layouts[0], layouts[1], layouts[2]
  111. #define TENSORS tensors[0], tensors[1], tensors[2]
  112. #define PREPROCESS_ARGS layouts[0], tensors[1], layouts[2]
  113. DEF_ALGO_PROXY(ConvolutionForward, 3);
  114. #undef PREPROCESS_ARGS
  115. #undef LAYOUTS
  116. #undef TENSORS
  117. #define LAYOUTS layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]
  118. #define TENSORS tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]
  119. #define PREPROCESS_ARGS layouts[0], tensors[1], tensors[2], layouts[3], layouts[4]
  120. DEF_ALGO_PROXY(ConvBias, 5);
  121. #undef PREPROCESS_ARGS
  122. #undef LAYOUTS
  123. #undef TENSORS
  124. #undef DEF_ALGO_PROXY
  125. template <typename Opr, size_t arity = OprTrait<Opr>::arity>
  126. struct OprAlgoProxyDefaultImpl : public AlgoProxy<Opr, arity> {};
  127. template <typename Opr>
  128. struct OprAlgoProxy : public OprAlgoProxyDefaultImpl<Opr> {};
  129. } // namespace test
  130. } // namespace megdnn
  131. // vim: syntax=cpp.doxygen