You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

plugin_options.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. #include "plugin_options.h"
  2. #include <map>
  3. #include "misc.h"
  4. #include "models/model_lite.h"
  5. #include "models/model_mdl.h"
  6. ///////////////////// Plugin options///////////////////////////
  7. namespace lar {
  8. template <>
  9. void PluginOption::config_model_internel<ModelLite>(
  10. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  11. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  12. LITE_ASSERT(range == 0, "lite model don't support NumRangeChecker plugin");
  13. LITE_ASSERT(
  14. !enable_check_dispatch,
  15. "lite model don't support CPUDispatchChecker plugin");
  16. LITE_ASSERT(
  17. var_value_check_str.empty(),
  18. "lite model don't support VarValueChecker plugin");
  19. }
  20. #if MGB_ENABLE_JSON
  21. else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  22. if (!profile_path.empty()) {
  23. if (!enable_profile_host) {
  24. LITE_LOG("enable profiling");
  25. model->get_lite_network()->enable_profile_performance(profile_path);
  26. } else {
  27. LITE_LOG("enable profiling for host");
  28. model->get_lite_network()->enable_profile_performance(profile_path);
  29. }
  30. }
  31. }
  32. #endif
  33. }
  34. template <>
  35. void PluginOption::config_model_internel<ModelMdl>(
  36. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  37. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  38. auto&& config = model->get_mdl_config();
  39. if (range > 0) {
  40. mgb_log("enable number range check");
  41. model->set_num_range_checker(float(range));
  42. }
  43. if (enable_check_dispatch) {
  44. mgb_log("enable cpu dispatch check");
  45. cpu_dispatch_checker =
  46. std::make_unique<mgb::CPUDispatchChecker>(config.comp_graph.get());
  47. }
  48. if (!var_value_check_str.empty()) {
  49. mgb_log("enable variable value check");
  50. size_t init_idx = 0, switch_interval;
  51. auto sep = var_value_check_str.find(':');
  52. if (sep != std::string::npos) {
  53. switch_interval = std::stoul(var_value_check_str.substr(0, sep));
  54. init_idx = std::stoul(var_value_check_str.substr(sep + 1));
  55. } else {
  56. switch_interval = std::stoul(var_value_check_str);
  57. }
  58. var_value_checker = std::make_unique<mgb::VarValueChecker>(
  59. config.comp_graph.get(), switch_interval, init_idx);
  60. }
  61. #if MGB_ENABLE_JSON
  62. if (!profile_path.empty()) {
  63. if (!enable_profile_host) {
  64. mgb_log("enable profiling");
  65. } else {
  66. mgb_log("enable profiling for host");
  67. }
  68. model->set_profiler();
  69. }
  70. #endif
  71. }
  72. else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  73. #if MGB_ENABLE_JSON
  74. if (!profile_path.empty()) {
  75. if (model->get_profiler()) {
  76. model->get_profiler()
  77. ->to_json_full(model->get_async_func().get())
  78. ->writeto_fpath(profile_path);
  79. mgb_log("profiling result written to %s", profile_path.c_str());
  80. }
  81. }
  82. #endif
  83. }
  84. }
  85. } // namespace lar
  86. using namespace lar;
  87. void PluginOption::update() {
  88. m_option_name = "plugin";
  89. range = FLAGS_range;
  90. enable_check_dispatch = FLAGS_check_dispatch;
  91. var_value_check_str = FLAGS_check_var_value;
  92. #if MGB_ENABLE_JSON
  93. enable_profile_host = false;
  94. if (!FLAGS_profile.empty()) {
  95. profile_path = FLAGS_profile;
  96. }
  97. if (!FLAGS_profile_host.empty()) {
  98. enable_profile_host = !FLAGS_profile_host.empty();
  99. profile_path = FLAGS_profile_host;
  100. }
  101. #endif
  102. }
  103. bool PluginOption::is_valid() {
  104. bool ret = FLAGS_check_dispatch;
  105. ret = ret || FLAGS_range > 0;
  106. ret = ret || !FLAGS_check_var_value.empty();
  107. #if MGB_ENABLE_JSON
  108. ret = ret || !FLAGS_profile.empty();
  109. ret = ret || !FLAGS_profile_host.empty();
  110. #endif
  111. return ret;
  112. }
  113. std::shared_ptr<OptionBase> PluginOption::create_option() {
  114. static std::shared_ptr<PluginOption> option(new PluginOption);
  115. if (PluginOption::is_valid()) {
  116. option->update();
  117. return std::static_pointer_cast<OptionBase>(option);
  118. } else {
  119. return nullptr;
  120. }
  121. }
  122. void PluginOption::config_model(
  123. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  124. CONFIG_MODEL_FUN;
  125. }
  126. ///////////////////// Debug options///////////////////////////
  127. namespace lar {
  128. template <>
  129. void DebugOption::format_and_print(
  130. const std::string& tablename, std::shared_ptr<ModelLite> model) {
  131. auto table = mgb::TextTable(tablename);
  132. auto&& network = model->get_lite_network();
  133. table.padding(1);
  134. table.align(mgb::TextTable::Align::Mid)
  135. .add("type")
  136. .add("name")
  137. .add("shape")
  138. .add("dtype")
  139. .eor();
  140. auto to_string = [&](lite::Layout& layout) {
  141. std::string shape("{");
  142. for (size_t i = 0; i < layout.ndim; i++) {
  143. if (i)
  144. shape.append(",");
  145. shape.append(std::to_string(layout.shapes[i]));
  146. }
  147. shape.append("}");
  148. return shape;
  149. };
  150. auto get_dtype = [&](lite::Layout& layout) {
  151. std::map<LiteDataType, std::string> type_map = {
  152. {LiteDataType::LITE_FLOAT, "float32"},
  153. {LiteDataType::LITE_HALF, "float16"},
  154. {LiteDataType::LITE_INT64, "int64"},
  155. {LiteDataType::LITE_INT, "int32"},
  156. {LiteDataType::LITE_UINT, "uint32"},
  157. {LiteDataType::LITE_INT16, "int16"},
  158. {LiteDataType::LITE_UINT16, "uint16"},
  159. {LiteDataType::LITE_INT8, "int8"},
  160. {LiteDataType::LITE_UINT8, "uint8"}};
  161. return type_map[layout.data_type];
  162. };
  163. auto input_name = network->get_all_input_name();
  164. for (auto& i : input_name) {
  165. auto layout = network->get_io_tensor(i)->get_layout();
  166. table.align(mgb::TextTable::Align::Mid)
  167. .add("INPUT")
  168. .add(i)
  169. .add(to_string(layout))
  170. .add(get_dtype(layout))
  171. .eor();
  172. }
  173. auto output_name = network->get_all_output_name();
  174. for (auto& i : output_name) {
  175. auto layout = network->get_io_tensor(i)->get_layout();
  176. table.align(mgb::TextTable::Align::Mid)
  177. .add("OUTPUT")
  178. .add(i)
  179. .add(to_string(layout))
  180. .add(get_dtype(layout))
  181. .eor();
  182. }
  183. std::stringstream ss;
  184. ss << table;
  185. LITE_LOG("\n%s\n", ss.str().c_str());
  186. }
  187. template <>
  188. void DebugOption::format_and_print(
  189. const std::string& tablename, std::shared_ptr<ModelMdl> model) {
  190. auto table = mgb::TextTable(tablename);
  191. table.padding(1);
  192. table.align(mgb::TextTable::Align::Mid)
  193. .add("type")
  194. .add("name")
  195. .add("shape")
  196. .add("dtype")
  197. .eor();
  198. auto get_dtype = [&](megdnn::DType data_type) {
  199. std::map<megdnn::DTypeEnum, std::string> type_map = {
  200. {mgb::dtype::Float32().enumv(), "float32"},
  201. {mgb::dtype::Int32().enumv(), "int32"},
  202. {mgb::dtype::Int16().enumv(), "int16"},
  203. {mgb::dtype::Uint16().enumv(), "uint16"},
  204. {mgb::dtype::Int8().enumv(), "int8"},
  205. {mgb::dtype::Uint8().enumv(), "uint8"}};
  206. return type_map[data_type.enumv()];
  207. };
  208. for (auto&& i : model->get_mdl_load_result().tensor_map) {
  209. table.align(mgb::TextTable::Align::Mid)
  210. .add("INPUT")
  211. .add(i.first)
  212. .add(i.second->shape().to_string())
  213. .add(get_dtype(i.second->dtype()))
  214. .eor();
  215. }
  216. for (auto&& i : model->get_mdl_load_result().output_var_list) {
  217. table.align(mgb::TextTable::Align::Mid)
  218. .add("OUTPUT")
  219. .add(i.node()->name())
  220. .add(i.shape().to_string())
  221. .add(get_dtype(i.dtype()))
  222. .eor();
  223. }
  224. std::stringstream ss;
  225. ss << table;
  226. mgb_log("\n%s\n", ss.str().c_str());
  227. }
  228. template <>
  229. void DebugOption::config_model_internel<ModelLite>(
  230. RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
  231. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  232. LITE_ASSERT(
  233. !disable_assert_throw, "lite model don't support disable assert throw");
  234. #ifndef __IN_TEE_ENV__
  235. #if MGB_ENABLE_JSON
  236. LITE_ASSERT(
  237. static_mem_log_dir_path.empty(),
  238. "lite model don't support static memory information export");
  239. #endif
  240. #endif
  241. if (enable_verbose) {
  242. LITE_LOG("enable verbose");
  243. lite::set_log_level(LiteLogLevel::DEBUG);
  244. }
  245. #if __linux__ || __unix__
  246. if (enable_wait_gdb) {
  247. printf("wait for gdb attach (pid=%d): ", getpid());
  248. getchar();
  249. }
  250. #endif
  251. } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
  252. if (enable_display_model_info) {
  253. LITE_LOG("enable display model information");
  254. format_and_print<ModelLite>("Runtime Model Info", model);
  255. }
  256. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  257. if (enable_display_model_info) {
  258. format_and_print<ModelLite>("Runtime Model Info", model);
  259. }
  260. }
  261. }
  262. template <>
  263. void DebugOption::config_model_internel<ModelMdl>(
  264. RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
  265. if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
  266. if (enable_verbose) {
  267. mgb_log("enable verbose");
  268. mgb::set_log_level(mgb::LogLevel::DEBUG);
  269. }
  270. #if __linux__ || __unix__
  271. if (enable_wait_gdb) {
  272. printf("wait for gdb attach (pid=%d): ", getpid());
  273. getchar();
  274. }
  275. #endif
  276. } else if (runtime_param.stage == RunStage::BEFORE_OUTSPEC_SET) {
  277. if (enable_display_model_info) {
  278. mgb_log("enable display model information");
  279. format_and_print<ModelMdl>("Runtime Model Info", model);
  280. }
  281. } else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) {
  282. #ifndef __IN_TEE_ENV__
  283. #if MGB_ENABLE_JSON
  284. if (!static_mem_log_dir_path.empty()) {
  285. mgb_log("enable get static memeory information");
  286. model->get_async_func()->get_static_memory_alloc_info(
  287. static_mem_log_dir_path);
  288. }
  289. #endif
  290. #endif
  291. if (disable_assert_throw) {
  292. mgb_log("disable assert throw");
  293. auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
  294. if (opr->same_type<mgb::opr::AssertEqual>()) {
  295. opr->cast_final<mgb::opr::AssertEqual>().disable_throw_on_error();
  296. }
  297. };
  298. mgb::cg::DepOprIter iter{on_opr};
  299. for (auto&& i : model->get_output_spec()) {
  300. iter.add(i.first.node()->owner_opr());
  301. }
  302. }
  303. } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
  304. if (enable_display_model_info) {
  305. format_and_print<ModelMdl>("Runtime Model Info", model);
  306. }
  307. }
  308. }
  309. } // namespace lar
  310. void DebugOption::update() {
  311. m_option_name = "debug";
  312. enable_display_model_info = FLAGS_model_info;
  313. enable_verbose = FLAGS_verbose;
  314. disable_assert_throw = FLAGS_disable_assert_throw;
  315. #if __linux__ || __unix__
  316. enable_wait_gdb = FLAGS_wait_gdb;
  317. #endif
  318. #ifndef __IN_TEE_ENV__
  319. #if MGB_ENABLE_JSON
  320. static_mem_log_dir_path = FLAGS_get_static_mem_info;
  321. #endif
  322. #endif
  323. }
  324. bool DebugOption::is_valid() {
  325. bool ret = FLAGS_model_info;
  326. ret = ret || FLAGS_verbose;
  327. ret = ret || FLAGS_disable_assert_throw;
  328. #if __linux__ || __unix__
  329. ret = ret || FLAGS_wait_gdb;
  330. #endif
  331. #ifndef __IN_TEE_ENV__
  332. #if MGB_ENABLE_JSON
  333. ret = ret || !FLAGS_get_static_mem_info.empty();
  334. #endif
  335. #endif
  336. return ret;
  337. }
  338. std::shared_ptr<OptionBase> DebugOption::create_option() {
  339. static std::shared_ptr<DebugOption> option(new DebugOption);
  340. if (DebugOption::is_valid()) {
  341. option->update();
  342. return std::static_pointer_cast<OptionBase>(option);
  343. } else {
  344. return nullptr;
  345. }
  346. }
  347. void DebugOption::config_model(
  348. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
  349. CONFIG_MODEL_FUN;
  350. }
  351. ///////////////////// Plugin gflags///////////////////////////
  352. DEFINE_double(
  353. range, 0,
  354. "check whether absolute value of all numbers in computing graph "
  355. "is in the given range");
  356. DEFINE_bool(
  357. check_dispatch, false,
  358. "check whether an operator call dispatch on cpu comp nodes This is used to "
  359. "find potential bugs in MegDNN");
  360. DEFINE_string(
  361. check_var_value, "",
  362. "--check-var-value [interval]|[interval:init_idx], Enable "
  363. "VarValueChecker plugin. check values of all vars in a graph from given var "
  364. "ID(init_idx) with step interval");
  365. #if MGB_ENABLE_JSON
  366. DEFINE_string(
  367. profile, "",
  368. "Write profiling result to given file. The output file is in "
  369. "JSON format");
  370. DEFINE_string(
  371. profile_host, "",
  372. "focus on host time profiling For some backends(such as openCL)");
  373. #endif
  374. ///////////////////// Debug gflags///////////////////////////
  375. DEFINE_bool(
  376. model_info, false,
  377. " Format and display model input/output tensor inforamtion");
  378. DEFINE_bool(verbose, false, "get more inforamtion for debug");
  379. DEFINE_bool(disable_assert_throw, false, "disable assert throw on error check");
  380. #if __linux__ || __unix__
  381. DEFINE_bool(wait_gdb, false, "print current process PID and wait for gdb attach");
  382. #endif
  383. #ifndef __IN_TEE_ENV__
  384. #if MGB_ENABLE_JSON
  385. DEFINE_string(
  386. get_static_mem_info, "",
  387. "Record the static computing graph's static memory information");
  388. #endif
  389. #endif
  390. REGIST_OPTION_CREATOR(plugin, lar::PluginOption::create_option);
  391. REGIST_OPTION_CREATOR(debug, lar::DebugOption::create_option);