You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exec_proxy.h 7.3 kB


  1. #pragma once
  2. #include "megdnn/basic_types.h"
  3. #include "test/common/workspace_wrapper.h"
  4. #include <cstddef>
  5. #include <vector>
  6. namespace megdnn {
  7. namespace test {
  8. template <typename Opr, size_t Arity, bool has_workspace>
  9. struct ExecProxy;
  10. template <typename Opr>
  11. struct ExecProxy<Opr, 13, true> {
  12. WorkspaceWrapper W;
  13. void exec(Opr* opr, const TensorNDArray& tensors) {
  14. if (!W.valid()) {
  15. W = WorkspaceWrapper(opr->handle(), 0);
  16. }
  17. W.update(opr->get_workspace_in_bytes(
  18. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  19. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  20. tensors[6].layout, tensors[7].layout, tensors[8].layout,
  21. tensors[9].layout, tensors[10].layout, tensors[11].layout,
  22. tensors[12].layout));
  23. opr->exec(
  24. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  25. tensors[6], tensors[7], tensors[8], tensors[9], tensors[10],
  26. tensors[11], tensors[12], W.workspace());
  27. }
  28. };
  29. template <typename Opr>
  30. struct ExecProxy<Opr, 10, true> {
  31. WorkspaceWrapper W;
  32. void exec(Opr* opr, const TensorNDArray& tensors) {
  33. if (!W.valid()) {
  34. W = WorkspaceWrapper(opr->handle(), 0);
  35. }
  36. W.update(opr->get_workspace_in_bytes(
  37. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  38. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  39. tensors[6].layout, tensors[7].layout, tensors[8].layout,
  40. tensors[9].layout));
  41. opr->exec(
  42. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  43. tensors[6], tensors[7], tensors[8], tensors[9], W.workspace());
  44. }
  45. };
  46. template <typename Opr>
  47. struct ExecProxy<Opr, 9, true> {
  48. WorkspaceWrapper W;
  49. void exec(Opr* opr, const TensorNDArray& tensors) {
  50. if (!W.valid()) {
  51. W = WorkspaceWrapper(opr->handle(), 0);
  52. }
  53. W.update(opr->get_workspace_in_bytes(
  54. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  55. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  56. tensors[6].layout, tensors[7].layout, tensors[8].layout));
  57. opr->exec(
  58. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  59. tensors[6], tensors[7], tensors[8], W.workspace());
  60. }
  61. };
  62. template <typename Opr>
  63. struct ExecProxy<Opr, 8, true> {
  64. WorkspaceWrapper W;
  65. void exec(Opr* opr, const TensorNDArray& tensors) {
  66. if (!W.valid()) {
  67. W = WorkspaceWrapper(opr->handle(), 0);
  68. }
  69. W.update(opr->get_workspace_in_bytes(
  70. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  71. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  72. tensors[6].layout, tensors[7].layout));
  73. opr->exec(
  74. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  75. tensors[6], tensors[7], W.workspace());
  76. }
  77. };
  78. template <typename Opr>
  79. struct ExecProxy<Opr, 7, true> {
  80. WorkspaceWrapper W;
  81. void exec(Opr* opr, const TensorNDArray& tensors) {
  82. if (!W.valid()) {
  83. W = WorkspaceWrapper(opr->handle(), 0);
  84. }
  85. W.update(opr->get_workspace_in_bytes(
  86. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  87. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  88. tensors[6].layout));
  89. opr->exec(
  90. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  91. tensors[6], W.workspace());
  92. }
  93. };
  94. template <typename Opr>
  95. struct ExecProxy<Opr, 6, true> {
  96. WorkspaceWrapper W;
  97. void exec(Opr* opr, const TensorNDArray& tensors) {
  98. if (!W.valid()) {
  99. W = WorkspaceWrapper(opr->handle(), 0);
  100. }
  101. W.update(opr->get_workspace_in_bytes(
  102. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  103. tensors[3].layout, tensors[4].layout, tensors[5].layout));
  104. opr->exec(
  105. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  106. W.workspace());
  107. }
  108. };
  109. template <typename Opr>
  110. struct ExecProxy<Opr, 5, true> {
  111. WorkspaceWrapper W;
  112. void exec(Opr* opr, const TensorNDArray& tensors) {
  113. if (!W.valid()) {
  114. W = WorkspaceWrapper(opr->handle(), 0);
  115. }
  116. W.update(opr->get_workspace_in_bytes(
  117. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  118. tensors[3].layout, tensors[4].layout));
  119. opr->exec(
  120. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  121. W.workspace());
  122. }
  123. };
  124. template <typename Opr>
  125. struct ExecProxy<Opr, 4, true> {
  126. WorkspaceWrapper W;
  127. void exec(Opr* opr, const TensorNDArray& tensors) {
  128. if (!W.valid()) {
  129. W = WorkspaceWrapper(opr->handle(), 0);
  130. }
  131. W.update(opr->get_workspace_in_bytes(
  132. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  133. tensors[3].layout));
  134. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], W.workspace());
  135. }
  136. };
  137. template <typename Opr>
  138. struct ExecProxy<Opr, 3, true> {
  139. WorkspaceWrapper W;
  140. void exec(Opr* opr, const TensorNDArray& tensors) {
  141. if (!W.valid()) {
  142. W = WorkspaceWrapper(opr->handle(), 0);
  143. }
  144. W.update(opr->get_workspace_in_bytes(
  145. tensors[0].layout, tensors[1].layout, tensors[2].layout));
  146. opr->exec(tensors[0], tensors[1], tensors[2], W.workspace());
  147. }
  148. };
  149. template <typename Opr>
  150. struct ExecProxy<Opr, 2, true> {
  151. WorkspaceWrapper W;
  152. void exec(Opr* opr, const TensorNDArray& tensors) {
  153. if (!W.valid()) {
  154. W = WorkspaceWrapper(opr->handle(), 0);
  155. }
  156. W.update(opr->get_workspace_in_bytes(tensors[0].layout, tensors[1].layout));
  157. opr->exec(tensors[0], tensors[1], W.workspace());
  158. }
  159. };
  160. template <typename Opr>
  161. struct ExecProxy<Opr, 1, true> {
  162. WorkspaceWrapper W;
  163. void exec(Opr* opr, const TensorNDArray& tensors) {
  164. if (!W.valid()) {
  165. W = WorkspaceWrapper(opr->handle(), 0);
  166. }
  167. W.update(opr->get_workspace_in_bytes(tensors[0].layout));
  168. opr->exec(tensors[0], W.workspace());
  169. }
  170. };
  171. template <typename Opr>
  172. struct ExecProxy<Opr, 5, false> {
  173. void exec(Opr* opr, const TensorNDArray& tensors) {
  174. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]);
  175. }
  176. };
  177. template <typename Opr>
  178. struct ExecProxy<Opr, 4, false> {
  179. void exec(Opr* opr, const TensorNDArray& tensors) {
  180. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3]);
  181. }
  182. };
  183. template <typename Opr>
  184. struct ExecProxy<Opr, 3, false> {
  185. void exec(Opr* opr, const TensorNDArray& tensors) {
  186. opr->exec(tensors[0], tensors[1], tensors[2]);
  187. }
  188. };
  189. template <typename Opr>
  190. struct ExecProxy<Opr, 2, false> {
  191. void exec(Opr* opr, const TensorNDArray& tensors) {
  192. opr->exec(tensors[0], tensors[1]);
  193. }
  194. };
  195. } // namespace test
  196. } // namespace megdnn
  197. // vim: syntax=cpp.doxygen