You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_base.h 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. /**
  2. * \file dnn/src/common/algo_base.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <functional>
  14. #include <string>
  15. #include <tuple>
  16. #include "megdnn/oprs/base.h"
  17. #include "src/common/utils.h"
  18. namespace megdnn {
  19. #define MEGDNN_DECL_ALGO_TYPE(_type) \
  20. uint32_t type() const override { \
  21. return static_cast<std::underlying_type<AlgoType>::type>( \
  22. AlgoType::_type); \
  23. }
  24. #define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \
  25. static fallback::_opr::AlgoBase* get_algo_from_desc( \
  26. const AlgorithmDesc& desc)
  27. #define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \
  28. fallback::_opr::AlgoBase* _opr::get_algo_from_desc( \
  29. const AlgorithmDesc& desc) { \
  30. megdnn_assert(algo_pack().all_algos_map().find(desc) != \
  31. algo_pack().all_algos_map().end()); \
  32. return algo_pack().all_algos_map().at(desc); \
  33. }
  34. #define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \
  35. _opr::Algorithm* _opr::get_algorithm_from_desc( \
  36. const AlgorithmDesc& desc) { \
  37. megdnn_assert(algo_pack().all_algos_map().find(desc) != \
  38. algo_pack().all_algos_map().end()); \
  39. return algo_pack().all_algos_map().at(desc); \
  40. }
  41. #define MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) \
  42. cb(AlgoAttribute::ACCURACY_DEPEND_ON_BATCH)
  43. /**
  44. * \brief construct algo from AlgorithmDesc
  45. */
  46. template <typename AlgoBase>
  47. class AlgoConstructMixin {
  48. private:
  49. std::vector<std::unique_ptr<AlgoBase>> m_refhold;
  50. protected:
  51. typename AlgoBase::Mapper m_all_algos_map;
  52. public:
  53. //! construct the algo which described by desc, and return the instance
  54. AlgoBase* construct_and_get_algo(
  55. const detail::Algorithm::Info::Desc& desc) {
  56. auto iter = m_all_algos_map.find(desc);
  57. if (iter != m_all_algos_map.end()) {
  58. return m_all_algos_map.at(desc);
  59. }
  60. std::string serialized_bin;
  61. AlgoBase::serialize_write_pod(desc.type, serialized_bin);
  62. serialized_bin += desc.param;
  63. m_refhold.emplace_back(AlgoBase::deserialize(serialized_bin));
  64. m_all_algos_map.emplace(desc, m_refhold.back().get());
  65. return m_refhold.back().get();
  66. }
  67. void clear() {
  68. m_all_algos_map.clear();
  69. m_refhold.clear();
  70. }
  71. const typename AlgoBase::Mapper& all_algos_map() const {
  72. return m_all_algos_map;
  73. }
  74. };
  75. template <std::size_t I = 0, typename Opr, typename... Tp>
  76. inline typename std::enable_if<I == sizeof...(Tp), void>::type
  77. set_sub_execution_policy(const Opr*, std::tuple<Tp...>&) {}
  78. template <std::size_t I = 0, typename Opr, typename... Tp>
  79. inline typename std::enable_if <
  80. I<sizeof...(Tp), void>::type set_sub_execution_policy(
  81. const Opr* opr, std::tuple<Tp...>& t) {
  82. std::get<I>(t)->execution_policy() = opr->execution_policy().sub_policy[I];
  83. set_sub_execution_policy<I + 1, Opr, Tp...>(opr, t);
  84. }
  85. template <typename Opr, typename... SubOpr>
  86. void set_execution_policy(const Opr* opr, SubOpr... sub_oprs) {
  87. if (opr->execution_policy().algo.valid() &&
  88. !opr->execution_policy().sub_policy.empty()) {
  89. megdnn_assert(opr->execution_policy().sub_policy.size() ==
  90. sizeof...(sub_oprs));
  91. auto&& sub = std::make_tuple(sub_oprs...);
  92. set_sub_execution_policy<0, Opr, SubOpr...>(opr, sub);
  93. }
  94. }
  95. } // namespace megdnn
  96. namespace std {
  97. template <>
  98. struct hash<megdnn::detail::Algorithm::Info::Desc> {
  99. std::size_t operator()(
  100. const megdnn::detail::Algorithm::Info::Desc& desc) const {
  101. return megdnn::hash_combine<size_t>(
  102. megdnn::hash_combine<size_t>(
  103. std::hash<std::string>()(desc.name),
  104. megdnn::hash_combine<size_t>(
  105. std::hash<std::string>()(desc.param),
  106. std::hash<uint32_t>()(desc.type))),
  107. std::hash<uint32_t>()(static_cast<uint32_t>(desc.handle_type)));
  108. }
  109. };
  110. } // namespace std
  111. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台