You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastrun_options.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. #include <gflags/gflags.h>
  2. #if defined(_WIN32)
  3. #include <io.h>
  4. #define F_OK 0
  5. #define access(a, b) _access(a, b)
  6. #elif __linux__ || __unix__ || __APPLE__
  7. #include <unistd.h>
  8. #endif
  9. #include "fastrun_options.h"
  10. #include "megbrain/gopt/inference.h"
  11. #include "megbrain/utils/infile_persistent_cache.h"
  12. #include "misc.h"
  13. #include "models/model_lite.h"
  14. #include "models/model_mdl.h"
  15. namespace lar {
  16. template <>
  17. void FastRunOption::config_model_internel<ModelLite>(
  18. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  19. if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
  20. //! set the algo policy before model load
  21. using Strategy = ModelLite::Strategy;
  22. uint32_t strategy = 0;
  23. #if MGB_ENABLE_FASTRUN
  24. if (enable_full_run) {
  25. LITE_LOG("enable full-run strategy for algo profile");
  26. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy;
  27. } else if (enable_fast_run) {
  28. LITE_LOG("enable fast-run strategy for algo profile");
  29. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
  30. static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy;
  31. } else if ((!m_fast_run_cache.empty() &&
  32. !access(m_fast_run_cache.c_str(), F_OK))) {
  33. LITE_LOG(
  34. "detect fast-run cache usable set LITE_ALGO_PROFILE for algo "
  35. "profile");
  36. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
  37. static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
  38. } else {
  39. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
  40. }
  41. #else
  42. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
  43. #endif
  44. if (batch_binary_equal || enable_reproducible) {
  45. LITE_LOG("enable reproducible strategy for algo profile");
  46. if (batch_binary_equal)
  47. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) |
  48. strategy;
  49. }
  50. auto lite_strategy = static_cast<Strategy>(strategy);
  51. //! set algo policy for model
  52. auto&& lite_network = model->get_lite_network();
  53. lite::Runtime::set_network_algo_policy(
  54. lite_network, lite_strategy, share_batch_size, batch_binary_equal);
  55. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  56. if (!m_fast_run_cache.empty()) {
  57. if (!access(m_fast_run_cache.c_str(), F_OK)) {
  58. lite::set_persistent_cache(m_fast_run_cache);
  59. } else {
  60. lite::set_persistent_cache(m_fast_run_cache, true);
  61. }
  62. }
  63. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  64. #if MGB_ENABLE_FASTRUN
  65. //! dump algo cache
  66. if (!m_fast_run_cache.empty()) {
  67. lite::dump_persistent_cache(m_fast_run_cache);
  68. }
  69. #endif
  70. }
  71. }
  72. template <>
  73. void FastRunOption::config_model_internel<ModelMdl>(
  74. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  75. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  76. //! set the algo policy before model load
  77. using Strategy = ModelMdl::Strategy;
  78. auto strategy = static_cast<Strategy>(0);
  79. #if MGB_ENABLE_FASTRUN
  80. if (enable_full_run) {
  81. mgb_log("enable full-run strategy for algo profile");
  82. strategy = Strategy::PROFILE | strategy;
  83. } else if (enable_fast_run) {
  84. mgb_log("enable fast-run strategy for algo profile");
  85. strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy;
  86. } else {
  87. strategy = Strategy::HEURISTIC | strategy;
  88. }
  89. #else
  90. strategy = Strategy::HEURISTIC | strategy;
  91. #endif
  92. if (batch_binary_equal || enable_reproducible) {
  93. mgb_log("enable reproducible strategy for algo profile");
  94. strategy = Strategy::REPRODUCIBLE | strategy;
  95. }
  96. model->set_mdl_strategy(strategy);
  97. //! set binary_equal_between_batch and shared_batch_size
  98. if (batch_binary_equal) {
  99. mgb_log("enable batch binary equal");
  100. model->get_mdl_config()
  101. .comp_graph->options()
  102. .fast_run_config.binary_equal_between_batch = true;
  103. }
  104. if (share_batch_size > 0) {
  105. mgb_log("set shared shared batch");
  106. model->get_mdl_config()
  107. .comp_graph->options()
  108. .fast_run_config.shared_batch_size = share_batch_size;
  109. }
  110. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  111. auto& vars = model->get_mdl_load_result().output_var_list;
  112. auto&& strategy = model->get_mdl_strategy();
  113. mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
  114. // set algo cache path
  115. if (!m_fast_run_cache.empty()) {
  116. if (!access(m_fast_run_cache.c_str(), F_OK)) {
  117. mgb::PersistentCache::set_impl(
  118. std::make_shared<mgb::InFilePersistentCache>(
  119. m_fast_run_cache.c_str()));
  120. } else {
  121. mgb::PersistentCache::set_impl(
  122. std::make_shared<mgb::InFilePersistentCache>());
  123. }
  124. #if MGB_ENABLE_FASTRUN
  125. if (!enable_full_run && !enable_fast_run)
  126. #endif
  127. mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
  128. }
  129. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  130. #if MGB_ENABLE_FASTRUN
  131. //! dump algo cache
  132. if (!m_fast_run_cache.empty()) {
  133. static_cast<mgb::InFilePersistentCache&>(mgb::PersistentCache::inst())
  134. .dump_cache(m_fast_run_cache.c_str());
  135. }
  136. #endif
  137. }
  138. }
  139. } // namespace lar
  140. using namespace lar;
  141. bool FastRunOption::m_valid;
  142. void FastRunOption::update() {
  143. m_option_name = "fastrun";
  144. #if MGB_ENABLE_FASTRUN
  145. enable_fast_run = FLAGS_fast_run;
  146. enable_full_run = FLAGS_full_run;
  147. #endif
  148. batch_binary_equal = FLAGS_binary_equal_between_batch;
  149. enable_reproducible = FLAGS_reproducible;
  150. m_fast_run_cache = FLAGS_fast_run_algo_policy;
  151. share_batch_size = FLAGS_fast_run_shared_batch_size;
  152. m_option = {
  153. #if MGB_ENABLE_FASTRUN
  154. {"fast_run", lar::Bool::make(false)},
  155. {"full_run", lar::Bool::make(false)},
  156. #endif
  157. {"binary_equal_between_batch", lar::Bool::make(false)},
  158. {"reproducible", lar::Bool::make(false)}
  159. };
  160. #if MGB_ENABLE_FASTRUN
  161. std::static_pointer_cast<lar::Bool>(m_option["fast_run"])
  162. ->set_value(FLAGS_fast_run);
  163. std::static_pointer_cast<lar::Bool>(m_option["full_run"])
  164. ->set_value(FLAGS_full_run);
  165. #endif
  166. std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"])
  167. ->set_value(FLAGS_binary_equal_between_batch);
  168. std::static_pointer_cast<lar::Bool>(m_option["reproducible"])
  169. ->set_value(FLAGS_reproducible);
  170. #if MGB_ENABLE_FASTRUN
  171. //! while fastrun cache file path is not empty and can't be accessed
  172. if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) {
  173. mgb_assert(
  174. enable_full_run || enable_fast_run,
  175. "--fast-run or --full-run should be enabled");
  176. }
  177. if (share_batch_size) {
  178. mgb_assert(
  179. enable_full_run || enable_fast_run || !m_fast_run_cache.empty(),
  180. "--fast-run-shared-batch-size should be used with "
  181. "--fast-run|--full-run|--fast-run-algo-policy");
  182. }
  183. #endif
  184. }
  185. bool FastRunOption::is_valid() {
  186. bool ret = false;
  187. #if MGB_ENABLE_FASTRUN
  188. ret = ret || FLAGS_fast_run;
  189. ret = ret || FLAGS_full_run;
  190. #endif
  191. ret = ret || FLAGS_binary_equal_between_batch;
  192. ret = ret || FLAGS_fast_run_shared_batch_size > 0;
  193. ret = ret || FLAGS_reproducible;
  194. ret = ret || FLAGS_fast_run_algo_policy.size() > 0;
  195. return ret || m_valid;
  196. }
  197. std::shared_ptr<OptionBase> FastRunOption::create_option() {
  198. static std::shared_ptr<FastRunOption> option(new FastRunOption);
  199. if (FastRunOption::is_valid()) {
  200. option->update();
  201. return std::static_pointer_cast<OptionBase>(option);
  202. } else {
  203. return nullptr;
  204. }
  205. }
  206. void FastRunOption::config_model(
  207. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  208. #if MGB_ENABLE_FASTRUN
  209. enable_fast_run =
  210. std::static_pointer_cast<lar::Bool>(m_option["fast_run"])->get_value();
  211. enable_full_run =
  212. std::static_pointer_cast<lar::Bool>(m_option["full_run"])->get_value();
  213. mgb_throw_if(
  214. enable_fast_run && enable_full_run, mgb::AssertionError,
  215. "invalid options of both fast-run and full-run");
  216. #endif
  217. batch_binary_equal =
  218. std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"])
  219. ->get_value();
  220. enable_reproducible =
  221. std::static_pointer_cast<lar::Bool>(m_option["reproducible"])->get_value();
  222. CONFIG_MODEL_FUN;
  223. }
  224. #if MGB_ENABLE_FASTRUN
  225. DEFINE_bool(fast_run, false, "whether to use fast-run in model run");
  226. DEFINE_bool(full_run, false, "whether to use full-run in model run");
  227. #endif
  228. DEFINE_bool(
  229. binary_equal_between_batch, false,
  230. "Each batch of output is promised binary equal if each batch of "
  231. "input is binary equal\n Note that if this option is turned on, "
  232. "`--reproducible` will also be turned on.");
  233. DEFINE_bool(
  234. reproducible, false,
  235. "Enable choose algo which is reproducible. It mainly used for "
  236. "cudnn algos.See "
  237. "https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/"
  238. "index.html#reproducibility"
  239. "for more details.");
  240. DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun");
  241. DEFINE_string(fast_run_algo_policy, "", "fast-run cache path.");
  242. REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option);
  243. REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid);