You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_chooser.cpp 39 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944
  1. /**
  2. * \file src/opr/impl/search_policy/algo_chooser.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/opr/search_policy/algo_chooser.h"
  13. #include <limits>
  14. #include <unordered_set>
  15. #include "megbrain/opr/dnn/convolution.h"
  16. #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
  17. #include "megbrain/opr/search_policy/algo_chooser_helper.h"
  18. #include "megbrain/opr/search_policy/profiler.h"
  19. #include "../internal/invoke.h"
  20. #include "../internal/megdnn_opr_wrapper.inl"
  21. #include "./workspace_need_limit_getter.inl"
  22. //! TODO: here has to be know some megdnn::opr when there is produced midout.h
  23. //! fix it if there is another graceful way.
  24. #include "megdnn/opr_param_defs.h"
  25. #include "megdnn/oprs.h"
  26. #include "megdnn/oprs/base.h"
  27. #include "midout.h"
  28. MIDOUT_DECL(megbrain_opr_algo_chooser)
  29. #define MIDOUT_B(...) MIDOUT_BEGIN(megbrain_opr_algo_chooser, __VA_ARGS__) {
  30. #define MIDOUT_E \
  31. } \
  32. MIDOUT_END();
  33. using mgb::opr::intl::WorkspaceLimitGetter;
  34. using namespace megdnn;
  35. using namespace mgb;
  36. #define APPLY(statement, ...) \
  37. mgb::apply([&](const auto&... args) { return statement; }, \
  38. std::tuple_cat(__VA_ARGS__))
  39. // timeout delta to be added with fastest known algorithm for new algos
  40. constexpr double TIMEOUT_TOLERANCE = 2;
  41. #define CACHE_KEY_VERSION "v5"
  42. namespace {
  43. template <typename Opr>
  44. std::string profile_name(Opr* opr) {
  45. std::string ret =
  46. std::string(MegDNNOpr2MGBOpr<Opr>::MGBOpr::typeinfo()->name) +
  47. CACHE_KEY_VERSION;
  48. ret.append(opr->get_algorithm_set_name());
  49. return ret;
  50. }
  51. template <typename Opr>
  52. std::string format_fixlayouts(
  53. const typename opr::AlgoChooser<Opr>::FixedTensorLayouts& layouts,
  54. size_t arity_in, size_t arity_out) {
  55. std::string ret;
  56. ret.append(": tensor layouts(");
  57. for (size_t i = 0; i < arity_in; ++i) {
  58. if (i) {
  59. ret.append(", ");
  60. }
  61. ret.append(layouts[i].to_string() + " ");
  62. }
  63. ret.append(") -> (");
  64. for (size_t i = 0; i < arity_out; ++i) {
  65. if (i) {
  66. ret.append(", ");
  67. }
  68. ret.append(layouts[i + arity_in].to_string() + " ");
  69. }
  70. return ret;
  71. }
  72. /**
  73. * \brief Check if the sub opr list has circular dependence.
  74. */
  75. class CircularDepsChecker {
  76. struct SearchItemStorage {
  77. std::string data_hold;
  78. size_t hash = 0;
  79. SearchItemStorage(const Algorithm::SearchItem& item) {
  80. Algorithm::serialize_write_pod(item.opr_type, data_hold);
  81. for (auto&& layout : item.layouts) {
  82. data_hold += layout.serialize();
  83. }
  84. data_hold += item.param;
  85. }
  86. SearchItemStorage& init_hash() {
  87. hash = XXHash64CT::hash(data_hold.data(), data_hold.size(),
  88. 20201225);
  89. return *this;
  90. }
  91. bool operator==(const SearchItemStorage& rhs) const {
  92. return data_hold == rhs.data_hold;
  93. }
  94. struct Hash {
  95. size_t operator()(const SearchItemStorage& s) const {
  96. return s.hash;
  97. }
  98. };
  99. };
  100. std::unordered_set<SearchItemStorage, SearchItemStorage::Hash> m_set;
  101. public:
  102. void put(const megdnn::Algorithm::SearchItem& key) {
  103. SearchItemStorage key_storage(key);
  104. key_storage.init_hash();
  105. mgb_assert(m_set.find(key_storage) == m_set.end(),
  106. "Circular dependency during flatten search space");
  107. auto ret = m_set.insert(std::move(key_storage));
  108. mgb_assert(ret.second);
  109. }
  110. void remove(const megdnn::Algorithm::SearchItem& key) {
  111. SearchItemStorage key_storage(key);
  112. key_storage.init_hash();
  113. auto&& iter = m_set.find(key_storage);
  114. mgb_assert(iter != m_set.end());
  115. m_set.erase(iter);
  116. }
  117. };
  118. ///////////////// OprTypeTrait /////////////////////////////
  119. template <megdnn::Algorithm::OprType>
  120. struct OprFromOprTypeTrait;
  121. template <typename Opr>
  122. struct OprTypeFromOprTrait;
  123. #define cb(_opr_type, _opr) \
  124. template <> \
  125. struct OprFromOprTypeTrait<megdnn::Algorithm::OprType::_opr_type> { \
  126. using Opr = megdnn::_opr; \
  127. }; \
  128. template <> \
  129. struct OprTypeFromOprTrait<megdnn::_opr> { \
  130. constexpr static megdnn::Algorithm::OprType opr_type = \
  131. megdnn::Algorithm::OprType::_opr_type; \
  132. }
  133. cb(MATRIX_MUL_FORWARD, MatrixMulForward);
  134. cb(BATCHED_MATRIX_MUL_FORWARD, BatchedMatrixMulForward);
  135. cb(CONVOLUTION_FORWARD, ConvolutionForward);
  136. cb(CONVOLUTION_BACKWARD_DATA, ConvolutionBackwardData);
  137. cb(CONVOLUTION_BACKWARD_FILTER, ConvolutionBackwardFilter);
  138. cb(CONVOLUTION3D_FORWARD, Convolution3DForward);
  139. cb(CONVOLUTION3D_BACKWARD_DATA, Convolution3DBackwardData);
  140. cb(CONVOLUTION3D_BACKWARD_FILTER, Convolution3DBackwardFilter);
  141. cb(LOCAL_SHARE_FORWARD, LocalShareForward);
  142. cb(LOCAL_SHARE_BACKWARD_DATA, LocalShareBackwardData);
  143. cb(LOCAL_SHARE_BACKWARD_FILTER, LocalShareBackwardFilter);
  144. cb(DEFORMABLE_CONV_FORWARD, DeformableConvForward);
  145. cb(DEFORMABLE_CONV_BACKWARD_DATA, DeformableConvBackwardData);
  146. cb(DEFORMABLE_CONV_BACKWARD_FILTER, DeformableConvBackwardFilter);
  147. cb(BATCH_CONV_FORWARD, BatchConvBiasForward);
  148. cb(CONVBIAS_FORWARD, ConvBiasForward);
  149. #undef cb
  150. // clang-format off
  151. #define FOREACH_OPR_TYPE_WITH_STMT(cb, stmt) \
  152. cb(MATRIX_MUL_FORWARD, stmt) \
  153. cb(BATCHED_MATRIX_MUL_FORWARD, stmt) \
  154. cb(CONVOLUTION_FORWARD, stmt) \
  155. cb(CONVOLUTION_BACKWARD_DATA, stmt) \
  156. cb(CONVOLUTION_BACKWARD_FILTER, stmt) \
  157. cb(CONVOLUTION3D_FORWARD, stmt) \
  158. cb(CONVOLUTION3D_BACKWARD_DATA, stmt) \
  159. cb(CONVOLUTION3D_BACKWARD_FILTER, stmt) \
  160. cb(LOCAL_SHARE_FORWARD, stmt) \
  161. cb(LOCAL_SHARE_BACKWARD_DATA, stmt) \
  162. cb(LOCAL_SHARE_BACKWARD_FILTER, stmt) \
  163. cb(DEFORMABLE_CONV_FORWARD, stmt) \
  164. cb(DEFORMABLE_CONV_BACKWARD_DATA, stmt) \
  165. cb(DEFORMABLE_CONV_BACKWARD_FILTER, stmt) \
  166. cb(BATCH_CONV_FORWARD, stmt) \
  167. cb(CONVBIAS_FORWARD, stmt)
  168. // clang-format on
  169. #define _OPR_TYPE_CASE(_opr_type, _stmt) \
  170. case Algorithm::OprType::_opr_type: { \
  171. using _Opr = typename OprFromOprTypeTrait< \
  172. Algorithm::OprType::_opr_type>::Opr; \
  173. _stmt; \
  174. break; \
  175. }
  176. #define FOREACH_OPR_TYPE_DISPATCH(_search_items, _stmt) \
  177. for (size_t _item_idx = 0; _item_idx < _search_items.size(); \
  178. _item_idx++) { \
  179. auto&& _item = _search_items[_item_idx]; \
  180. switch (_item.opr_type) { \
  181. FOREACH_OPR_TYPE_WITH_STMT(_OPR_TYPE_CASE, _stmt) \
  182. default: \
  183. mgb_throw(MegBrainError, "unknown opr_type"); \
  184. } \
  185. }
  186. template <typename Opr>
  187. TensorLayoutArray to_layout_array(
  188. const typename opr::AlgoChooser<Opr>::FixedTensorLayouts& layouts) {
  189. TensorLayoutArray ret;
  190. for (auto&& layout : layouts) {
  191. ret.push_back(layout);
  192. }
  193. return ret;
  194. }
  195. template <typename Opr>
  196. typename opr::AlgoChooser<Opr>::FixedTensorLayouts to_fixed_layouts(
  197. const TensorLayoutArray& layouts) {
  198. typename opr::AlgoChooser<Opr>::FixedTensorLayouts ret;
  199. mgb_assert(ret.size() == layouts.size());
  200. size_t idx = 0;
  201. for (auto&& layout : layouts) {
  202. ret[idx++] = layout;
  203. }
  204. return ret;
  205. }
  206. /**
  207. * flatten search space in postorder traversal
  208. * The subopr search construct a search tree
  209. *
  210. * A
  211. * / \
  212. * B1B2 C
  213. * / \
  214. * D1D2D3 E
  215. * We use postorder traverse the search tree.
  216. * D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A
  217. */
  218. template <typename Opr>
  219. std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
  220. const typename opr::AlgoChooser<Opr>::AlgoChooserHelper& helper,
  221. CircularDepsChecker& checker) {
  222. auto&& search_item = megdnn::Algorithm::SearchItem{
  223. OprTypeFromOprTrait<Opr>::opr_type, helper.param(),
  224. to_layout_array<Opr>(helper.layouts())};
  225. checker.put(search_item);
  226. std::vector<megdnn::Algorithm::SearchItem> ret;
  227. for (auto algo_info : helper.get_all_candidates()) {
  228. megdnn::Algorithm* algo =
  229. helper.get_algorithm_from_desc(algo_info.desc);
  230. mgb_assert(algo, "Unknown algo description");
  231. std::vector<megdnn::Algorithm::SearchItem>&& sub_items =
  232. algo->get_subopr_list(to_layout_array<Opr>(helper.layouts()),
  233. helper.megdnn_opr());
  234. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  235. auto&& megdnn_opr =
  236. opr::intl::create_megdnn_opr<_Opr>(helper.comp_node());
  237. megdnn_opr->param() =
  238. Algorithm::deserialize_read_pod<typename _Opr::Param>(
  239. _item.param);
  240. typename opr::AlgoChooser<_Opr>::AlgoChooserHelper sub_helper(
  241. to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
  242. _item.param, helper.mgb_opr(), helper.comp_node(),
  243. helper.execution_policy(),
  244. helper.allow_weight_preprocess());
  245. auto space = flatten_search_space<_Opr>(sub_helper, checker);
  246. ret.insert(ret.end(), space.begin(), space.end());
  247. });
  248. }
  249. ret.push_back(search_item);
  250. checker.remove(search_item);
  251. return ret;
  252. }
  253. //! serialize a algo's desc to string. format is
  254. //! handle_type|algo_type|size_of_param|size_of_name|string_of_param|string_of_name
  255. static void serialize_write_pod(const Algorithm::Info::Desc& val,
  256. std::string& result) {
  257. megdnn::Algorithm::serialize_write_pod(val.handle_type, result);
  258. megdnn::Algorithm::serialize_write_pod(val.type, result);
  259. uint32_t param_size = val.param.size();
  260. uint32_t name_size = val.name.size();
  261. megdnn::Algorithm::serialize_write_pod<uint32_t>(param_size, result);
  262. megdnn::Algorithm::serialize_write_pod<uint32_t>(name_size, result);
  263. result += val.param;
  264. result += val.name;
  265. }
  266. static Algorithm::Info::Desc deserialize_read_pod(const std::string& data,
  267. size_t offset = 0) {
  268. Algorithm::Info::Desc ret;
  269. #define cb(_val, _type) \
  270. _val = megdnn::Algorithm::deserialize_read_pod<_type>(data.data(), \
  271. offset); \
  272. offset += sizeof(_val)
  273. cb(ret.handle_type, megdnn::Handle::HandleType);
  274. cb(ret.type, uint32_t);
  275. uint32_t param_size = 0;
  276. uint32_t name_size = 0;
  277. cb(param_size, uint32_t);
  278. cb(name_size, uint32_t);
  279. if (param_size > 0) {
  280. ret.param = std::string(data.data() + offset, param_size);
  281. offset += param_size;
  282. }
  283. if (name_size > 0) {
  284. ret.name = std::string(data.data() + offset, name_size);
  285. offset += name_size;
  286. }
  287. return ret;
  288. }
  289. } // namespace
  290. namespace mgb {
  291. namespace opr {
  292. ///////////////////////////// AlgoChooserHelper //////////////////////////
  293. template <typename Opr>
  294. AlgoChooser<Opr>::AlgoChooserHelper::AlgoChooserHelper(
  295. const FixedTensorLayouts& layouts, Opr* megdnn_opr,
  296. const std::string& param_str, const cg::OperatorNodeBase* mgb_opr,
  297. const CompNode& cn,
  298. const megdnn::param::ExecutionPolicy& execution_policy,
  299. bool allow_weight_preprocess)
  300. : m_layouts{layouts},
  301. m_dnn_opr{megdnn_opr},
  302. m_param{param_str},
  303. m_base_mgb_opr{mgb_opr},
  304. m_cn{cn},
  305. m_execution_policy{execution_policy},
  306. m_allow_weight_preprocess{allow_weight_preprocess} {
  307. mgb_assert(m_layouts.size() == layouts.size());
  308. static_assert(std::tuple_size<FixedTensorLayouts>::value == 3 ||
  309. std::tuple_size<FixedTensorLayouts>::value == 5 ||
  310. std::tuple_size<FixedTensorLayouts>::value == 8,
  311. "Convolution AlgoChooser assumes arity = 3 , 5 or 8 (for "
  312. "deformable conv)");
  313. }
  314. template <typename Opr>
  315. typename AlgoChooser<Opr>::ImplExecutionPolicy
  316. AlgoChooser<Opr>::AlgoChooserHelper::choose_by_heuristic(
  317. const ExecutionStrategy& selected_strategy) const {
  318. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_heuristic")))
  319. ImplExecutionPolicy policy;
  320. auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
  321. owner_graph(), m_cn, m_execution_policy.workspace_limit);
  322. auto attr = extract_algo_attribute(selected_strategy);
  323. policy.algo =
  324. APPLY(m_dnn_opr->get_algorithm_info_heuristic(
  325. args..., workspace_limit, attr.first, attr.second),
  326. m_layouts)
  327. .desc;
  328. Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
  329. mgb_assert(algo, "Unknown algo description");
  330. std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list(
  331. to_layout_array<Opr>(m_layouts), m_dnn_opr);
  332. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  333. auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn);
  334. megdnn_opr->param() =
  335. Algorithm::deserialize_read_pod<typename _Opr::Param>(
  336. _item.param);
  337. typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper(
  338. to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
  339. _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
  340. m_allow_weight_preprocess);
  341. policy.sub_policy.push_back(
  342. sub_helper.choose_by_heuristic(selected_strategy));
  343. });
  344. return policy;
  345. MIDOUT_E
  346. }
  347. template <typename Opr>
  348. typename AlgoChooser<Opr>::ImplExecutionPolicy
  349. AlgoChooser<Opr>::AlgoChooserHelper::choose_by_profile(
  350. const ExecutionStrategy& selected_strategy, bool enable_update) const {
  351. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile")))
  352. if (owner_graph()->options().no_profiling_on_shape_change) {
  353. auto policy = m_dnn_opr->execution_policy();
  354. if (policy.algo.valid()) {
  355. return policy;
  356. }
  357. if (!algo_usable_on_shape_change<Opr>()) {
  358. mgb_log_warn(
  359. "choose algo by heuristic, which may cause performance "
  360. "regression.");
  361. return choose_by_heuristic(selected_strategy);
  362. }
  363. }
  364. typename AlgoChooser<Opr>::ImplExecutionPolicy tmp_policy;
  365. bool retrive_from_cache = true;
  366. bool allow_log = false;
  367. construct_execution_policy(selected_strategy, tmp_policy,
  368. retrive_from_cache, allow_log);
  369. if (tmp_policy.algo.valid()) {
  370. // return policy when contruct successed
  371. return tmp_policy;
  372. }
  373. if (enable_update) {
  374. CircularDepsChecker circular_deps_checker;
  375. auto&& search_items =
  376. flatten_search_space<Opr>(*this, circular_deps_checker);
  377. FOREACH_OPR_TYPE_DISPATCH(search_items, {
  378. auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn);
  379. megdnn_opr->param() =
  380. Algorithm::deserialize_read_pod<typename _Opr::Param>(
  381. _item.param);
  382. typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper(
  383. to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
  384. _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
  385. m_allow_weight_preprocess);
  386. sub_helper.profile(selected_strategy);
  387. });
  388. }
  389. typename AlgoChooser<Opr>::ImplExecutionPolicy policy;
  390. construct_execution_policy(selected_strategy, policy);
  391. return policy;
  392. MIDOUT_E
  393. }
  394. template <typename Opr>
  395. typename AlgoChooser<Opr>::ImplAlgoDesc
  396. AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
  397. const ExecutionStrategy& selected_strategy) const {
  398. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache")))
  399. AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str());
  400. typename Opr::Param origin_param = m_dnn_opr->param();
  401. AlgoChooserProfileCache::Key cache_key{m_layouts.data(), m_layouts.size(),
  402. &origin_param, sizeof(origin_param)};
  403. auto&& rst = cache.get(cache_key);
  404. if (!rst.valid())
  405. return {};
  406. auto&& prof = rst.val();
  407. if (prof.empty())
  408. return {};
  409. auto target_attr = extract_algo_attribute(selected_strategy);
  410. bool skip_by_negative = false;
  411. for (auto&& i : prof) {
  412. auto attr_of_algo =
  413. static_cast<megdnn::Algorithm::Attribute>(i.attribute);
  414. bool contain_attr_all_positive =
  415. (target_attr.first == (attr_of_algo & target_attr.first));
  416. bool contain_attr_any_negative =
  417. static_cast<bool>(attr_of_algo & target_attr.second);
  418. if (contain_attr_all_positive) {
  419. if (!contain_attr_any_negative) {
  420. Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
  421. return algo_desc;
  422. } else {
  423. skip_by_negative = true;
  424. }
  425. }
  426. }
  427. if (skip_by_negative) {
  428. mgb_log_error(
  429. "No usable algo. There are available algos match positive "
  430. "strategy(%s), but filtered by negative stategy(%s).",
  431. Algorithm::attribute_str(target_attr.first).c_str(),
  432. Algorithm::attribute_str(target_attr.second).c_str());
  433. } else {
  434. mgb_log_error(
  435. "No usable algo. algos read from cache could not satisfy "
  436. "positive strategy(%s)",
  437. Algorithm::attribute_str(target_attr.first).c_str());
  438. }
  439. mgb_trap();
  440. MIDOUT_E
  441. }
  442. template <typename Opr>
  443. void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
  444. const ExecutionStrategy& selected_strategy,
  445. typename AlgoChooser<Opr>::ImplExecutionPolicy& policy,
  446. bool retrive_from_cache, bool allow_log) const {
  447. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy")))
  448. if (!policy.algo.valid()) {
  449. if (retrive_from_cache) {
  450. policy.algo = get_profile_result_from_cache(selected_strategy);
  451. if (!policy.algo.valid()) {
  452. if (allow_log) {
  453. auto target_attr =
  454. extract_algo_attribute(selected_strategy);
  455. std::string layouts_str = format_fixlayouts<Opr>(
  456. m_layouts, arity_in, arity_out);
  457. std::string msg = ssprintf(
  458. "(opr : %s, layouts %s, with attribute(%s) and "
  459. "without attribute(%s)",
  460. m_base_mgb_opr->dyn_typeinfo()->name,
  461. layouts_str.c_str(),
  462. Algorithm::attribute_str(target_attr.first).c_str(),
  463. Algorithm::attribute_str(target_attr.second)
  464. .c_str());
  465. mgb_log_warn(
  466. "No algo get from cache for %s. This may caused by "
  467. "mismatch with model and cache file or imcomplete "
  468. "cache file. ex. profiling with version1, but "
  469. "inferencing on version2 or profiling modelA but "
  470. "inferencing modelB",
  471. msg.c_str());
  472. }
  473. return;
  474. }
  475. } else {
  476. auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
  477. owner_graph(), m_cn, m_execution_policy.workspace_limit);
  478. auto attr = extract_algo_attribute(selected_strategy);
  479. policy.algo = APPLY(m_dnn_opr->get_algorithm_info_heuristic(
  480. args..., workspace_limit, attr.first,
  481. attr.second),
  482. m_layouts)
  483. .desc;
  484. mgb_assert(policy.algo.valid(),
  485. "No algo found from heuristic with strategy %u and "
  486. "workspace limit %zu",
  487. static_cast<uint32_t>(selected_strategy),
  488. workspace_limit);
  489. }
  490. }
  491. Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
  492. mgb_assert(algo, "Unknown algo description");
  493. std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list(
  494. to_layout_array<Opr>(m_layouts), m_dnn_opr);
  495. FOREACH_OPR_TYPE_DISPATCH(sub_items, {
  496. auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn);
  497. megdnn_opr->param() =
  498. Algorithm::deserialize_read_pod<typename _Opr::Param>(
  499. _item.param);
  500. typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper(
  501. to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(),
  502. _item.param, m_base_mgb_opr, m_cn, m_execution_policy,
  503. m_allow_weight_preprocess);
  504. policy.sub_policy.push_back({});
  505. sub_helper.construct_execution_policy(selected_strategy,
  506. policy.sub_policy.back(),
  507. retrive_from_cache, allow_log);
  508. if (!policy.sub_policy.back().algo.valid()) {
  509. // means sub_helper.construct_execution_policy fails. clean up
  510. // policy.algo and return
  511. policy = {};
  512. return;
  513. }
  514. });
  515. MIDOUT_E
  516. }
  517. template <typename Opr>
  518. size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes(
  519. const ImplExecutionPolicy& policy) const {
  520. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_workspace_size_bytes")))
  521. m_dnn_opr->execution_policy() = policy;
  522. size_t result;
  523. if_constexpr<opr_supports_preprocess<Opr>()>(
  524. [&](auto _) {
  525. auto&& opr = _(m_dnn_opr);
  526. auto prep = this->construct_fake_preprocess_filter();
  527. PreprocessFilter<Opr>* prep_ptr =
  528. prep.valid() ? &prep.val() : nullptr;
  529. result = std::max(
  530. APPLY(opr->get_preprocess_workspace_in_bytes(args...),
  531. m_layouts),
  532. APPLY(opr->get_workspace_in_bytes(args..., prep_ptr),
  533. m_layouts));
  534. },
  535. /* else */
  536. [&](auto _) {
  537. result = APPLY(_(m_dnn_opr)->get_workspace_in_bytes(args...),
  538. m_layouts);
  539. });
  540. return result;
  541. MIDOUT_E
  542. }
  543. template <typename Opr>
  544. std::vector<typename AlgoChooser<Opr>::ImplAlgo>
  545. AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const {
  546. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates")))
  547. auto heu = choose_by_heuristic(m_execution_policy.strategy);
  548. auto&& ret =
  549. APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_layouts);
  550. bool found = false;
  551. for (size_t i = 0; i < ret.size(); ++i) {
  552. if (ret[i].desc == heu.algo) {
  553. found = true;
  554. std::swap(ret[i], ret[0]);
  555. break;
  556. }
  557. }
  558. Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(heu.algo);
  559. mgb_assert(palgo, "Unknown algo description");
  560. mgb_assert(found,
  561. "algo %s got by heuristic not found in "
  562. "candidate list",
  563. palgo->name());
  564. return std::move(ret);
  565. MIDOUT_E
  566. }
  567. template <typename Opr>
  568. Maybe<AlgoChooserProfileCache::ResultEntry>
  569. AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo(
  570. const ImplExecutionPolicy& policy, double& timeout) const {
  571. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile_single_algo")))
  572. typename TimedProfiler<Opr>::Param param;
  573. // force check copy size <= dest len-1 from gcc8 for safe
  574. param.execution_policy =
  575. TimedProfiler<Opr>::Param::ExecutionPolicyBlob::serialize(policy);
  576. param.workspace = get_workspace_size_bytes(policy);
  577. for (int i = 0; i < arity; ++i) {
  578. auto&& src = m_layouts[i];
  579. mgb_assert(src.format.is_default() &&
  580. (src.dtype.category() == DTypeCategory::FLOAT ||
  581. src.dtype.category() == DTypeCategory::INT ||
  582. src.dtype.category() == DTypeCategory::QUANTIZED),
  583. "unsupported layout in profiling: %s",
  584. src.to_string().c_str());
  585. param.dtypes[i] = src.dtype.enumv();
  586. }
  587. param.comp_node_loc = m_cn.locator();
  588. mgb_assert(param.shapes.size() == m_layouts.size());
  589. for (size_t i = 0; i < param.shapes.size(); ++i)
  590. param.shapes[i] = m_layouts[i];
  591. param.opr_param = m_dnn_opr->param();
  592. param.allow_weight_preprocess = m_allow_weight_preprocess;
  593. Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
  594. mgb_assert(palgo, "can not find algo when profile single algo");
  595. auto rst = TimedProfiler<Opr>::profile(param, timeout);
  596. // MIOpen conv profiles all available algos when a specfic shape is
  597. // provided for the first time, which probably adds to the result time.
  598. // Therefore, a second profile execution is needed.
  599. if (strncmp(palgo->name(), "MIOpen", 6) == 0) {
  600. rst = TimedProfiler<Opr>::profile(param, timeout);
  601. }
  602. if (!rst.valid())
  603. return None;
  604. std::string algo_desc;
  605. serialize_write_pod(policy.algo, algo_desc);
  606. return AlgoChooserProfileCache::ResultEntry{
  607. algo_desc, static_cast<uint32_t>(palgo->attribute()),
  608. rst.val().time, param.workspace};
  609. MIDOUT_E
  610. }
  611. template <typename Opr>
  612. void AlgoChooser<Opr>::AlgoChooserHelper::profile(
  613. const ExecutionStrategy& selected_strategy) const {
  614. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile")))
  615. if (get_profile_result_from_cache(selected_strategy).valid())
  616. return;
  617. AlgoChooserProfileCache::Result prof_rst;
  618. auto target_attr = extract_algo_attribute(selected_strategy);
  619. std::string layouts_str =
  620. format_fixlayouts<Opr>(m_layouts, arity_in, arity_out);
  621. double cur_timeout = 0;
  622. auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
  623. owner_graph(), m_cn, m_execution_policy.workspace_limit);
  624. RealTimer timer;
  625. for (auto algo : get_all_candidates()) {
  626. Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst;
  627. ImplExecutionPolicy policy;
  628. policy.algo = algo.desc;
  629. //! check negative attribute : skip negative attribute
  630. auto palgo = m_dnn_opr->get_algorithm_from_desc(policy.algo);
  631. if (palgo->contain_attribute_any(target_attr.second)) {
  632. mgb_log_debug(
  633. "skip algo %s, which matches the profile strategy required "
  634. "'not contain attribute(%s).'",
  635. algo.desc.name.c_str(),
  636. Algorithm::attribute_str(target_attr.second).c_str());
  637. continue;
  638. }
  639. //! check workspace limit
  640. construct_execution_policy(selected_strategy, policy);
  641. mgb_assert(policy.algo.valid(),
  642. "construct execution policy must success when profiling");
  643. if (get_workspace_size_bytes(policy) > workspace_limit) {
  644. continue;
  645. }
  646. std::string msg = ssprintf("profiling %s algorithm %s %s",
  647. m_base_mgb_opr->dyn_typeinfo()->name,
  648. algo.desc.name.c_str(), layouts_str.c_str());
  649. timer.reset();
  650. MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); }
  651. MGB_CATCH(std::exception & exc, {
  652. mgb_log_warn("caught exception during %s: %s", msg.c_str(),
  653. exc.what());
  654. continue;
  655. })
  656. MGB_CATCH(..., {
  657. mgb_log_warn("caught exception during %s", msg.c_str());
  658. continue;
  659. })
  660. if (!cur_rst.valid()) {
  661. mgb_log_warn("timeout when %s; timeout setting: %.3fsec",
  662. msg.c_str(), cur_timeout);
  663. continue;
  664. }
  665. if (!cur_timeout) {
  666. cur_timeout = timer.get_secs() + TIMEOUT_TOLERANCE;
  667. } else {
  668. cur_timeout =
  669. std::min(cur_timeout, timer.get_secs() + TIMEOUT_TOLERANCE);
  670. }
  671. auto&& rst = cur_rst.val();
  672. mgb_log_debug("%s: workspace: %zu; time: %.3gsec", msg.c_str(),
  673. rst.workspace, rst.time);
  674. prof_rst.push_back(rst);
  675. }
  676. std::string msg = ssprintf(
  677. "no usable %s algorithm %s without attribute(%s) or could not meet "
  678. "workspace limite requirement(%zu)",
  679. m_base_mgb_opr->dyn_typeinfo()->name, layouts_str.c_str(),
  680. Algorithm::attribute_str(target_attr.second).c_str(),
  681. workspace_limit);
  682. mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
  683. FixedTensorLayouts origin_layouts = m_layouts;
  684. typename Opr::Param origin_param = m_dnn_opr->param();
  685. AlgoChooserProfileCache::Key cache_key{origin_layouts.data(),
  686. origin_layouts.size(), &origin_param,
  687. sizeof(origin_param)};
  688. AlgoChooserProfileCache cache(m_cn, profile_name(m_dnn_opr).c_str());
  689. cache.put(cache_key, prof_rst);
  690. MIDOUT_E
  691. }
  692. template <typename Opr>
  693. Maybe<PreprocessFilter<Opr>>
  694. AlgoChooser<Opr>::AlgoChooserHelper::construct_fake_preprocess_filter() const {
  695. MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_fake_preprocess_filter")))
  696. Maybe<PreprocessFilter<Opr>> result = None;
  697. if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) {
  698. if (!m_allow_weight_preprocess)
  699. return;
  700. auto opr = _(m_dnn_opr);
  701. auto layouts = APPLY(opr->deduce_preprocessed_filter_layout(args...),
  702. m_layouts);
  703. //! No preprocess layout means no need weight preprocess
  704. if (layouts.empty()) {
  705. return;
  706. }
  707. //! all layouts arm empty means no need weight preprocess
  708. bool layout_valid = false;
  709. for (auto&& layout : layouts) {
  710. if (!layout.is_empty()) {
  711. layout_valid = true;
  712. }
  713. }
  714. if (!layout_valid) {
  715. return;
  716. }
  717. result = PreprocessFilter<Opr>{};
  718. auto& res = result.val();
  719. res.algorithm_id = nullptr;
  720. res.tensors.resize(layouts.size());
  721. for (size_t i = 0; i < layouts.size(); i++) {
  722. res.tensors[i] = megdnn::TensorND(nullptr, layouts[i]);
  723. }
  724. });
  725. return result;
  726. MIDOUT_E
  727. }
  728. template <typename Opr>
  729. std::pair<AlgoAttribute, AlgoAttribute>
  730. AlgoChooser<Opr>::AlgoChooserHelper::extract_algo_attribute(
  731. const ExecutionStrategy& strategy) const {
  732. std::pair<AlgoAttribute, AlgoAttribute> ret =
  733. std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
  734. //! from strategy
  735. if (strategy & ExecutionStrategy::REPRODUCIBLE) {
  736. ret.first |= AlgoAttribute::REPRODUCIBLE;
  737. }
  738. if (strategy & ExecutionStrategy::OPTMIZED) {
  739. ret.second |= AlgoAttribute::NAIVE;
  740. }
  741. return ret;
  742. }
  743. #define INST(Opr) \
  744. template AlgoChooser<megdnn::Opr>::AlgoChooserHelper::AlgoChooserHelper( \
  745. const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
  746. const std::string& param_str, const cg::OperatorNodeBase* mgb_opr, \
  747. const CompNode& cn, \
  748. const megdnn::param::ExecutionPolicy& execution_policy, \
  749. bool allow_weight_preprocess); \
  750. template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
  751. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_heuristic( \
  752. const ExecutionStrategy& select_strategy) const; \
  753. template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
  754. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_profile( \
  755. const ExecutionStrategy& select_strategy, bool enable_update) \
  756. const; \
  757. template typename AlgoChooser<megdnn::Opr>::ImplAlgoDesc \
  758. AlgoChooser<megdnn::Opr>::AlgoChooserHelper:: \
  759. get_profile_result_from_cache( \
  760. const ExecutionStrategy& select_strategy) const; \
  761. template void \
  762. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::construct_execution_policy( \
  763. const ExecutionStrategy& select_strategy, \
  764. typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \
  765. bool retrive_from_cache, bool allow_log) const; \
  766. template size_t \
  767. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::get_workspace_size_bytes( \
  768. const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \
  769. policy) const; \
  770. template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \
  771. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::get_all_candidates() const; \
  772. template Maybe<AlgoChooserProfileCache::ResultEntry> \
  773. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::profile_single_algo( \
  774. const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \
  775. policy, \
  776. double& timeout) const; \
  777. template std::pair<AlgoAttribute, AlgoAttribute> \
  778. AlgoChooser<megdnn::Opr>::AlgoChooserHelper::extract_algo_attribute( \
  779. const ExecutionStrategy& strategy) const; \
  780. template void AlgoChooser<megdnn::Opr>::AlgoChooserHelper::profile( \
  781. const ExecutionStrategy& selected_strategy) const;
  782. MGB_FOREACH_FASTRUN_OPR(INST)
  783. #undef INST
  784. //////////////////////////////// AlgoChoose /////////////////////////////
  785. template <typename Opr>
  786. typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy(
  787. const AlgoChooserHelper& helper) {
  788. auto opr_strategy = helper.execution_policy().strategy;
  789. if (opr_strategy & ExecutionStrategy::HEURISTIC) {
  790. if (opr_strategy & ExecutionStrategy::PROFILE) {
  791. //! this strategy will choose from cache first, then choost by
  792. //! heuristic if fail.
  793. ImplExecutionPolicy policy =
  794. helper.choose_by_profile(opr_strategy, false);
  795. if (!policy.algo.valid()) {
  796. policy = helper.choose_by_heuristic(opr_strategy);
  797. }
  798. return policy;
  799. } else {
  800. return helper.choose_by_heuristic(opr_strategy);
  801. }
  802. }
  803. #if MGB_ENABLE_FASTRUN
  804. else if (opr_strategy & ExecutionStrategy::PROFILE) {
  805. return helper.choose_by_profile(opr_strategy, true);
  806. }
  807. #endif
  808. else {
  809. mgb_throw(GraphError, "bad ExecutionPolicy strategy");
  810. }
  811. }
  812. template <typename Opr>
  813. size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts,
  814. Opr* megdnn_opr, const MGBOpr* mgb_opr,
  815. bool allow_weight_preprocess) {
  816. if (WorkspaceLimitGetter::is_prealloc_run(mgb_opr->owner_graph())) {
  817. return 0;
  818. }
  819. std::string param_str;
  820. Algorithm::serialize_write_pod(megdnn_opr->param(), param_str);
  821. AlgoChooserHelper helper(layouts, megdnn_opr, param_str, mgb_opr,
  822. mgb_opr->comp_node(), mgb_opr->execution_policy(),
  823. allow_weight_preprocess);
  824. ImplExecutionPolicy policy;
  825. if (auto algo_choose_hook = mgb_opr->algo_chooser()) {
  826. policy = algo_choose_hook(mgb_opr);
  827. auto strategy =
  828. ExecutionStrategy::HEURISTIC | ExecutionStrategy::REPRODUCIBLE;
  829. bool retrive_from_cache = false;
  830. helper.construct_execution_policy(strategy, policy, retrive_from_cache);
  831. }
  832. if (!policy.algo.valid()) {
  833. policy = get_policy(helper);
  834. }
  835. size_t workspace = helper.get_workspace_size_bytes(policy);
  836. std::string ret;
  837. ret.append(mgb_opr->dyn_typeinfo()->name);
  838. ret += format_fixlayouts<Opr>(layouts, arity_in, arity_out);
  839. Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo);
  840. mgb_assert(palgo, "Unknown algo description");
  841. ret.append("): algo=" + std::string(palgo->name()));
  842. ret.append(ssprintf(" workspace=%.2fMiB attirbute=%d",
  843. workspace / (1024 * 1024.0),
  844. static_cast<uint32_t>(palgo->attribute())));
  845. mgb_log_debug("%s", ret.c_str());
  846. megdnn_opr->execution_policy() = policy;
  847. return workspace;
  848. }
  849. #define INST(Opr) \
  850. template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \
  851. AlgoChooser<megdnn::Opr>::get_policy(const AlgoChooserHelper& proxy); \
  852. template size_t AlgoChooser<megdnn::Opr>::setup_algo( \
  853. const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
  854. const MGBOpr* mgb_opr, bool allow_weight_preprocess);
  855. MGB_FOREACH_FASTRUN_OPR(INST)
  856. #undef INST
  857. } // namespace opr
  858. } // namespace mgb
  859. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台