You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fastrun_options.cpp 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. #include <gflags/gflags.h>
  2. #if defined(_WIN32)
  3. #include <io.h>
  4. #define F_OK 0
  5. #define access(a, b) _access(a, b)
  6. #elif __linux__ || __unix__ || __APPLE__
  7. #include <unistd.h>
  8. #endif
  9. #include "fastrun_options.h"
  10. #include "megbrain/gopt/inference.h"
  11. #include "megbrain/utils/infile_persistent_cache.h"
  12. #include "misc.h"
  13. #include "models/model_lite.h"
  14. #include "models/model_mdl.h"
  15. namespace lar {
  16. template <>
  17. void FastRunOption::config_model_internel<ModelLite>(
  18. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  19. if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
  20. //! set the algo policy before model load
  21. using Strategy = ModelLite::Strategy;
  22. uint32_t strategy = 0;
  23. #if MGB_ENABLE_FASTRUN
  24. if (enable_full_run) {
  25. LITE_LOG("enable full-run strategy for algo profile");
  26. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy;
  27. } else if (enable_fast_run) {
  28. LITE_LOG("enable fast-run strategy for algo profile");
  29. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
  30. static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy;
  31. } else {
  32. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
  33. }
  34. #else
  35. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
  36. #endif
  37. if (batch_binary_equal || enable_reproducible) {
  38. LITE_LOG("enable reproducible strategy for algo profile");
  39. if (batch_binary_equal)
  40. strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) |
  41. strategy;
  42. }
  43. auto lite_strategy = static_cast<Strategy>(strategy);
  44. //! set algo policy for model
  45. auto&& lite_network = model->get_lite_network();
  46. lite::Runtime::set_network_algo_policy(
  47. lite_network, lite_strategy, share_batch_size, batch_binary_equal);
  48. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  49. if (!m_fast_run_cache.empty()) {
  50. if (!access(m_fast_run_cache.c_str(), F_OK)) {
  51. lite::set_persistent_cache(m_fast_run_cache);
  52. } else {
  53. lite::set_persistent_cache(m_fast_run_cache, true);
  54. }
  55. }
  56. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  57. #if MGB_ENABLE_FASTRUN
  58. //! dump algo cache
  59. if (!m_fast_run_cache.empty()) {
  60. lite::dump_persistent_cache(m_fast_run_cache);
  61. }
  62. #endif
  63. }
  64. }
  65. template <>
  66. void FastRunOption::config_model_internel<ModelMdl>(
  67. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  68. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  69. //! set the algo policy before model load
  70. using Strategy = ModelMdl::Strategy;
  71. auto strategy = static_cast<Strategy>(0);
  72. #if MGB_ENABLE_FASTRUN
  73. if (enable_full_run) {
  74. mgb_log("enable full-run strategy for algo profile");
  75. strategy = Strategy::PROFILE | strategy;
  76. } else if (enable_fast_run) {
  77. mgb_log("enable fast-run strategy for algo profile");
  78. strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy;
  79. } else {
  80. strategy = Strategy::HEURISTIC | strategy;
  81. }
  82. #else
  83. strategy = Strategy::HEURISTIC | strategy;
  84. #endif
  85. if (batch_binary_equal || enable_reproducible) {
  86. mgb_log("enable reproducible strategy for algo profile");
  87. strategy = Strategy::REPRODUCIBLE | strategy;
  88. }
  89. model->set_mdl_strategy(strategy);
  90. //! set binary_equal_between_batch and shared_batch_size
  91. if (batch_binary_equal) {
  92. mgb_log("enable batch binary equal");
  93. model->get_mdl_config()
  94. .comp_graph->options()
  95. .fast_run_config.binary_equal_between_batch = true;
  96. }
  97. if (share_batch_size > 0) {
  98. mgb_log("set shared shared batch");
  99. model->get_mdl_config()
  100. .comp_graph->options()
  101. .fast_run_config.shared_batch_size = share_batch_size;
  102. }
  103. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  104. auto& vars = model->get_mdl_load_result().output_var_list;
  105. auto&& strategy = model->get_mdl_strategy();
  106. mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy);
  107. // set algo cache path
  108. if (!m_fast_run_cache.empty()) {
  109. if (!access(m_fast_run_cache.c_str(), F_OK)) {
  110. mgb::PersistentCache::set_impl(
  111. std::make_shared<mgb::InFilePersistentCache>(
  112. m_fast_run_cache.c_str()));
  113. } else {
  114. mgb::PersistentCache::set_impl(
  115. std::make_shared<mgb::InFilePersistentCache>());
  116. }
  117. #if MGB_ENABLE_FASTRUN
  118. if (!enable_full_run && !enable_fast_run)
  119. #endif
  120. mgb::gopt::enable_opr_use_profiling_cache_inplace(vars);
  121. }
  122. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  123. #if MGB_ENABLE_FASTRUN
  124. //! dump algo cache
  125. if (!m_fast_run_cache.empty()) {
  126. static_cast<mgb::InFilePersistentCache&>(mgb::PersistentCache::inst())
  127. .dump_cache(m_fast_run_cache.c_str());
  128. }
  129. #endif
  130. }
  131. }
  132. } // namespace lar
  133. using namespace lar;
  134. bool FastRunOption::m_valid;
  135. void FastRunOption::update() {
  136. m_option_name = "fastrun";
  137. #if MGB_ENABLE_FASTRUN
  138. enable_fast_run = FLAGS_fast_run;
  139. enable_full_run = FLAGS_full_run;
  140. #endif
  141. batch_binary_equal = FLAGS_binary_equal_between_batch;
  142. enable_reproducible = FLAGS_reproducible;
  143. m_fast_run_cache = FLAGS_fast_run_algo_policy;
  144. share_batch_size = FLAGS_fast_run_shared_batch_size;
  145. m_option = {
  146. #if MGB_ENABLE_FASTRUN
  147. {"fast_run", lar::Bool::make(false)},
  148. {"full_run", lar::Bool::make(false)},
  149. #endif
  150. {"binary_equal_between_batch", lar::Bool::make(false)},
  151. {"reproducible", lar::Bool::make(false)}
  152. };
  153. #if MGB_ENABLE_FASTRUN
  154. std::static_pointer_cast<lar::Bool>(m_option["fast_run"])
  155. ->set_value(FLAGS_fast_run);
  156. std::static_pointer_cast<lar::Bool>(m_option["full_run"])
  157. ->set_value(FLAGS_full_run);
  158. #endif
  159. std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"])
  160. ->set_value(FLAGS_binary_equal_between_batch);
  161. std::static_pointer_cast<lar::Bool>(m_option["reproducible"])
  162. ->set_value(FLAGS_reproducible);
  163. #if MGB_ENABLE_FASTRUN
  164. //! while fastrun cache file path is not empty and can't be accessed
  165. if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) {
  166. mgb_assert(
  167. enable_full_run || enable_fast_run,
  168. "--fast-run or --full-run should be enabled");
  169. }
  170. if (share_batch_size) {
  171. mgb_assert(
  172. enable_full_run || enable_fast_run || !m_fast_run_cache.empty(),
  173. "--fast-run-shared-batch-size should be used with "
  174. "--fast-run|--full-run|--fast-run-algo-policy");
  175. }
  176. #endif
  177. }
  178. bool FastRunOption::is_valid() {
  179. bool ret = false;
  180. #if MGB_ENABLE_FASTRUN
  181. ret = ret || FLAGS_fast_run;
  182. ret = ret || FLAGS_full_run;
  183. #endif
  184. ret = ret || FLAGS_binary_equal_between_batch;
  185. ret = ret || FLAGS_fast_run_shared_batch_size > 0;
  186. ret = ret || FLAGS_reproducible;
  187. ret = ret || FLAGS_fast_run_algo_policy.size() > 0;
  188. return ret || m_valid;
  189. }
  190. std::shared_ptr<OptionBase> FastRunOption::create_option() {
  191. static std::shared_ptr<FastRunOption> option(new FastRunOption);
  192. if (FastRunOption::is_valid()) {
  193. option->update();
  194. return std::static_pointer_cast<OptionBase>(option);
  195. } else {
  196. return nullptr;
  197. }
  198. }
  199. void FastRunOption::config_model(
  200. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  201. #if MGB_ENABLE_FASTRUN
  202. enable_fast_run =
  203. std::static_pointer_cast<lar::Bool>(m_option["fast_run"])->get_value();
  204. enable_full_run =
  205. std::static_pointer_cast<lar::Bool>(m_option["full_run"])->get_value();
  206. mgb_throw_if(
  207. enable_fast_run && enable_full_run, mgb::AssertionError,
  208. "invalid options of both fast-run and full-run");
  209. #endif
  210. batch_binary_equal =
  211. std::static_pointer_cast<lar::Bool>(m_option["binary_equal_between_batch"])
  212. ->get_value();
  213. enable_reproducible =
  214. std::static_pointer_cast<lar::Bool>(m_option["reproducible"])->get_value();
  215. CONFIG_MODEL_FUN;
  216. }
  217. #if MGB_ENABLE_FASTRUN
  218. DEFINE_bool(fast_run, false, "whether to use fast-run in model run");
  219. DEFINE_bool(full_run, false, "whether to use full-run in model run");
  220. #endif
  221. DEFINE_bool(
  222. binary_equal_between_batch, false,
  223. "Each batch of output is promised binary equal if each batch of "
  224. "input is binary equal\n Note that if this option is turned on, "
  225. "`--reproducible` will also be turned on.");
  226. DEFINE_bool(
  227. reproducible, false,
  228. "Enable choose algo which is reproducible. It mainly used for "
  229. "cudnn algos.See "
  230. "https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/"
  231. "index.html#reproducibility"
  232. "for more details.");
  233. DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun");
  234. DEFINE_string(fast_run_algo_policy, "", "fast-run cache path.");
  235. REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option);
  236. REGIST_OPTION_VALIDATER(fastrun, lar::FastRunOption::set_valid);