You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exec_proxy.h 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. /**
  2. * \file dnn/test/common/exec_proxy.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/basic_types.h"
  13. #include "test/common/workspace_wrapper.h"
  14. #include <cstddef>
  15. #include <vector>
  16. namespace megdnn {
  17. namespace test {
  18. template <typename Opr, size_t Arity, bool has_workspace>
  19. struct ExecProxy;
  20. template <typename Opr>
  21. struct ExecProxy<Opr, 8, true> {
  22. WorkspaceWrapper W;
  23. void exec(Opr* opr, const TensorNDArray& tensors) {
  24. if (!W.valid()) {
  25. W = WorkspaceWrapper(opr->handle(), 0);
  26. }
  27. W.update(opr->get_workspace_in_bytes(
  28. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  29. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  30. tensors[6].layout, tensors[7].layout));
  31. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  32. tensors[5], tensors[6], tensors[7], W.workspace());
  33. }
  34. };
  35. template <typename Opr>
  36. struct ExecProxy<Opr, 5, true> {
  37. WorkspaceWrapper W;
  38. void exec(Opr* opr, const TensorNDArray& tensors) {
  39. if (!W.valid()) {
  40. W = WorkspaceWrapper(opr->handle(), 0);
  41. }
  42. W.update(opr->get_workspace_in_bytes(
  43. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  44. tensors[3].layout, tensors[4].layout));
  45. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  46. W.workspace());
  47. }
  48. };
  49. template <typename Opr>
  50. struct ExecProxy<Opr, 4, true> {
  51. WorkspaceWrapper W;
  52. void exec(Opr* opr, const TensorNDArray& tensors) {
  53. if (!W.valid()) {
  54. W = WorkspaceWrapper(opr->handle(), 0);
  55. }
  56. W.update(opr->get_workspace_in_bytes(
  57. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  58. tensors[3].layout));
  59. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3],
  60. W.workspace());
  61. }
  62. };
  63. template <typename Opr>
  64. struct ExecProxy<Opr, 3, true> {
  65. WorkspaceWrapper W;
  66. void exec(Opr* opr, const TensorNDArray& tensors) {
  67. if (!W.valid()) {
  68. W = WorkspaceWrapper(opr->handle(), 0);
  69. }
  70. W.update(opr->get_workspace_in_bytes(
  71. tensors[0].layout, tensors[1].layout, tensors[2].layout));
  72. opr->exec(tensors[0], tensors[1], tensors[2], W.workspace());
  73. }
  74. };
  75. template <typename Opr>
  76. struct ExecProxy<Opr, 2, true> {
  77. WorkspaceWrapper W;
  78. void exec(Opr* opr, const TensorNDArray& tensors) {
  79. if (!W.valid()) {
  80. W = WorkspaceWrapper(opr->handle(), 0);
  81. }
  82. W.update(opr->get_workspace_in_bytes(tensors[0].layout,
  83. tensors[1].layout));
  84. opr->exec(tensors[0], tensors[1], W.workspace());
  85. }
  86. };
  87. template <typename Opr>
  88. struct ExecProxy<Opr, 1, true> {
  89. WorkspaceWrapper W;
  90. void exec(Opr* opr, const TensorNDArray& tensors) {
  91. if (!W.valid()) {
  92. W = WorkspaceWrapper(opr->handle(), 0);
  93. }
  94. W.update(opr->get_workspace_in_bytes(tensors[0].layout));
  95. opr->exec(tensors[0], W.workspace());
  96. }
  97. };
  98. template <typename Opr>
  99. struct ExecProxy<Opr, 5, false> {
  100. void exec(Opr* opr, const TensorNDArray& tensors) {
  101. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4]);
  102. }
  103. };
  104. template <typename Opr>
  105. struct ExecProxy<Opr, 4, false> {
  106. void exec(Opr* opr, const TensorNDArray& tensors) {
  107. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3]);
  108. }
  109. };
  110. template <typename Opr>
  111. struct ExecProxy<Opr, 3, false> {
  112. void exec(Opr* opr, const TensorNDArray& tensors) {
  113. opr->exec(tensors[0], tensors[1], tensors[2]);
  114. }
  115. };
  116. template <typename Opr>
  117. struct ExecProxy<Opr, 2, false> {
  118. void exec(Opr* opr, const TensorNDArray& tensors) {
  119. opr->exec(tensors[0], tensors[1]);
  120. }
  121. };
  122. template <typename Opr>
  123. struct ExecProxy<Opr, 7, true> {
  124. WorkspaceWrapper W;
  125. void exec(Opr* opr, const TensorNDArray& tensors) {
  126. if (!W.valid()) {
  127. W = WorkspaceWrapper(opr->handle(), 0);
  128. }
  129. W.update(opr->get_workspace_in_bytes(
  130. tensors[0].layout, tensors[1].layout, tensors[2].layout,
  131. tensors[3].layout, tensors[4].layout, tensors[5].layout,
  132. tensors[6].layout));
  133. opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
  134. tensors[5], tensors[6], W.workspace());
  135. }
  136. };
  137. } // namespace test
  138. } // namespace megdnn
  139. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台