You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_chooser.h 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /**
  2. * \file dnn/src/common/algo_chooser.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <cstddef>
  13. #include <limits>
  14. #include <utility>
  15. #include <vector>
  16. #include "utils.h"
  17. namespace megdnn {
  18. /*!
  19. * \brief get user-configured algorithm, or heuristic algorithm
  20. */
  21. template <class Opr, typename... Args>
  22. typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) {
  23. typename Opr::Algorithm* ret;
  24. if (auto set = opr->execution_policy().algorithm) {
  25. ret = set;
  26. } else {
  27. ret = opr->get_algorithm_heuristic(std::forward<Args>(args)...,
  28. std::numeric_limits<size_t>::max(),
  29. false);
  30. }
  31. return static_cast<typename Opr::AlgoBase*>(ret);
  32. }
  33. /*!
  34. * \brief get all algorithms from algo_pack() that is available for current size
  35. */
  36. template <class Opr>
  37. std::vector<typename Opr::Algorithm*> get_all_algorithms(
  38. const typename Opr::AlgoBase::SizeArgs& args) {
  39. std::vector<typename Opr::Algorithm*> ret;
  40. ret.reserve(Opr::algo_pack().all_algos.size());
  41. for (auto i : Opr::algo_pack().all_algos) {
  42. if (i->is_available(args)) {
  43. ret.push_back(i);
  44. }
  45. }
  46. megdnn_assert(!ret.empty(), "no conv algorithm for %s",
  47. args.to_string().c_str());
  48. return ret;
  49. }
  50. /*!
  51. * \brief a helper function to get a reproducible algorithm. If require a
  52. * reproducible algorithm, and the given algorithm is reproducible, return the
  53. * given algorithm. Otherwise return nullptr
  54. */
  55. template <typename Opr>
  56. typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo,
  57. bool reproducible) {
  58. if (reproducible) {
  59. if (algo->is_reproducible()) {
  60. return algo;
  61. }
  62. } else {
  63. return algo;
  64. }
  65. return nullptr;
  66. }
  67. template <typename Opr>
  68. typename Opr::Algorithm* get_reproducible_algo(
  69. const std::vector<typename Opr::AlgoBase*>& algos,
  70. const typename Opr::AlgoBase::SizeArgs& args,
  71. size_t workspace_limit_in_bytes, const char* name) {
  72. size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max();
  73. bool available_but_limited_by_workspace = false;
  74. bool available_but_not_reproducible = false;
  75. for (auto i : algos) {
  76. if (i->is_available_reproducible(args, true,
  77. workspace_limit_in_bytes)) {
  78. return i;
  79. }
  80. if (i->is_available_reproducible(args)) {
  81. if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) {
  82. available_but_limited_by_workspace = true;
  83. min_workspace_limit_in_bytes =
  84. std::min(min_workspace_limit_in_bytes,
  85. i->get_workspace_in_bytes(args));
  86. }
  87. }
  88. if (i->is_available(args)) {
  89. if (!i->is_reproducible())
  90. available_but_not_reproducible = true;
  91. }
  92. }
  93. MEGDNN_MARK_USED_VAR(name);
  94. if (available_but_limited_by_workspace) {
  95. megdnn_throw(megdnn_mangle(ssprintf(
  96. "no reproducible %s algorithm: %s workspace limit %zu is "
  97. "less than mini workspace limit %zu",
  98. name, args.to_string().c_str(), workspace_limit_in_bytes,
  99. min_workspace_limit_in_bytes)));
  100. } else if (available_but_not_reproducible) {
  101. megdnn_throw(
  102. megdnn_mangle(ssprintf("no reproducible %s algorithm", name)));
  103. } else {
  104. megdnn_throw(megdnn_mangle(ssprintf("no usable %s algorithm", name)));
  105. }
  106. }
  107. template <typename Opr>
  108. typename Opr::Algorithm* get_usable_algo(
  109. const std::vector<typename Opr::AlgoBase*>& algos,
  110. const typename Opr::AlgoBase::SizeArgs& args,
  111. size_t workspace_limit_in_bytes, const char* name) {
  112. size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max();
  113. bool available_but_limited_by_workspace = false;
  114. for (auto i : algos) {
  115. if (i->is_available_wk(args, workspace_limit_in_bytes)) {
  116. return i;
  117. }
  118. if (i->is_available(args)) {
  119. available_but_limited_by_workspace = true;
  120. min_workspace_limit_in_bytes =
  121. std::min(min_workspace_limit_in_bytes,
  122. i->get_workspace_in_bytes(args));
  123. }
  124. }
  125. MEGDNN_MARK_USED_VAR(name);
  126. if (available_but_limited_by_workspace) {
  127. megdnn_throw(megdnn_mangle(ssprintf(
  128. "no usable %s algorithm: %s workspace limit %zu is "
  129. "less than mini workspace limit %zu",
  130. name, args.to_string().c_str(), workspace_limit_in_bytes,
  131. min_workspace_limit_in_bytes)));
  132. } else {
  133. megdnn_throw(megdnn_mangle(ssprintf("no usable %s algorithm", name)));
  134. }
  135. }
  136. } // namespace megdnn
  137. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台