You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_algo_proxy.h 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /**
  2. * \file dnn/test/common/opr_algo_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/basic_types.h"
  14. #include "src/common/opr_trait.h"
  15. #include "test/common/utils.h"
  16. namespace megdnn {
  17. namespace test {
  18. template <typename Opr, size_t Arity>
  19. struct AlgoProxy;
  20. #define DEF_ALGO_PROXY(arity) \
  21. template <typename Opr> \
  22. struct AlgoProxy<Opr, arity> { \
  23. static std::vector<typename Opr::AlgorithmInfo> get_all_algorithms_info_safe( \
  24. Opr* opr, const TensorLayoutArray& layouts) { \
  25. megdnn_assert(layouts.size() == arity); \
  26. return opr->get_all_algorithms_info_safe(LAYOUTS); \
  27. } \
  28. static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
  29. Opr* opr, const TensorLayoutArray& layouts) { \
  30. megdnn_assert(layouts.size() == arity); \
  31. return opr->get_algorithm_info_heuristic(LAYOUTS); \
  32. } \
  33. static size_t get_workspace_in_bytes( \
  34. Opr* opr, const TensorLayoutArray& layouts) { \
  35. megdnn_assert(layouts.size() == arity); \
  36. return opr->get_workspace_in_bytes(LAYOUTS); \
  37. } \
  38. static void exec( \
  39. Opr* opr, const TensorNDArray& tensors, Workspace workspace) { \
  40. megdnn_assert(tensors.size() == arity); \
  41. return opr->exec(TENSORS, workspace); \
  42. } \
  43. }
  44. #define LAYOUTS layouts[0], layouts[1]
  45. #define TENSORS tensors[0], tensors[1]
  46. DEF_ALGO_PROXY(2);
  47. #undef LAYOUTS
  48. #undef TENSORS
  49. #define LAYOUTS layouts[0], layouts[1], layouts[2]
  50. #define TENSORS tensors[0], tensors[1], tensors[2]
  51. DEF_ALGO_PROXY(3);
  52. #undef LAYOUTS
  53. #undef TENSORS
  54. #define LAYOUTS layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]
  55. #define TENSORS tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]
  56. DEF_ALGO_PROXY(5);
  57. #undef LAYOUTS
  58. #undef TENSORS
  59. #define LAYOUTS \
  60. layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5], \
  61. layouts[6], layouts[7]
  62. #define TENSORS \
  63. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5], \
  64. tensors[6], tensors[7]
  65. DEF_ALGO_PROXY(8);
  66. #undef LAYOUTS
  67. #undef TENSORS
  68. #undef DEF_ALGO_PROXY
  69. #define DEF_ALGO_PROXY(Opr, arity) \
  70. template <> \
  71. struct AlgoProxy<Opr, arity> { \
  72. static std::vector<typename Opr::AlgorithmInfo> get_all_algorithms_info_safe( \
  73. Opr* opr, const TensorLayoutArray& layouts) { \
  74. megdnn_assert(layouts.size() == arity); \
  75. return opr->get_all_algorithms_info_safe(LAYOUTS); \
  76. } \
  77. static typename Opr::AlgorithmInfo get_algorithm_info_heuristic( \
  78. Opr* opr, const TensorLayoutArray& layouts) { \
  79. megdnn_assert(layouts.size() == arity); \
  80. return opr->get_algorithm_info_heuristic(LAYOUTS); \
  81. } \
  82. static size_t get_workspace_in_bytes( \
  83. Opr* opr, const TensorLayoutArray& layouts, \
  84. const typename Opr::PreprocessedFilter* preprocessed_filter = \
  85. nullptr) { \
  86. megdnn_assert(layouts.size() == arity); \
  87. return opr->get_workspace_in_bytes(LAYOUTS, preprocessed_filter); \
  88. } \
  89. static void exec( \
  90. Opr* opr, const TensorNDArray& tensors, \
  91. const typename Opr::PreprocessedFilter* preprocessed_filter, \
  92. Workspace workspace) { \
  93. megdnn_assert(tensors.size() == arity); \
  94. return opr->exec(TENSORS, preprocessed_filter, workspace); \
  95. } \
  96. static void exec( \
  97. Opr* opr, const TensorNDArray& tensors, Workspace workspace) { \
  98. megdnn_assert(tensors.size() == arity); \
  99. return opr->exec(TENSORS, nullptr, workspace); \
  100. } \
  101. static size_t get_preprocess_workspace_in_bytes( \
  102. Opr* opr, const TensorLayoutArray& layouts) { \
  103. megdnn_assert(layouts.size() == arity); \
  104. return opr->get_preprocess_workspace_in_bytes(LAYOUTS); \
  105. } \
  106. static SmallVector<TensorLayout> deduce_preprocessed_filter_layout( \
  107. Opr* opr, const TensorLayoutArray& layouts) { \
  108. megdnn_assert(layouts.size() == arity); \
  109. return opr->deduce_preprocessed_filter_layout(LAYOUTS); \
  110. } \
  111. static void exec_preprocess( \
  112. Opr* opr, const TensorNDArray& tensors, \
  113. const TensorLayoutArray& layouts, \
  114. Opr::PreprocessedFilter* preprocessed_filter, \
  115. _megdnn_workspace workspace) { \
  116. megdnn_assert(layouts.size() == arity && tensors.size() == arity); \
  117. return opr->exec_preprocess( \
  118. PREPROCESS_ARGS, preprocessed_filter, workspace); \
  119. } \
  120. };
  121. #define LAYOUTS layouts[0], layouts[1], layouts[2]
  122. #define TENSORS tensors[0], tensors[1], tensors[2]
  123. #define PREPROCESS_ARGS layouts[0], tensors[1], layouts[2]
  124. DEF_ALGO_PROXY(ConvolutionForward, 3);
  125. #undef PREPROCESS_ARGS
  126. #undef LAYOUTS
  127. #undef TENSORS
  128. #define LAYOUTS layouts[0], layouts[1], layouts[2], layouts[3], layouts[4]
  129. #define TENSORS tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]
  130. #define PREPROCESS_ARGS layouts[0], tensors[1], tensors[2], layouts[3], layouts[4]
  131. DEF_ALGO_PROXY(ConvBias, 5);
  132. #undef PREPROCESS_ARGS
  133. #undef LAYOUTS
  134. #undef TENSORS
  135. #undef DEF_ALGO_PROXY
  136. template <typename Opr, size_t arity = OprTrait<Opr>::arity>
  137. struct OprAlgoProxyDefaultImpl : public AlgoProxy<Opr, arity> {};
  138. template <typename Opr>
  139. struct OprAlgoProxy : public OprAlgoProxyDefaultImpl<Opr> {};
  140. } // namespace test
  141. } // namespace megdnn
  142. // vim: syntax=cpp.doxygen