You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exec_proxy.h 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. /**
  2. * \file dnn/test/common/exec_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/workspace_wrapper.h"
  14. #include <cstddef>
  15. #include <vector>
  16. namespace megdnn {
  17. namespace test {
  18. template <typename Opr, size_t Arity, bool has_workspace>
  19. struct ExecProxy;
  20. template <typename Opr>
  21. struct ExecProxy<Opr, 13, true> {
  22. WorkspaceWrapper W;
  23. void exec(Opr* opr, const TensorNDArray& tensors) {
  24. if (!W.valid()) {
  25. W = WorkspaceWrapper(opr->handle(), 0);
  26. }
  27. W.update(opr->get_workspace_in_bytes(
  28. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  29. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  30. tensors[6].layout, tensors[7].layout, tensors[8].layout,
  31. tensors[9].layout, tensors[10].layout, tensors[11].layout,
  32. tensors[12].layout));
  33. opr->exec(
  34. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  35. tensors[6], tensors[7], tensors[8], tensors[9], tensors[10],
  36. tensors[11], tensors[12], W.workspace());
  37. }
  38. };
  39. template <typename Opr>
  40. struct ExecProxy<Opr, 10, true> {
  41. WorkspaceWrapper W;
  42. void exec(Opr* opr, const TensorNDArray& tensors) {
  43. if (!W.valid()) {
  44. W = WorkspaceWrapper(opr->handle(), 0);
  45. }
  46. W.update(opr->get_workspace_in_bytes(
  47. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  48. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  49. tensors[6].layout, tensors[7].layout, tensors[8].layout,
  50. tensors[9].layout));
  51. opr->exec(
  52. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  53. tensors[6], tensors[7], tensors[8], tensors[9], W.workspace());
  54. }
  55. };
  56. template <typename Opr>
  57. struct ExecProxy<Opr, 9, true> {
  58. WorkspaceWrapper W;
  59. void exec(Opr* opr, const TensorNDArray& tensors) {
  60. if (!W.valid()) {
  61. W = WorkspaceWrapper(opr->handle(), 0);
  62. }
  63. W.update(opr->get_workspace_in_bytes(
  64. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  65. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  66. tensors[6].layout, tensors[7].layout, tensors[8].layout));
  67. opr->exec(
  68. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  69. tensors[6], tensors[7], tensors[8], W.workspace());
  70. }
  71. };
  72. template <typename Opr>
  73. struct ExecProxy<Opr, 8, true> {
  74. WorkspaceWrapper W;
  75. void exec(Opr* opr, const TensorNDArray& tensors) {
  76. if (!W.valid()) {
  77. W = WorkspaceWrapper(opr->handle(), 0);
  78. }
  79. W.update(opr->get_workspace_in_bytes(
  80. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  81. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  82. tensors[6].layout, tensors[7].layout));
  83. opr->exec(
  84. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  85. tensors[6], tensors[7], W.workspace());
  86. }
  87. };
  88. template <typename Opr>
  89. struct ExecProxy<Opr, 7, true> {
  90. WorkspaceWrapper W;
  91. void exec(Opr* opr, const TensorNDArray& tensors) {
  92. if (!W.valid()) {
  93. W = WorkspaceWrapper(opr->handle(), 0);
  94. }
  95. W.update(opr->get_workspace_in_bytes(
  96. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  97. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  98. tensors[6].layout));
  99. opr->exec(
  100. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  101. tensors[6], W.workspace());
  102. }
  103. };
  104. template <typename Opr>
  105. struct ExecProxy<Opr, 6, true> {
  106. WorkspaceWrapper W;
  107. void exec(Opr* opr, const TensorNDArray& tensors) {
  108. if (!W.valid()) {
  109. W = WorkspaceWrapper(opr->handle(), 0);
  110. }
  111. W.update(opr->get_workspace_in_bytes(
  112. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  113. tensors[3].layout, tensors[4].layout, tensors[5].layout));
  114. opr->exec(
  115. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  116. W.workspace());
  117. }
  118. };
  119. template <typename Opr>
  120. struct ExecProxy<Opr, 5, true> {
  121. WorkspaceWrapper W;
  122. void exec(Opr* opr, const TensorNDArray& tensors) {
  123. if (!W.valid()) {
  124. W = WorkspaceWrapper(opr->handle(), 0);
  125. }
  126. W.update(opr->get_workspace_in_bytes(
  127. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  128. tensors[3].layout, tensors[4].layout));
  129. opr->exec(
  130. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  131. W.workspace());
  132. }
  133. };
  134. template <typename Opr>
  135. struct ExecProxy<Opr, 4, true> {
  136. WorkspaceWrapper W;
  137. void exec(Opr* opr, const TensorNDArray& tensors) {
  138. if (!W.valid()) {
  139. W = WorkspaceWrapper(opr->handle(), 0);
  140. }
  141. W.update(opr->get_workspace_in_bytes(
  142. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  143. tensors[3].layout));
  144. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], W.workspace());
  145. }
  146. };
  147. template <typename Opr>
  148. struct ExecProxy<Opr, 3, true> {
  149. WorkspaceWrapper W;
  150. void exec(Opr* opr, const TensorNDArray& tensors) {
  151. if (!W.valid()) {
  152. W = WorkspaceWrapper(opr->handle(), 0);
  153. }
  154. W.update(opr->get_workspace_in_bytes(
  155. tensors[0].layout, tensors[1].layout, tensors[2].layout));
  156. opr->exec(tensors[0], tensors[1], tensors[2], W.workspace());
  157. }
  158. };
  159. template <typename Opr>
  160. struct ExecProxy<Opr, 2, true> {
  161. WorkspaceWrapper W;
  162. void exec(Opr* opr, const TensorNDArray& tensors) {
  163. if (!W.valid()) {
  164. W = WorkspaceWrapper(opr->handle(), 0);
  165. }
  166. W.update(opr->get_workspace_in_bytes(tensors[0].layout, tensors[1].layout));
  167. opr->exec(tensors[0], tensors[1], W.workspace());
  168. }
  169. };
  170. template <typename Opr>
  171. struct ExecProxy<Opr, 1, true> {
  172. WorkspaceWrapper W;
  173. void exec(Opr* opr, const TensorNDArray& tensors) {
  174. if (!W.valid()) {
  175. W = WorkspaceWrapper(opr->handle(), 0);
  176. }
  177. W.update(opr->get_workspace_in_bytes(tensors[0].layout));
  178. opr->exec(tensors[0], W.workspace());
  179. }
  180. };
  181. template <typename Opr>
  182. struct ExecProxy<Opr, 5, false> {
  183. void exec(Opr* opr, const TensorNDArray& tensors) {
  184. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]);
  185. }
  186. };
  187. template <typename Opr>
  188. struct ExecProxy<Opr, 4, false> {
  189. void exec(Opr* opr, const TensorNDArray& tensors) {
  190. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3]);
  191. }
  192. };
  193. template <typename Opr>
  194. struct ExecProxy<Opr, 3, false> {
  195. void exec(Opr* opr, const TensorNDArray& tensors) {
  196. opr->exec(tensors[0], tensors[1], tensors[2]);
  197. }
  198. };
  199. template <typename Opr>
  200. struct ExecProxy<Opr, 2, false> {
  201. void exec(Opr* opr, const TensorNDArray& tensors) {
  202. opr->exec(tensors[0], tensors[1]);
  203. }
  204. };
  205. } // namespace test
  206. } // namespace megdnn
  207. // vim: syntax=cpp.doxygen