You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exec_proxy.h 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. /**
  2. * \file dnn/test/common/exec_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/workspace_wrapper.h"
  14. #include <cstddef>
  15. #include <vector>
  16. namespace megdnn {
  17. namespace test {
  18. template <typename Opr, size_t Arity, bool has_workspace>
  19. struct ExecProxy;
  20. template <typename Opr>
  21. struct ExecProxy<Opr, 9, true> {
  22. WorkspaceWrapper W;
  23. void exec(Opr* opr, const TensorNDArray& tensors) {
  24. if (!W.valid()) {
  25. W = WorkspaceWrapper(opr->handle(), 0);
  26. }
  27. W.update(opr->get_workspace_in_bytes(
  28. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  29. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  30. tensors[6].layout, tensors[7].layout, tensors[8].layout));
  31. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  32. tensors[5], tensors[6], tensors[7], tensors[8],
  33. W.workspace());
  34. }
  35. };
  36. template <typename Opr>
  37. struct ExecProxy<Opr, 8, true> {
  38. WorkspaceWrapper W;
  39. void exec(Opr* opr, const TensorNDArray& tensors) {
  40. if (!W.valid()) {
  41. W = WorkspaceWrapper(opr->handle(), 0);
  42. }
  43. W.update(opr->get_workspace_in_bytes(
  44. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  45. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  46. tensors[6].layout, tensors[7].layout));
  47. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  48. tensors[5], tensors[6], tensors[7], W.workspace());
  49. }
  50. };
  51. template <typename Opr>
  52. struct ExecProxy<Opr, 7, true> {
  53. WorkspaceWrapper W;
  54. void exec(Opr* opr, const TensorNDArray& tensors) {
  55. if (!W.valid()) {
  56. W = WorkspaceWrapper(opr->handle(), 0);
  57. }
  58. W.update(opr->get_workspace_in_bytes(
  59. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  60. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  61. tensors[6].layout));
  62. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  63. tensors[5], tensors[6], W.workspace());
  64. }
  65. };
  66. template <typename Opr>
  67. struct ExecProxy<Opr, 6, true> {
  68. WorkspaceWrapper W;
  69. void exec(Opr* opr, const TensorNDArray& tensors) {
  70. if (!W.valid()) {
  71. W = WorkspaceWrapper(opr->handle(), 0);
  72. }
  73. W.update(opr->get_workspace_in_bytes(
  74. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  75. tensors[3].layout, tensors[4].layout, tensors[5].layout));
  76. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  77. tensors[5], W.workspace());
  78. }
  79. };
  80. template <typename Opr>
  81. struct ExecProxy<Opr, 5, true> {
  82. WorkspaceWrapper W;
  83. void exec(Opr* opr, const TensorNDArray& tensors) {
  84. if (!W.valid()) {
  85. W = WorkspaceWrapper(opr->handle(), 0);
  86. }
  87. W.update(opr->get_workspace_in_bytes(
  88. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  89. tensors[3].layout, tensors[4].layout));
  90. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  91. W.workspace());
  92. }
  93. };
  94. template <typename Opr>
  95. struct ExecProxy<Opr, 4, true> {
  96. WorkspaceWrapper W;
  97. void exec(Opr* opr, const TensorNDArray& tensors) {
  98. if (!W.valid()) {
  99. W = WorkspaceWrapper(opr->handle(), 0);
  100. }
  101. W.update(opr->get_workspace_in_bytes(
  102. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  103. tensors[3].layout));
  104. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  105. W.workspace());
  106. }
  107. };
  108. template <typename Opr>
  109. struct ExecProxy<Opr, 3, true> {
  110. WorkspaceWrapper W;
  111. void exec(Opr* opr, const TensorNDArray& tensors) {
  112. if (!W.valid()) {
  113. W = WorkspaceWrapper(opr->handle(), 0);
  114. }
  115. W.update(opr->get_workspace_in_bytes(
  116. tensors[0].layout, tensors[1].layout, tensors[2].layout));
  117. opr->exec(tensors[0], tensors[1], tensors[2], W.workspace());
  118. }
  119. };
  120. template <typename Opr>
  121. struct ExecProxy<Opr, 2, true> {
  122. WorkspaceWrapper W;
  123. void exec(Opr* opr, const TensorNDArray& tensors) {
  124. if (!W.valid()) {
  125. W = WorkspaceWrapper(opr->handle(), 0);
  126. }
  127. W.update(opr->get_workspace_in_bytes(tensors[0].layout,
  128. tensors[1].layout));
  129. opr->exec(tensors[0], tensors[1], W.workspace());
  130. }
  131. };
  132. template <typename Opr>
  133. struct ExecProxy<Opr, 1, true> {
  134. WorkspaceWrapper W;
  135. void exec(Opr* opr, const TensorNDArray& tensors) {
  136. if (!W.valid()) {
  137. W = WorkspaceWrapper(opr->handle(), 0);
  138. }
  139. W.update(opr->get_workspace_in_bytes(tensors[0].layout));
  140. opr->exec(tensors[0], W.workspace());
  141. }
  142. };
  143. template <typename Opr>
  144. struct ExecProxy<Opr, 5, false> {
  145. void exec(Opr* opr, const TensorNDArray& tensors) {
  146. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]);
  147. }
  148. };
  149. template <typename Opr>
  150. struct ExecProxy<Opr, 4, false> {
  151. void exec(Opr* opr, const TensorNDArray& tensors) {
  152. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3]);
  153. }
  154. };
  155. template <typename Opr>
  156. struct ExecProxy<Opr, 3, false> {
  157. void exec(Opr* opr, const TensorNDArray& tensors) {
  158. opr->exec(tensors[0], tensors[1], tensors[2]);
  159. }
  160. };
  161. template <typename Opr>
  162. struct ExecProxy<Opr, 2, false> {
  163. void exec(Opr* opr, const TensorNDArray& tensors) {
  164. opr->exec(tensors[0], tensors[1]);
  165. }
  166. };
  167. } // namespace test
  168. } // namespace megdnn
  169. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台