You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exec_proxy.h 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. /**
  2. * \file dnn/test/common/exec_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/workspace_wrapper.h"
  14. #include <cstddef>
  15. #include <vector>
  16. namespace megdnn {
  17. namespace test {
  18. template <typename Opr, size_t Arity, bool has_workspace>
  19. struct ExecProxy;
  20. template <typename Opr>
  21. struct ExecProxy<Opr, 8, true> {
  22. WorkspaceWrapper W;
  23. void exec(Opr* opr, const TensorNDArray& tensors) {
  24. if (!W.valid()) {
  25. W = WorkspaceWrapper(opr->handle(), 0);
  26. }
  27. W.update(opr->get_workspace_in_bytes(
  28. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  29. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  30. tensors[6].layout, tensors[7].layout));
  31. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  32. tensors[5], tensors[6], tensors[7], W.workspace());
  33. }
  34. };
  35. template <typename Opr>
  36. struct ExecProxy<Opr, 6, true> {
  37. WorkspaceWrapper W;
  38. void exec(Opr* opr, const TensorNDArray& tensors) {
  39. if (!W.valid()) {
  40. W = WorkspaceWrapper(opr->handle(), 0);
  41. }
  42. W.update(opr->get_workspace_in_bytes(
  43. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  44. tensors[3].layout, tensors[4].layout, tensors[5].layout));
  45. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  46. tensors[5], W.workspace());
  47. }
  48. };
  49. template <typename Opr>
  50. struct ExecProxy<Opr, 5, true> {
  51. WorkspaceWrapper W;
  52. void exec(Opr* opr, const TensorNDArray& tensors) {
  53. if (!W.valid()) {
  54. W = WorkspaceWrapper(opr->handle(), 0);
  55. }
  56. W.update(opr->get_workspace_in_bytes(
  57. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  58. tensors[3].layout, tensors[4].layout));
  59. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  60. W.workspace());
  61. }
  62. };
  63. template <typename Opr>
  64. struct ExecProxy<Opr, 4, true> {
  65. WorkspaceWrapper W;
  66. void exec(Opr* opr, const TensorNDArray& tensors) {
  67. if (!W.valid()) {
  68. W = WorkspaceWrapper(opr->handle(), 0);
  69. }
  70. W.update(opr->get_workspace_in_bytes(
  71. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  72. tensors[3].layout));
  73. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  74. W.workspace());
  75. }
  76. };
  77. template <typename Opr>
  78. struct ExecProxy<Opr, 3, true> {
  79. WorkspaceWrapper W;
  80. void exec(Opr* opr, const TensorNDArray& tensors) {
  81. if (!W.valid()) {
  82. W = WorkspaceWrapper(opr->handle(), 0);
  83. }
  84. W.update(opr->get_workspace_in_bytes(
  85. tensors[0].layout, tensors[1].layout, tensors[2].layout));
  86. opr->exec(tensors[0], tensors[1], tensors[2], W.workspace());
  87. }
  88. };
  89. template <typename Opr>
  90. struct ExecProxy<Opr, 2, true> {
  91. WorkspaceWrapper W;
  92. void exec(Opr* opr, const TensorNDArray& tensors) {
  93. if (!W.valid()) {
  94. W = WorkspaceWrapper(opr->handle(), 0);
  95. }
  96. W.update(opr->get_workspace_in_bytes(tensors[0].layout,
  97. tensors[1].layout));
  98. opr->exec(tensors[0], tensors[1], W.workspace());
  99. }
  100. };
  101. template <typename Opr>
  102. struct ExecProxy<Opr, 1, true> {
  103. WorkspaceWrapper W;
  104. void exec(Opr* opr, const TensorNDArray& tensors) {
  105. if (!W.valid()) {
  106. W = WorkspaceWrapper(opr->handle(), 0);
  107. }
  108. W.update(opr->get_workspace_in_bytes(tensors[0].layout));
  109. opr->exec(tensors[0], W.workspace());
  110. }
  111. };
  112. template <typename Opr>
  113. struct ExecProxy<Opr, 5, false> {
  114. void exec(Opr* opr, const TensorNDArray& tensors) {
  115. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]);
  116. }
  117. };
  118. template <typename Opr>
  119. struct ExecProxy<Opr, 4, false> {
  120. void exec(Opr* opr, const TensorNDArray& tensors) {
  121. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3]);
  122. }
  123. };
  124. template <typename Opr>
  125. struct ExecProxy<Opr, 3, false> {
  126. void exec(Opr* opr, const TensorNDArray& tensors) {
  127. opr->exec(tensors[0], tensors[1], tensors[2]);
  128. }
  129. };
  130. template <typename Opr>
  131. struct ExecProxy<Opr, 2, false> {
  132. void exec(Opr* opr, const TensorNDArray& tensors) {
  133. opr->exec(tensors[0], tensors[1]);
  134. }
  135. };
  136. template <typename Opr>
  137. struct ExecProxy<Opr, 7, true> {
  138. WorkspaceWrapper W;
  139. void exec(Opr* opr, const TensorNDArray& tensors) {
  140. if (!W.valid()) {
  141. W = WorkspaceWrapper(opr->handle(), 0);
  142. }
  143. W.update(opr->get_workspace_in_bytes(
  144. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  145. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  146. tensors[6].layout));
  147. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  148. tensors[5], tensors[6], W.workspace());
  149. }
  150. };
  151. } // namespace test
  152. } // namespace megdnn
  153. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台