You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastrun_options.cpp 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. #include <gflags/gflags.h>
  2. #if defined(_WIN32)
  3. #include <io.h>
  4. #define F_OK 0
  5. #define access(a, b) _access(a, b)
  6. #elif __linux__ || __unix__ || __APPLE__
  7. #include <unistd.h>
  8. #endif
  9. #include "fastrun_options.h"
  10. #include "megbrain/gopt/inference.h"
  11. #include "megbrain/utils/infile_persistent_cache.h"
  12. #include "misc.h"
  13. #include "models/model_lite.h"
  14. #include "models/model_mdl.h"
  15. namespace lar {
  16. template <>
  17. void FastRunOption::config_model_internel<ModelLite>(
  18. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  19. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  20. //! set the algo policy before model load
  21. using Strategy = ModelLite::Strategy;
  22. uint32_t strategy = 0;
  23. #if MGB_ENABLE_FASTRUN
  24. if (enable_full_run) {
  25. LITE_LOG("enable full-run strategy for algo profile");
  26. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy;
  27. } else if (enable_fast_run) {
  28. LITE_LOG("enable fast-run strategy for algo profile");
  29. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
  30. static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy;
  31. } else {
  32. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
  33. }
  34. #else
  35. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
  36. #endif
  37. if (batch_binary_equal || enable_reproducible) {
  38. LITE_LOG("enable reproducible strategy for algo profile");
  39. if (batch_binary_equal)
  40. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) |
  41. strategy;
  42. }
  43. auto lite_strategy = static_cast<Strategy>(strategy);
  44. model->set_lite_strategy(lite_strategy);
  45. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  46. auto&& lite_network = model->get_lite_network();
  47. auto&& lite_strategy = model->get_lite_strategy();
  48. //! set algo policy for model
  49. lite::Runtime::set_network_algo_policy(
  50. lite_network, lite_strategy, share_batch_size, batch_binary_equal);
  51. if (!m_fast_run_cache.empty()) {
  52. if (!access(m_fast_run_cache.c_str(), F_OK)) {
  53. lite::set_persistent_cache(m_fast_run_cache);
  54. } else {
  55. lite::set_persistent_cache(m_fast_run_cache, true);
  56. }
  57. //! TODO:this is from mdl model settings but not matched settings in
  58. //! lite model
  59. // if (!enable_full_run && !enable_fast_run)
  60. // mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
  61. }
  62. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  63. #if MGB_ENABLE_FASTRUN
  64. //! dump algo cache
  65. if (!m_fast_run_cache.empty()) {
  66. lite::dump_persistent_cache(m_fast_run_cache);
  67. }
  68. #endif
  69. }
  70. }
  71. template <>
  72. void FastRunOption::config_model_internel<ModelMdl>(
  73. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  74. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  75. //! set the algo policy before model load
  76. using Strategy = ModelMdl::Strategy;
  77. auto strategy = static_cast<Strategy>(0);
  78. #if MGB_ENABLE_FASTRUN
  79. if (enable_full_run) {
  80. mgb_log("enable full-run strategy for algo profile");
  81. strategy = Strategy::PROFILE | strategy;
  82. } else if (enable_fast_run) {
  83. mgb_log("enable fast-run strategy for algo profile");
  84. strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy;
  85. } else {
  86. strategy = Strategy::HEURISTIC | strategy;
  87. }
  88. #else
  89. strategy = Strategy::HEURISTIC | strategy;
  90. #endif
  91. if (batch_binary_equal || enable_reproducible) {
  92. mgb_log("enable reproducible strategy for algo profile");
  93. strategy = Strategy::REPRODUCIBLE | strategy;
  94. }
  95. model->set_mdl_strategy(strategy);
  96. //! set binary_equal_between_batch and shared_batch_size
  97. if (batch_binary_equal) {
  98. mgb_log("enable batch binary equal");
  99. model->get_mdl_config()
  100. .comp_graph->options()
  101. .fast_run_config.binary_equal_between_batch = true;
  102. }
  103. if (share_batch_size > 0) {
  104. mgb_log("set shared shared batch");
  105. model->get_mdl_config()
  106. .comp_graph->options()
  107. .fast_run_config.shared_batch_size = share_batch_size;
  108. }
  109. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  110. auto& vars = model->get_mdl_load_result().output_var_list;
  111. auto&& strategy = model->get_mdl_strategy();
  112. mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
  113. // set algo cache path
  114. if (!m_fast_run_cache.empty()) {
  115. if (!access(m_fast_run_cache.c_str(), F_OK)) {
  116. mgb::PersistentCache::set_impl(
  117. std::make_shared<mgb::InFilePersistentCache>(
  118. m_fast_run_cache.c_str()));
  119. } else {
  120. mgb::PersistentCache::set_impl(
  121. std::make_shared<mgb::InFilePersistentCache>());
  122. }
  123. #if MGB_ENABLE_FASTRUN
  124. if (!enable_full_run && !enable_fast_run)
  125. #endif
  126. mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
  127. }
  128. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  129. #if MGB_ENABLE_FASTRUN
  130. //! dump algo cache
  131. if (!m_fast_run_cache.empty()) {
  132. static_cast<mgb::InFilePersistentCache&>(mgb::PersistentCache::inst())
  133. .dump_cache(m_fast_run_cache.c_str());
  134. }
  135. #endif
  136. }
  137. }
  138. } // namespace lar
  139. using namespace lar;
  140. bool FastRunOption::m_valid;
  141. void FastRunOption::update() {
  142. m_option_name = "fastrun";
  143. #if MGB_ENABLE_FASTRUN
  144. enable_fast_run = FLAGS_fast_run;
  145. enable_full_run = FLAGS_full_run;
  146. #endif
  147. batch_binary_equal = FLAGS_binary_equal_between_batch;
  148. enable_reproducible = FLAGS_reproducible;
  149. m_fast_run_cache = FLAGS_fast_run_algo_policy;
  150. share_batch_size = FLAGS_fast_run_shared_batch_size;
  151. m_option = {
  152. #if MGB_ENABLE_FASTRUN
  153. {"fast_run", lar::Bool::make(false)},
  154. {"full_run", lar::Bool::make(false)},
  155. #endif
  156. {"binary_equal_between_batch", lar::Bool::make(false)},
  157. {"reproducible", lar::Bool::make(false)}
  158. };
  159. #if MGB_ENABLE_FASTRUN
  160. std::static_pointer_cast<lar::Bool>(m_option["fast_run"])
  161. ->set_value(FLAGS_fast_run);
  162. std::static_pointer_cast<lar::Bool>(m_option["full_run"])
  163. ->set_value(FLAGS_full_run);
  164. #endif
  165. std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"])
  166. ->set_value(FLAGS_binary_equal_between_batch);
  167. std::static_pointer_cast<lar::Bool>(m_option["reproducible"])
  168. ->set_value(FLAGS_reproducible);
  169. #if MGB_ENABLE_FASTRUN
  170. //! while fastrun cache file path is not empty and can't be accessed
  171. if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) {
  172. mgb_assert(
  173. enable_full_run || enable_fast_run,
  174. "--fast-run or --full-run should be enabled");
  175. }
  176. if (share_batch_size) {
  177. mgb_assert(
  178. enable_full_run || enable_fast_run || !m_fast_run_cache.empty(),
  179. "--fast-run-shared-batch-size should be used with "
  180. "--fast-run|--full-run|--fast-run-algo-policy");
  181. }
  182. #endif
  183. }
  184. bool FastRunOption::is_valid() {
  185. bool ret = false;
  186. #if MGB_ENABLE_FASTRUN
  187. ret = ret || FLAGS_fast_run;
  188. ret = ret || FLAGS_full_run;
  189. #endif
  190. ret = ret || FLAGS_binary_equal_between_batch;
  191. ret = ret || FLAGS_fast_run_shared_batch_size > 0;
  192. ret = ret || FLAGS_reproducible;
  193. ret = ret || FLAGS_fast_run_algo_policy.size() > 0;
  194. return ret || m_valid;
  195. }
  196. std::shared_ptr<OptionBase> FastRunOption::create_option() {
  197. static std::shared_ptr<FastRunOption> option(new FastRunOption);
  198. if (FastRunOption::is_valid()) {
  199. option->update();
  200. return std::static_pointer_cast<OptionBase>(option);
  201. } else {
  202. return nullptr;
  203. }
  204. }
  205. void FastRunOption::config_model(
  206. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  207. #if MGB_ENABLE_FASTRUN
  208. enable_fast_run =
  209. std::static_pointer_cast<lar::Bool>(m_option["fast_run"])->get_value();
  210. enable_full_run =
  211. std::static_pointer_cast<lar::Bool>(m_option["full_run"])->get_value();
  212. mgb_throw_if(
  213. enable_fast_run && enable_full_run, mgb::AssertionError,
  214. "invalid options of both fast-run and full-run");
  215. #endif
  216. batch_binary_equal =
  217. std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"])
  218. ->get_value();
  219. enable_reproducible =
  220. std::static_pointer_cast<lar::Bool>(m_option["reproducible"])->get_value();
  221. CONFIG_MODEL_FUN;
  222. }
  223. #if MGB_ENABLE_FASTRUN
  224. DEFINE_bool(fast_run, false, "whether to use fast-run in model run");
  225. DEFINE_bool(full_run, false, "whether to use full-run in model run");
  226. #endif
  227. DEFINE_bool(
  228. binary_equal_between_batch, false,
  229. "Each batch of output is promised binary equal if each batch of "
  230. "input is binary equal\n Note that if this option is turned on, "
  231. "`--reproducible` will also be turned on.");
  232. DEFINE_bool(
  233. reproducible, false,
  234. "Enable choose algo which is reproducible. It mainly used for "
  235. "cudnn algos.See "
  236. "https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/"
  237. "index.html#reproducibility"
  238. "for more details.");
  239. DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun");
  240. DEFINE_string(fast_run_algo_policy, "", "fast-run cache path.");
  241. REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option);
  242. REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid);