You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_chooser.h 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. /**
  2. * \file dnn/src/common/algo_chooser.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <cstddef>
  13. #include <limits>
  14. #include <utility>
  15. #include <vector>
  16. #include "utils.h"
  17. namespace megdnn {
  18. /*!
  19. * \brief get user-configured algorithm, or heuristic algorithm
  20. */
  21. template <class Opr, typename... Args>
  22. typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
  23. typename Opr::AlgorithmInfo ret;
  24. auto set = opr->execution_policy().algo;
  25. if (set.valid()) {
  26. ret = set;
  27. } else {
  28. ret = opr->get_algorithm_info_heuristic(
  29. std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
  30. false);
  31. }
  32. return opr->get_algo_from_desc(ret.desc);
  33. }
  34. /*!
  35. * \brief get user-configured algorithm, or heuristic algorithm. used in opencl
  36. * whose algo need to be constructed each time.
  37. */
  38. template <class Opr, typename... Args>
  39. typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) {
  40. typename Opr::AlgorithmInfo ret;
  41. auto set = opr->execution_policy().algo;
  42. if (set.valid()) {
  43. return opr->algo_pack().construct_and_get_algo(set.desc);
  44. } else {
  45. ret = opr->get_algorithm_info_heuristic(
  46. std::forward<Args>(args)..., std::numeric_limits<size_t>::max(),
  47. false);
  48. return opr->get_algo_from_desc(ret.desc);
  49. }
  50. }
  51. /*!
  52. * \brief get all algorithms from algo_pack() that is available for current size
  53. */
  54. template <class Opr>
  55. std::vector<typename Opr::Algorithm*> get_all_algorithms(
  56. const typename Opr::AlgoBase::SizeArgs& args) {
  57. std::vector<typename Opr::Algorithm*> ret;
  58. ret.reserve(Opr::algo_pack().all_algos.size());
  59. for (auto i : Opr::algo_pack().all_algos) {
  60. if (i->is_available(args)) {
  61. ret.push_back(i);
  62. }
  63. }
  64. megdnn_assert(!ret.empty(), "no conv algorithm for %s",
  65. args.to_string().c_str());
  66. return ret;
  67. }
  68. /*!
  69. * \brief a helper function to get a reproducible algorithm. If require a
  70. * reproducible algorithm, and the given algorithm is reproducible, return the
  71. * given algorithm. Otherwise return nullptr
  72. */
  73. template <typename Opr>
  74. typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo,
  75. bool reproducible) {
  76. if (reproducible) {
  77. if (algo->is_reproducible()) {
  78. return algo;
  79. }
  80. } else {
  81. return algo;
  82. }
  83. return nullptr;
  84. }
  85. template <typename Opr>
  86. typename Opr::Algorithm* get_reproducible_algo(
  87. const std::vector<typename Opr::AlgoBase*>& algos,
  88. const typename Opr::AlgoBase::SizeArgs& args,
  89. size_t workspace_limit_in_bytes, const char* name) {
  90. size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max();
  91. bool available_but_limited_by_workspace = false;
  92. bool available_but_not_reproducible = false;
  93. for (auto i : algos) {
  94. if (i->is_available_reproducible(args, true,
  95. workspace_limit_in_bytes)) {
  96. return i;
  97. }
  98. if (i->is_available_reproducible(args)) {
  99. if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) {
  100. available_but_limited_by_workspace = true;
  101. min_workspace_limit_in_bytes =
  102. std::min(min_workspace_limit_in_bytes,
  103. i->get_workspace_in_bytes(args));
  104. }
  105. }
  106. if (i->is_available(args)) {
  107. if (!i->is_reproducible())
  108. available_but_not_reproducible = true;
  109. }
  110. }
  111. MEGDNN_MARK_USED_VAR(name);
  112. if (available_but_limited_by_workspace) {
  113. megdnn_throw(megdnn_mangle(ssprintf(
  114. "no reproducible %s algorithm: %s workspace limit %zu is "
  115. "less than mini workspace limit %zu",
  116. name, args.to_string().c_str(), workspace_limit_in_bytes,
  117. min_workspace_limit_in_bytes)));
  118. } else if (available_but_not_reproducible) {
  119. megdnn_throw(
  120. megdnn_mangle(ssprintf("no reproducible %s algorithm", name)));
  121. } else {
  122. megdnn_throw(megdnn_mangle(ssprintf("no usable %s algorithm", name)));
  123. }
  124. }
  125. template <typename Opr>
  126. typename Opr::Algorithm* get_usable_algo(
  127. const std::vector<typename Opr::AlgoBase*>& algos,
  128. const typename Opr::AlgoBase::SizeArgs& args,
  129. size_t workspace_limit_in_bytes, const char* name) {
  130. size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max();
  131. bool available_but_limited_by_workspace = false;
  132. for (auto i : algos) {
  133. if (i->is_available_wk(args, workspace_limit_in_bytes)) {
  134. return i;
  135. }
  136. if (i->is_available(args)) {
  137. available_but_limited_by_workspace = true;
  138. min_workspace_limit_in_bytes =
  139. std::min(min_workspace_limit_in_bytes,
  140. i->get_workspace_in_bytes(args));
  141. }
  142. }
  143. MEGDNN_MARK_USED_VAR(name);
  144. if (available_but_limited_by_workspace) {
  145. megdnn_throw(megdnn_mangle(ssprintf(
  146. "no usable %s algorithm: %s workspace limit %zu is "
  147. "less than mini workspace limit %zu",
  148. name, args.to_string().c_str(), workspace_limit_in_bytes,
  149. min_workspace_limit_in_bytes)));
  150. } else {
  151. megdnn_throw(megdnn_mangle(ssprintf("no usable %s algorithm", name)));
  152. }
  153. }
  154. } // namespace megdnn
  155. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台