You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exec_proxy.h 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /**
  2. * \file dnn/test/common/exec_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/workspace_wrapper.h"
  14. #include <cstddef>
  15. #include <vector>
  16. namespace megdnn {
  17. namespace test {
  18. template <typename Opr, size_t Arity, bool has_workspace>
  19. struct ExecProxy;
  20. template <typename Opr>
  21. struct ExecProxy<Opr, 8, true> {
  22. WorkspaceWrapper W;
  23. void exec(Opr* opr, const TensorNDArray& tensors) {
  24. if (!W.valid()) {
  25. W = WorkspaceWrapper(opr->handle(), 0);
  26. }
  27. W.update(opr->get_workspace_in_bytes(
  28. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  29. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  30. tensors[6].layout, tensors[7].layout));
  31. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  32. tensors[5], tensors[6], tensors[7], W.workspace());
  33. }
  34. };
  35. template <typename Opr>
  36. struct ExecProxy<Opr, 7, true> {
  37. WorkspaceWrapper W;
  38. void exec(Opr* opr, const TensorNDArray& tensors) {
  39. if (!W.valid()) {
  40. W = WorkspaceWrapper(opr->handle(), 0);
  41. }
  42. W.update(opr->get_workspace_in_bytes(
  43. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  44. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  45. tensors[6].layout));
  46. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  47. tensors[5], tensors[6], W.workspace());
  48. }
  49. };
  50. template <typename Opr>
  51. struct ExecProxy<Opr, 6, true> {
  52. WorkspaceWrapper W;
  53. void exec(Opr* opr, const TensorNDArray& tensors) {
  54. if (!W.valid()) {
  55. W = WorkspaceWrapper(opr->handle(), 0);
  56. }
  57. W.update(opr->get_workspace_in_bytes(
  58. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  59. tensors[3].layout, tensors[4].layout, tensors[5].layout));
  60. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  61. tensors[5], W.workspace());
  62. }
  63. };
  64. template <typename Opr>
  65. struct ExecProxy<Opr, 5, true> {
  66. WorkspaceWrapper W;
  67. void exec(Opr* opr, const TensorNDArray& tensors) {
  68. if (!W.valid()) {
  69. W = WorkspaceWrapper(opr->handle(), 0);
  70. }
  71. W.update(opr->get_workspace_in_bytes(
  72. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  73. tensors[3].layout, tensors[4].layout));
  74. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  75. W.workspace());
  76. }
  77. };
  78. template <typename Opr>
  79. struct ExecProxy<Opr, 4, true> {
  80. WorkspaceWrapper W;
  81. void exec(Opr* opr, const TensorNDArray& tensors) {
  82. if (!W.valid()) {
  83. W = WorkspaceWrapper(opr->handle(), 0);
  84. }
  85. W.update(opr->get_workspace_in_bytes(
  86. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  87. tensors[3].layout));
  88. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  89. W.workspace());
  90. }
  91. };
  92. template <typename Opr>
  93. struct ExecProxy<Opr, 3, true> {
  94. WorkspaceWrapper W;
  95. void exec(Opr* opr, const TensorNDArray& tensors) {
  96. if (!W.valid()) {
  97. W = WorkspaceWrapper(opr->handle(), 0);
  98. }
  99. W.update(opr->get_workspace_in_bytes(
  100. tensors[0].layout, tensors[1].layout, tensors[2].layout));
  101. opr->exec(tensors[0], tensors[1], tensors[2], W.workspace());
  102. }
  103. };
  104. template <typename Opr>
  105. struct ExecProxy<Opr, 2, true> {
  106. WorkspaceWrapper W;
  107. void exec(Opr* opr, const TensorNDArray& tensors) {
  108. if (!W.valid()) {
  109. W = WorkspaceWrapper(opr->handle(), 0);
  110. }
  111. W.update(opr->get_workspace_in_bytes(tensors[0].layout,
  112. tensors[1].layout));
  113. opr->exec(tensors[0], tensors[1], W.workspace());
  114. }
  115. };
  116. template <typename Opr>
  117. struct ExecProxy<Opr, 1, true> {
  118. WorkspaceWrapper W;
  119. void exec(Opr* opr, const TensorNDArray& tensors) {
  120. if (!W.valid()) {
  121. W = WorkspaceWrapper(opr->handle(), 0);
  122. }
  123. W.update(opr->get_workspace_in_bytes(tensors[0].layout));
  124. opr->exec(tensors[0], W.workspace());
  125. }
  126. };
  127. template <typename Opr>
  128. struct ExecProxy<Opr, 5, false> {
  129. void exec(Opr* opr, const TensorNDArray& tensors) {
  130. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]);
  131. }
  132. };
  133. template <typename Opr>
  134. struct ExecProxy<Opr, 4, false> {
  135. void exec(Opr* opr, const TensorNDArray& tensors) {
  136. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3]);
  137. }
  138. };
  139. template <typename Opr>
  140. struct ExecProxy<Opr, 3, false> {
  141. void exec(Opr* opr, const TensorNDArray& tensors) {
  142. opr->exec(tensors[0], tensors[1], tensors[2]);
  143. }
  144. };
  145. template <typename Opr>
  146. struct ExecProxy<Opr, 2, false> {
  147. void exec(Opr* opr, const TensorNDArray& tensors) {
  148. opr->exec(tensors[0], tensors[1]);
  149. }
  150. };
  151. } // namespace test
  152. } // namespace megdnn
  153. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台