You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exec_proxy.h 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /**
  2. * \file dnn/test/common/exec_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/workspace_wrapper.h"
  14. #include <cstddef>
  15. #include <vector>
  16. namespace megdnn {
  17. namespace test {
  18. template <typename Opr, size_t Arity, bool has_workspace>
  19. struct ExecProxy;
  20. template <typename Opr>
  21. struct ExecProxy<Opr, 9, true> {
  22. WorkspaceWrapper W;
  23. void exec(Opr* opr, const TensorNDArray& tensors) {
  24. if (!W.valid()) {
  25. W = WorkspaceWrapper(opr->handle(), 0);
  26. }
  27. W.update(opr->get_workspace_in_bytes(
  28. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  29. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  30. tensors[6].layout, tensors[7].layout, tensors[8].layout));
  31. opr->exec(
  32. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  33. tensors[6], tensors[7], tensors[8], W.workspace());
  34. }
  35. };
  36. template <typename Opr>
  37. struct ExecProxy<Opr, 8, true> {
  38. WorkspaceWrapper W;
  39. void exec(Opr* opr, const TensorNDArray& tensors) {
  40. if (!W.valid()) {
  41. W = WorkspaceWrapper(opr->handle(), 0);
  42. }
  43. W.update(opr->get_workspace_in_bytes(
  44. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  45. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  46. tensors[6].layout, tensors[7].layout));
  47. opr->exec(
  48. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  49. tensors[6], tensors[7], W.workspace());
  50. }
  51. };
  52. template <typename Opr>
  53. struct ExecProxy<Opr, 7, true> {
  54. WorkspaceWrapper W;
  55. void exec(Opr* opr, const TensorNDArray& tensors) {
  56. if (!W.valid()) {
  57. W = WorkspaceWrapper(opr->handle(), 0);
  58. }
  59. W.update(opr->get_workspace_in_bytes(
  60. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  61. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  62. tensors[6].layout));
  63. opr->exec(
  64. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  65. tensors[6], W.workspace());
  66. }
  67. };
  68. template <typename Opr>
  69. struct ExecProxy<Opr, 6, true> {
  70. WorkspaceWrapper W;
  71. void exec(Opr* opr, const TensorNDArray& tensors) {
  72. if (!W.valid()) {
  73. W = WorkspaceWrapper(opr->handle(), 0);
  74. }
  75. W.update(opr->get_workspace_in_bytes(
  76. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  77. tensors[3].layout, tensors[4].layout, tensors[5].layout));
  78. opr->exec(
  79. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], tensors[5],
  80. W.workspace());
  81. }
  82. };
  83. template <typename Opr>
  84. struct ExecProxy<Opr, 5, true> {
  85. WorkspaceWrapper W;
  86. void exec(Opr* opr, const TensorNDArray& tensors) {
  87. if (!W.valid()) {
  88. W = WorkspaceWrapper(opr->handle(), 0);
  89. }
  90. W.update(opr->get_workspace_in_bytes(
  91. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  92. tensors[3].layout, tensors[4].layout));
  93. opr->exec(
  94. tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  95. W.workspace());
  96. }
  97. };
  98. template <typename Opr>
  99. struct ExecProxy<Opr, 4, true> {
  100. WorkspaceWrapper W;
  101. void exec(Opr* opr, const TensorNDArray& tensors) {
  102. if (!W.valid()) {
  103. W = WorkspaceWrapper(opr->handle(), 0);
  104. }
  105. W.update(opr->get_workspace_in_bytes(
  106. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  107. tensors[3].layout));
  108. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], W.workspace());
  109. }
  110. };
  111. template <typename Opr>
  112. struct ExecProxy<Opr, 3, true> {
  113. WorkspaceWrapper W;
  114. void exec(Opr* opr, const TensorNDArray& tensors) {
  115. if (!W.valid()) {
  116. W = WorkspaceWrapper(opr->handle(), 0);
  117. }
  118. W.update(opr->get_workspace_in_bytes(
  119. tensors[0].layout, tensors[1].layout, tensors[2].layout));
  120. opr->exec(tensors[0], tensors[1], tensors[2], W.workspace());
  121. }
  122. };
  123. template <typename Opr>
  124. struct ExecProxy<Opr, 2, true> {
  125. WorkspaceWrapper W;
  126. void exec(Opr* opr, const TensorNDArray& tensors) {
  127. if (!W.valid()) {
  128. W = WorkspaceWrapper(opr->handle(), 0);
  129. }
  130. W.update(opr->get_workspace_in_bytes(tensors[0].layout, tensors[1].layout));
  131. opr->exec(tensors[0], tensors[1], W.workspace());
  132. }
  133. };
  134. template <typename Opr>
  135. struct ExecProxy<Opr, 1, true> {
  136. WorkspaceWrapper W;
  137. void exec(Opr* opr, const TensorNDArray& tensors) {
  138. if (!W.valid()) {
  139. W = WorkspaceWrapper(opr->handle(), 0);
  140. }
  141. W.update(opr->get_workspace_in_bytes(tensors[0].layout));
  142. opr->exec(tensors[0], W.workspace());
  143. }
  144. };
  145. template <typename Opr>
  146. struct ExecProxy<Opr, 5, false> {
  147. void exec(Opr* opr, const TensorNDArray& tensors) {
  148. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]);
  149. }
  150. };
  151. template <typename Opr>
  152. struct ExecProxy<Opr, 4, false> {
  153. void exec(Opr* opr, const TensorNDArray& tensors) {
  154. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3]);
  155. }
  156. };
  157. template <typename Opr>
  158. struct ExecProxy<Opr, 3, false> {
  159. void exec(Opr* opr, const TensorNDArray& tensors) {
  160. opr->exec(tensors[0], tensors[1], tensors[2]);
  161. }
  162. };
  163. template <typename Opr>
  164. struct ExecProxy<Opr, 2, false> {
  165. void exec(Opr* opr, const TensorNDArray& tensors) {
  166. opr->exec(tensors[0], tensors[1]);
  167. }
  168. };
  169. } // namespace test
  170. } // namespace megdnn
  171. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台