You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.h 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. /**
  2. * \file dnn/include/megdnn/oprs/base.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/basic_types.h"
  14. #include "megdnn/handle.h"
  15. #include "megdnn/internal/visibility_prologue.h"
  16. namespace megdnn {
  17. class Handle;
  18. /**
  19. * \brief base class for all operators
  20. *
  21. * This is an helper class. Users should not use OperatorBase directly.
  22. * Operators should be created by handle->create_opr<>().
  23. *
  24. * Each operator must provides the following constexpr values:
  25. *
  26. * * NR_INPUTS: number of input vars
  27. * * NR_OUTPUTS: number of output vars
  28. * * OPERATOR_TYPE: operator type as an enum
  29. *
  30. * If the operator has dynamic inputs or in_out param, the corresponding
  31. * NR_INPUTS is -1.
  32. *
  33. * For an operator whose NR_INPUTS >= 0 and NR_OUTPUTS >= 0, the operator must
  34. * also provide following methods:
  35. *
  36. * * void exec(_megdnn_in inputs..., _megdnn_tensor_out outputs...,
  37. * _megdnn_workspace workspace)
  38. * * void deduce_layout(const TensorLayout& inputs...,
  39. * TensorLayout& outputs...)
  40. * * size_t get_workspace_in_bytes(const TensorLayout &inputs...,
  41. * const TensorLayout &outputs)
  42. */
  43. class OperatorBase {
  44. public:
  45. explicit OperatorBase(Handle* handle) : m_handle(handle) {}
  46. virtual ~OperatorBase();
  47. //! get the handle from which this operator is created
  48. Handle* handle() const { return m_handle; }
  49. //! whether this opr guarantees that its exec() is thread-safe
  50. virtual bool is_thread_safe() const { return false; }
  51. /*!
  52. * \brief set the tracker to be used with MegcoreAsyncErrorInfo
  53. *
  54. * Most operators do not have async errors so this function has a
  55. * default empty implementation.
  56. */
  57. virtual void set_error_tracker(void*) {}
  58. private:
  59. Handle* m_handle;
  60. };
  61. namespace detail {
  62. /**
  63. * \brief AlgoSelectionStrategy is the advance information for selecting
  64. * algo
  65. */
  66. enum class AlgoSelectionStrategy {
  67. HEURISTIC = 0, //!< heristic to select the algos
  68. FAST_RUN = 1,
  69. FULL_RUN = 2,
  70. };
  71. /**
  72. * \brief separate algo by datatype for Matmul and conv
  73. */
  74. enum class AlgoDataType : uint32_t {
  75. FLOAT32 = 1 << 0,
  76. FLOAT16 = 1 << 1,
  77. QINT8X8X32 = 1 << 2,
  78. QUINT8X8X32 = 1 << 3,
  79. INT8X8X16 = 1 << 4,
  80. INT16X16X32 = 1 << 5,
  81. };
  82. /*!
  83. * \brief Abstract representation of an algorithm for implementing
  84. * the operator
  85. */
  86. class Algorithm {
  87. public:
  88. static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1);
  89. /**
  90. * \brief Algorithm information, we can get real algo from
  91. * AlgorithmInfo::Info::Desc
  92. */
  93. struct Info {
  94. struct Desc {
  95. //! backend of the algo belonging to
  96. Handle::HandleType handle_type;
  97. //! indicate the real algo implementation
  98. uint32_t type = INVALID_ALGO_TYPE;
  99. //! serialized param of the algo type
  100. std::string param;
  101. bool valid() const { return type != INVALID_ALGO_TYPE; }
  102. void reset() { type = INVALID_ALGO_TYPE; }
  103. bool operator==(const Desc& rhs) const {
  104. return handle_type == rhs.handle_type && type == rhs.type &&
  105. param == rhs.param;
  106. }
  107. } desc;
  108. //! algorithm name
  109. std::string name;
  110. bool is_reproducible;
  111. bool valid() const { return desc.valid(); }
  112. void reset() { desc.reset(); }
  113. //! desc donate the algo
  114. bool operator==(const Info& rhs) const { return desc == rhs.desc; }
  115. };
  116. virtual ~Algorithm() = default;
  117. /**
  118. * \brief whether the execution result is
  119. * reproducible across multiple runs.
  120. */
  121. virtual bool is_reproducible() const = 0;
  122. virtual const char* name() const = 0;
  123. //! serialized param
  124. virtual std::string param() const { return {}; }
  125. virtual uint32_t type() const = 0;
  126. Handle::HandleType handle_type() const { return m_handle_type; }
  127. Info info() const {
  128. return {{handle_type(), type(), param()}, name(), is_reproducible()};
  129. }
  130. template <typename T>
  131. static void serialize_write_pod(const T& val, std::string& result) {
  132. result.append(reinterpret_cast<const char*>(&val), sizeof(T));
  133. }
  134. static void serialize_write_pod(const char* val, std::string& result) {
  135. result.append(val, strlen(val));
  136. }
  137. template <typename T>
  138. static T deserialize_read_pod(const std::string& data, size_t offset = 0) {
  139. T ret = *reinterpret_cast<const T*>(&data[offset]);
  140. return ret;
  141. }
  142. protected:
  143. Handle::HandleType m_handle_type = Handle::HandleType::NAIVE;
  144. };
  145. /*!
  146. * \brief define Algorithm and ExecutionPolicy for oprs that have
  147. * multiple impl algos
  148. *
  149. * \tparam Opr the operator class
  150. * \tparam nargs number of arguments
  151. */
  152. template <class Opr, int nargs>
  153. class MultiAlgoOpr;
  154. //! base def
  155. template <class Opr>
  156. class MultiAlgoOpr<Opr, -1> {
  157. public:
  158. using AlgorithmInfo = detail::Algorithm::Info;
  159. using AlgorithmDesc = detail::Algorithm::Info::Desc;
  160. using Algorithm = detail::Algorithm;
  161. /*!
  162. * \brief get a string representation for current algorithm set;
  163. *
  164. * get_all_algorithms() may return different algorithms only if
  165. * algorithm set name differs. This is used for checking cache
  166. * validity.
  167. */
  168. virtual const char* get_algorithm_set_name() const = 0;
  169. //! policy for executing the operator
  170. struct ExecutionPolicy {
  171. //! INVALID_ALGO_TYPE algo_type means using heuristic
  172. AlgorithmInfo algo;
  173. };
  174. ExecutionPolicy& execution_policy() { return m_execution_policy; }
  175. const ExecutionPolicy& execution_policy() const {
  176. return m_execution_policy;
  177. }
  178. protected:
  179. ~MultiAlgoOpr() = default;
  180. private:
  181. ExecutionPolicy m_execution_policy;
  182. };
  183. //! specialize for nargs == 3
  184. template <class Opr>
  185. class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> {
  186. public:
  187. using Algorithm = detail::Algorithm;
  188. using AlgorithmInfo = detail::Algorithm::Info;
  189. //! get all possible algorithm decriptions for the specified layouts
  190. std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
  191. const TensorLayout& p1,
  192. const TensorLayout& p2) {
  193. std::vector<AlgorithmInfo> ret;
  194. for (auto&& algo : get_all_algorithms(p0, p1, p2)) {
  195. ret.emplace_back(algo->info());
  196. }
  197. return ret;
  198. }
  199. /**
  200. * \brief Returns the best algorithm information which indicate the
  201. * algorithm by heuristic.
  202. *
  203. * The selected algorithm should not use workspace more than
  204. * \p workspace_limit_in_bytes.
  205. */
  206. AlgorithmInfo get_algorithm_info_heuristic(
  207. const TensorLayout& p0, const TensorLayout& p1,
  208. const TensorLayout& p2,
  209. size_t workspace_limit_in_bytes =
  210. std::numeric_limits<size_t>::max(),
  211. bool reproducible = false) {
  212. return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes,
  213. reproducible)
  214. ->info();
  215. }
  216. protected:
  217. ~MultiAlgoOpr() = default;
  218. //! get all possible algorithms for the specified layouts
  219. virtual std::vector<Algorithm*> get_all_algorithms(
  220. const TensorLayout& p0, const TensorLayout& p1,
  221. const TensorLayout& p2) = 0;
  222. /**
  223. * \brief Returns the best algorithm by heuristic.
  224. *
  225. * The selected algorithm should not use workspace more than
  226. * \p workspace_limit_in_bytes.
  227. */
  228. virtual Algorithm* get_algorithm_heuristic(
  229. const TensorLayout& p0, const TensorLayout& p1,
  230. const TensorLayout& p2,
  231. size_t workspace_limit_in_bytes =
  232. std::numeric_limits<size_t>::max(),
  233. bool reproducible = false) = 0;
  234. };
  235. //! specializae for nargs == 4
  236. template <class Opr>
  237. class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> {
  238. public:
  239. using Algorithm = detail::Algorithm;
  240. using AlgorithmInfo = detail::Algorithm::Info;
  241. //! get all possible algorithm decriptions for the specified layouts
  242. std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
  243. const TensorLayout& p1,
  244. const TensorLayout& p2,
  245. const TensorLayout& p3) {
  246. std::vector<AlgorithmInfo> ret;
  247. for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) {
  248. ret.emplace_back(algo->info());
  249. }
  250. return ret;
  251. }
  252. /**
  253. * \brief Returns the best algorithm information which indicate the
  254. * algorithm by heuristic.
  255. *
  256. * The selected algorithm should not use workspace more than
  257. * \p workspace_limit_in_bytes.
  258. */
  259. AlgorithmInfo get_algorithm_info_heuristic(
  260. const TensorLayout& p0, const TensorLayout& p1,
  261. const TensorLayout& p2, const TensorLayout& p3,
  262. size_t workspace_limit_in_bytes =
  263. std::numeric_limits<size_t>::max(),
  264. bool reproducible = false) {
  265. return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes,
  266. reproducible)
  267. ->info();
  268. }
  269. protected:
  270. ~MultiAlgoOpr() = default;
  271. //! get all possible algorithms for the specified layouts
  272. virtual std::vector<Algorithm*> get_all_algorithms(
  273. const TensorLayout& p0, const TensorLayout& p1,
  274. const TensorLayout& p2, const TensorLayout& p3) = 0;
  275. /**
  276. * \brief Returns the best algorithm by heuristic.
  277. *
  278. * The selected algorithm should not use workspace more than
  279. * \p workspace_limit_in_bytes.
  280. */
  281. virtual Algorithm* get_algorithm_heuristic(
  282. const TensorLayout& p0, const TensorLayout& p1,
  283. const TensorLayout& p2, const TensorLayout& p3,
  284. size_t workspace_limit_in_bytes =
  285. std::numeric_limits<size_t>::max(),
  286. bool reproducible = false) = 0;
  287. };
  288. //! specializae for nargs == 5
  289. template <class Opr>
  290. class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> {
  291. public:
  292. using Algorithm = detail::Algorithm;
  293. using AlgorithmInfo = detail::Algorithm::Info;
  294. //! get all possible algorithm decriptions for the specified layouts
  295. std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0,
  296. const TensorLayout& p1,
  297. const TensorLayout& p2,
  298. const TensorLayout& p3,
  299. const TensorLayout& p4) {
  300. std::vector<AlgorithmInfo> ret;
  301. for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) {
  302. ret.emplace_back(algo->info());
  303. }
  304. return ret;
  305. }
  306. /**
  307. * \brief Returns the best algorithm information which indicate the
  308. * algorithm by heuristic.
  309. *
  310. * The selected algorithm should not use workspace more than
  311. * \p workspace_limit_in_bytes.
  312. */
  313. AlgorithmInfo get_algorithm_info_heuristic(
  314. const TensorLayout& p0, const TensorLayout& p1,
  315. const TensorLayout& p2, const TensorLayout& p3,
  316. const TensorLayout& p4,
  317. size_t workspace_limit_in_bytes =
  318. std::numeric_limits<size_t>::max(),
  319. bool reproducible = false) {
  320. return get_algorithm_heuristic(p0, p1, p2, p3, p4,
  321. workspace_limit_in_bytes, reproducible)
  322. ->info();
  323. }
  324. protected:
  325. ~MultiAlgoOpr() = default;
  326. //! get all possible algorithms for the specified layouts
  327. virtual std::vector<Algorithm*> get_all_algorithms(
  328. const TensorLayout& p0, const TensorLayout& p1,
  329. const TensorLayout& p2, const TensorLayout& p3,
  330. const TensorLayout& p4) = 0;
  331. /**
  332. * \brief Returns the best algorithm by heuristic.
  333. *
  334. * The selected algorithm should not use workspace more than
  335. * \p workspace_limit_in_bytes.
  336. */
  337. virtual Algorithm* get_algorithm_heuristic(
  338. const TensorLayout& p0, const TensorLayout& p1,
  339. const TensorLayout& p2, const TensorLayout& p3,
  340. const TensorLayout& p4,
  341. size_t workspace_limit_in_bytes =
  342. std::numeric_limits<size_t>::max(),
  343. bool reproducible = false) = 0;
  344. };
  345. //! specializae for nargs == 8
  346. template <class Opr>
  347. class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> {
  348. public:
  349. using Algorithm = detail::Algorithm;
  350. using AlgorithmInfo = detail::Algorithm::Info;
  351. //! get all possible algorithm decriptions for the specified layouts
  352. std::vector<AlgorithmInfo> get_all_algorithms_info(
  353. const TensorLayout& p0, const TensorLayout& p1,
  354. const TensorLayout& p2, const TensorLayout& p3,
  355. const TensorLayout& p4, const TensorLayout& p5,
  356. const TensorLayout& p6, const TensorLayout& p7) {
  357. std::vector<AlgorithmInfo> ret;
  358. for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) {
  359. ret.emplace_back(algo->info());
  360. }
  361. return ret;
  362. }
  363. /**
  364. * \brief Returns the best algorithm information which indicate the
  365. * algorithm by heuristic.
  366. *
  367. * The selected algorithm should not use workspace more than
  368. */
  369. AlgorithmInfo get_algorithm_info_heuristic(
  370. const TensorLayout& p0, const TensorLayout& p1,
  371. const TensorLayout& p2, const TensorLayout& p3,
  372. const TensorLayout& p4, const TensorLayout& p5,
  373. const TensorLayout& p6, const TensorLayout& p7,
  374. size_t workspace_limit_in_bytes =
  375. std::numeric_limits<size_t>::max(),
  376. bool reproducible = false) {
  377. return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7,
  378. workspace_limit_in_bytes, reproducible)
  379. ->info();
  380. }
  381. protected:
  382. ~MultiAlgoOpr() = default;
  383. //! get all possible algorithms for the specified layouts
  384. virtual std::vector<Algorithm*> get_all_algorithms(
  385. const TensorLayout& p0, const TensorLayout& p1,
  386. const TensorLayout& p2, const TensorLayout& p3,
  387. const TensorLayout& p4, const TensorLayout& p5,
  388. const TensorLayout& p6, const TensorLayout& p7) = 0;
  389. /**
  390. * \brief Returns the best algorithm by heuristic.
  391. *
  392. * The selected algorithm should not use workspace more than
  393. * \p workspace_limit_in_bytes.
  394. */
  395. virtual Algorithm* get_algorithm_heuristic(
  396. const TensorLayout& p0, const TensorLayout& p1,
  397. const TensorLayout& p2, const TensorLayout& p3,
  398. const TensorLayout& p4, const TensorLayout& p5,
  399. const TensorLayout& p6, const TensorLayout& p7,
  400. size_t workspace_limit_in_bytes =
  401. std::numeric_limits<size_t>::max(),
  402. bool reproducible = false) = 0;
  403. };
  404. } // namespace detail
  405. } // namespace megdnn
  406. #include "megdnn/internal/visibility_epilogue.h"
  407. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台