You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_impl.cpp 46 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105
  1. #include "lite_build_config.h"
  2. #if LITE_BUILD_WITH_MGE
  3. #include "common.h"
  4. #include "lite/network.h"
  5. #include "memory_allocator.h"
  6. #include "network_impl.h"
  7. #include "parse_info/parse_info_base.h"
  8. #include "parse_model/model_parser.h"
  9. #include "megbrain/common.h"
  10. #include "megbrain/comp_node.h"
  11. #include "megbrain/comp_node_env.h"
  12. #include "megbrain/graph.h"
  13. #include "megbrain/graph/cg.h"
  14. #include "megbrain/opr/io.h"
  15. #include "megbrain/opr/tensor_manip.h"
  16. #include "megbrain/tensor.h"
  17. #if MGB_OPENCL
  18. #include "megcore_opencl.h"
  19. #endif
  20. #if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
  21. #include "cpuinfo.h"
  22. #endif
  23. #include <fstream>
  24. #include <memory>
  25. #include <set>
  26. using namespace lite;
  27. using namespace mgb;
  28. LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);
  29. void NetworkImplDft::set_config(const Config& config) {
  30. *m_user_config = config;
  31. m_compnode_locator = to_compnode_locator(m_user_config->device_type);
  32. m_compnode_locator.device = config.device_id;
  33. }
  34. void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
  35. application_config();
  36. const auto& src_impl = src_network->cast_final_safe<NetworkImplDft>();
  37. LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded.");
  38. m_load_result = src_impl.m_loader->load(m_load_config, true);
  39. configure_after_loaded();
  40. }
  41. void NetworkImplDft::application_config() {
  42. auto device_type = m_user_config->device_type;
  43. m_compnode_locator.type = to_compnode_locator(device_type).type;
  44. //! when the device id is not configured, configure it
  45. if (m_compnode_locator.device == -1) {
  46. m_compnode_locator.device = m_user_config->device_id;
  47. }
  48. if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) {
  49. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  50. if (m_compnode_locator.device == -1) {
  51. m_compnode_locator.device = m_user_config->device_id;
  52. }
  53. }
  54. //! model options
  55. #define ConfigOption(mge_name, lite_name) \
  56. options.mge_name = m_user_config->options.lite_name;
  57. auto&& options = m_load_config.comp_graph->options();
  58. ConfigOption(graph_opt.weight_preprocess, weight_preprocess);
  59. ConfigOption(graph_opt.fuse_preprocess, fuse_preprocess);
  60. ConfigOption(fake_next_exec, fake_next_exec);
  61. ConfigOption(var_sanity_check_first_run, var_sanity_check_first_run);
  62. m_load_config.const_var_shape = m_user_config->options.const_shape;
  63. ConfigOption(force_dynamic_alloc, force_dynamic_alloc);
  64. ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc);
  65. ConfigOption(
  66. force_output_use_user_specified_memory,
  67. force_output_use_user_specified_memory);
  68. ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change);
  69. LITE_ASSERT(
  70. m_user_config->options.jit_level == 0 ||
  71. (m_user_config->options.jit_level > 0 &&
  72. device_type == LiteDeviceType::LITE_CUDA),
  73. "jit only support in cuda device.");
  74. ConfigOption(graph_opt.jit, jit_level);
  75. ConfigOption(comp_node_seq_record_level, comp_node_seq_record_level);
  76. ConfigOption(graph_opt_level, graph_opt_level);
  77. ConfigOption(async_exec_level, async_exec_level);
  78. #undef ConfigOption
  79. #define ConfigOptionLayoutTransform(name) \
  80. if (m_user_config->options.name) { \
  81. options.graph_opt.name(); \
  82. }
  83. ConfigOptionLayoutTransform(enable_nchw44);
  84. ConfigOptionLayoutTransform(enable_nchw44_dot);
  85. ConfigOptionLayoutTransform(enable_nchw88);
  86. ConfigOptionLayoutTransform(enable_nhwcd4);
  87. ConfigOptionLayoutTransform(enable_nchw4);
  88. ConfigOptionLayoutTransform(enable_nchw32);
  89. ConfigOptionLayoutTransform(enable_nchw64);
  90. #undef ConfigOptionLayoutTransform
  91. if (m_user_config->has_compression) {
  92. m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
  93. }
  94. //! if device is LITE_NONE, the compnode information is stored in model or
  95. //! xpu in MegEngine
  96. if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
  97. m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
  98. if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
  99. loc.type = m_compnode_locator.type;
  100. }
  101. loc.device = m_compnode_locator.device;
  102. //! if user set the thread number and the compnode is multithread
  103. if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
  104. m_nr_threads != 1) {
  105. loc.stream = m_nr_threads;
  106. } else {
  107. loc.stream = m_compnode_locator.stream;
  108. }
  109. };
  110. }
  111. }
  112. void NetworkImplDft::set_memory_allocator(std::shared_ptr<Allocator> user_allocator) {
  113. auto allocator = std::make_shared<UserStaticMemAlloc>(user_allocator);
  114. LITE_ASSERT(m_load_config.comp_graph);
  115. m_load_config.comp_graph->set_device_memory_allocator(allocator);
  116. }
  117. //! share the runtime memory with other network, the weights is not shared
  118. void NetworkImplDft::share_runtime_memory_with(Network::NetworkImplBase* network_impl) {
  119. LITE_ASSERT(network_impl);
  120. LITE_ASSERT(m_load_config.comp_graph);
  121. m_load_config.comp_graph->share_device_memory_with(*(
  122. network_impl->cast_final_safe<NetworkImplDft>().m_load_config.comp_graph));
  123. }
  124. void NetworkImplDft::set_cpu_inplace_mode() {
  125. LITE_ASSERT(
  126. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  127. "cpu inplace mode is only avaliable in CPU.");
  128. m_is_cpu_inplace_mode = true;
  129. if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) {
  130. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
  131. m_user_config->device_id = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
  132. } else {
  133. LITE_ASSERT(
  134. m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD,
  135. "cpu inplace mode is only avaliable in CPU.");
  136. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  137. m_user_config->device_id = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  138. }
  139. }
  140. void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) {
  141. LITE_ASSERT(
  142. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  143. "multi threads mode is only avaliable in CPU.");
  144. if (nr_threads > 1) {
  145. m_nr_threads = nr_threads;
  146. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  147. if (m_is_cpu_inplace_mode) {
  148. m_compnode_locator.device =
  149. mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  150. m_user_config->device_id =
  151. mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  152. }
  153. m_compnode_locator.nr_threads = nr_threads;
  154. }
  155. }
  156. void NetworkImplDft::set_runtime_thread_affinity(
  157. const ThreadAffinityCallback& thread_affinity_callback) {
  158. LITE_ASSERT(
  159. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  160. "multi threads mode is only avaliable in CPU.");
  161. mgb::CompNode::Locator loc;
  162. m_load_config.comp_node_mapper(loc);
  163. auto cn = mgb::CompNode::load(loc);
  164. if (m_nr_threads > 1) {
  165. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().set_affinity(
  166. thread_affinity_callback);
  167. } else {
  168. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().dispatch(
  169. [thread_affinity_callback](void) { thread_affinity_callback(0); });
  170. }
  171. }
  172. void NetworkImplDft::set_device_id(int device_id) {
  173. m_compnode_locator.device = device_id;
  174. m_user_config->device_id = device_id;
  175. }
  176. void NetworkImplDft::set_stream_id(int stream_id) {
  177. m_compnode_locator.stream = stream_id;
  178. }
  179. void NetworkImplDft::use_tensorrt() {
  180. auto&& options = m_load_config.comp_graph->options();
  181. options.graph_opt.tensorrt = true;
  182. }
  183. //! set the callback in async model
  184. void NetworkImplDft::set_async_callback(const AsyncCallback& callback) {
  185. LITE_ASSERT(!m_is_cpu_inplace_mode, "cpu inplace mode not support async mode");
  186. LITE_ASSERT(
  187. m_user_config->device_type == LiteDeviceType::LITE_CPU ||
  188. m_user_config->device_type == LiteDeviceType::LITE_CUDA,
  189. "Now only cpu and cuda>10.0 support async mode");
  190. m_async = true;
  191. m_async_callback = std::move(callback);
  192. }
  193. void NetworkImplDft::make_output_spec() {
  194. m_output_spec.clear();
  195. for (auto&& out : m_network_io->outputs) {
  196. if (m_load_result.output_var_map.count(out.name)) {
  197. auto&& load_out = m_load_result.output_var_map[out.name];
  198. auto cb = [&out, this](const mgb::DeviceTensorND& dv) mutable {
  199. mgb::CompNode comp_node = dv.comp_node();
  200. if (out.io_type == LiteIOType::LITE_IO_SHAPE) {
  201. auto mgb_layout = dv.layout();
  202. out.lite_tensor->set_layout(to_lite_layout(mgb_layout));
  203. } else {
  204. TensorHelper::implement(out.lite_tensor)
  205. ->cast_final_safe<TensorImplDft>()
  206. .copy_from_mge_tensor(dv);
  207. out.lite_tensor->update_from_implement();
  208. }
  209. if (m_async) {
  210. out.have_sync = true;
  211. bool need_exec_cb = true;
  212. for (auto&& j : m_network_io->outputs) {
  213. if (!j.have_sync) {
  214. need_exec_cb = false;
  215. }
  216. }
  217. if (need_exec_cb) {
  218. for (auto&& j : m_network_io->outputs) {
  219. j.have_sync = false;
  220. }
  221. comp_node.add_callback([this]() { finish(); });
  222. }
  223. }
  224. };
  225. //! if write to user-specified memory, the CallbackCaller must be nullptr.
  226. if (m_user_config->options.force_output_use_user_specified_memory ||
  227. m_user_config->options.force_output_dynamic_alloc) {
  228. m_output_spec.emplace_back(load_out, nullptr);
  229. } else {
  230. m_output_spec.emplace_back(load_out, std::move(cb));
  231. }
  232. } else {
  233. LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str()));
  234. }
  235. }
  236. }
  237. void NetworkImplDft::replace_dev_input_pass() {
  238. mgb::CompNode::Locator locator;
  239. m_load_config.comp_node_mapper(locator);
  240. //! CPU is not need use device input
  241. if (locator.type == mgb::CompNode::DeviceType::CPU) {
  242. return;
  243. }
  244. //! repalce the H2D with VolatileSharedDeviceTensor, and keep the dev tensor
  245. //! in m_network_io.input, user can directly change the dev tensor
  246. //! storage through m_network_io.input.lite_tensor->reset() befor forward
  247. using DeviceTensorMap =
  248. std::unordered_map<std::string, std::shared_ptr<mgb::DeviceTensorND>>;
  249. DeviceTensorMap name2dev_tensor;
  250. mgb::ThinHashMap<mgb::HostTensorND*, mgb::SymbolVar> host_val2var;
  251. //! construct host_val2var that maps from host tensor to corresponding var
  252. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  253. if (opr->same_type<mgb::opr::Host2DeviceCopy>()) {
  254. mgb::HostTensorND* tensor =
  255. opr->cast_final<mgb::opr::Host2DeviceCopy>().host_data().get();
  256. host_val2var[tensor] = opr->output(0);
  257. }
  258. };
  259. mgb::cg::DepOprIter dep_iter{on_opr};
  260. for (auto i : m_load_result.output_var_list) {
  261. dep_iter.add(i.node()->owner_opr());
  262. }
  263. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> inp_var_map, out_var_map;
  264. mgb::SmallVector<std::string> to_clear;
  265. for (auto&& config_in : m_network_io->inputs) {
  266. if (!config_in.is_host) {
  267. auto host_val = m_load_result.tensor_map[config_in.name];
  268. auto dev_val = TensorHelper::implement(config_in.lite_tensor)
  269. ->cast_final_safe<TensorImplDft>()
  270. .m_dev_tensor;
  271. auto dev_var = mgb::opr::VolatileSharedDeviceTensor::make(
  272. *m_load_result.graph, dev_val, {config_in.name});
  273. inp_var_map[host_val2var.at(host_val.get())] = dev_var;
  274. name2dev_tensor[config_in.name] = dev_val;
  275. }
  276. }
  277. auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map);
  278. for (size_t i = 0; i < new_ovar.size(); ++i) {
  279. out_var_map[m_load_result.output_var_list[i]] = new_ovar[i];
  280. }
  281. for (auto&& i : m_load_result.output_var_map) {
  282. i.second = out_var_map.at(i.second);
  283. }
  284. for (auto&& i : m_load_result.output_var_map_id) {
  285. i.second = out_var_map.at(i.second);
  286. }
  287. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  288. new_ovar[i].rename(m_load_result.output_var_list[i].node()->name());
  289. }
  290. m_load_result.output_var_list = std::move(new_ovar);
  291. }
  292. void NetworkImplDft::cross_compnode_model_detect() {
  293. mgb::ThinHashSet<LiteDeviceType> nr_used_device_type;
  294. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  295. for (auto j : opr->output()) {
  296. if (j->comp_node() != mgb::CompNode::default_cpu()) {
  297. nr_used_device_type.insert(
  298. get_device_from_locator(j->comp_node().locator()));
  299. }
  300. }
  301. };
  302. mgb::cg::DepOprIter dep_iter{on_opr};
  303. for (auto i : m_load_result.output_var_list) {
  304. dep_iter.add(i.node()->owner_opr());
  305. }
  306. m_nr_device_type = nr_used_device_type.size();
  307. }
  308. void NetworkImplDft::adapt_option_valid() {
  309. auto&& options = m_load_config.comp_graph->options();
  310. if (m_user_config->options.force_output_use_user_specified_memory) {
  311. for (auto&& out : m_load_result.output_var_list) {
  312. auto opr = out.node()->owner_opr();
  313. //! all the dest operator inherit from ReadonlyFwdHelper can't
  314. //! support force_output_use_user_specified_memory options
  315. if (opr->try_cast_final<mgb::opr::Reshape>() ||
  316. opr->try_cast_final<mgb::opr::Broadcast>() ||
  317. opr->try_cast_final<mgb::opr::Subtensor>() ||
  318. opr->try_cast_final<mgb::opr::AxisAddRemove>() ||
  319. opr->try_cast_final<mgb::opr::Dimshuffle>()) {
  320. m_user_config->options.force_output_use_user_specified_memory = false;
  321. options.force_output_use_user_specified_memory = false;
  322. LITE_WARN(
  323. "detect the unsupported dest operator %s when config "
  324. "force_output_use_user_specified_memory, set "
  325. "force_output_use_user_specified_memory to false\n",
  326. opr->cname());
  327. break;
  328. }
  329. }
  330. }
  331. }
  332. void NetworkImplDft::layout_transform_optimization() {
  333. if (m_set_layout_transform) {
  334. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
  335. auto output_var_array = mgb::gopt::layout_transform(
  336. m_load_result.output_var_list, m_layout_transform_target);
  337. m_load_result.update_output_var_list(output_var_array);
  338. } else if (m_user_config->auto_optimize_inference) {
  339. //! set model weight preprocess
  340. m_load_config.comp_graph->options().graph_opt.weight_preprocess = true;
  341. LITE_LOG(
  342. "weight_preprocess is enabled, this maybe use more memory when "
  343. "infernece.");
  344. //! get the current format and data type of the model
  345. bool is_model_nchw = true;
  346. //! is any convolution is int8
  347. bool is_model_int8 = false;
  348. //! is all convolution is float32
  349. bool is_model_float32 = true;
  350. float conv_cnt = 0;
  351. float dimshuffle_cnt = 0;
  352. auto detect_int8_model = [&](const VarNode* input) {
  353. if (input->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
  354. input->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm) {
  355. is_model_int8 = true;
  356. is_model_float32 = false;
  357. } else if (input->dtype().enumv() == megdnn::DTypeEnum::Float32) {
  358. is_model_float32 = (is_model_float32 && true);
  359. } else {
  360. is_model_float32 = false;
  361. }
  362. };
  363. cg::DepOprIter dep([&](cg::OperatorNodeBase* opr) {
  364. if (auto conv = opr->try_cast_final<opr::ConvolutionForward>()) {
  365. if (conv->param().format != megdnn::param::ConvBias::Format::NCHW) {
  366. is_model_nchw = false;
  367. }
  368. conv_cnt++;
  369. detect_int8_model(conv->input(0));
  370. } else if (auto conv_bias = opr->try_cast_final<opr::ConvBias>()) {
  371. if (conv_bias->param().format !=
  372. megdnn::param::ConvBias::Format::NCHW) {
  373. is_model_nchw = false;
  374. }
  375. conv_cnt++;
  376. detect_int8_model(conv->input(0));
  377. } else if (auto dimshuffle = opr->try_cast_final<opr::Dimshuffle>()) {
  378. LITE_MARK_USED_VAR(dimshuffle);
  379. dimshuffle_cnt++;
  380. }
  381. });
  382. for (auto&& i : m_load_result.output_var_list)
  383. dep.add(i);
  384. float radio_dimshuffle_conv = 0;
  385. if (conv_cnt > 0) {
  386. radio_dimshuffle_conv = dimshuffle_cnt / conv_cnt;
  387. }
  388. //! format optimize can only applied on nchw model,
  389. //! shufflenet like model will hurt the performance when using nchw88 or nchw44
  390. //! format, here just heuristically decide the gate radio of
  391. //! dimshuffle and convolution
  392. if (!is_model_nchw || radio_dimshuffle_conv > 0.15f) {
  393. return;
  394. }
  395. //! determine the layout by the device information
  396. //! TODO: shufflenet like model use nchw88 or nchw44 will hurt the
  397. //! performance
  398. if (m_user_config->device_type == LITE_CPU) {
  399. #if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
  400. cpuinfo_initialize();
  401. //! if all convolution and matmul data type is float32
  402. if (is_model_float32) {
  403. //! if device is x86
  404. //! if x86 support avx, use format nchw88
  405. if (cpuinfo_has_x86_avx()) {
  406. m_load_config.comp_graph->options().graph_opt.enable_nchw88();
  407. LITE_LOG("Configure model inference with nchw88 format.");
  408. } else if (cpuinfo_has_x86_sse2() && !cpuinfo_has_x86_sse3()) {
  409. //! if x86 only support sse2, use format nchw44
  410. m_load_config.comp_graph->options().graph_opt.enable_nchw44();
  411. LITE_LOG("Configure model inference with nchw44 format.");
  412. } else if (cpuinfo_has_arm_neon()) {
  413. //! if device is arm, use format nchw44
  414. m_load_config.comp_graph->options().graph_opt.enable_nchw44();
  415. LITE_LOG("Configure model inference with nchw44 format.");
  416. }
  417. } else if (is_model_int8) {
  418. //! if date type of convolution is int8
  419. //! if device is arm and support dot, use nchw44-dot format
  420. if (cpuinfo_has_arm_neon() && cpuinfo_has_arm_neon_dot()) {
  421. m_load_config.comp_graph->options().graph_opt.enable_nchw44_dot();
  422. LITE_LOG("Configure model inference with nchw44-dot format.");
  423. } else if (cpuinfo_has_arm_neon()) {
  424. //! if device is arm and do not support dot, use nchw44 format
  425. m_load_config.comp_graph->options().graph_opt.enable_nchw44();
  426. LITE_LOG("Configure model inference with nchw44 format.");
  427. }
  428. }
  429. #endif
  430. }
  431. }
  432. }
  433. void NetworkImplDft::load_model(
  434. std::shared_ptr<void> model_mem, size_t size,
  435. std::unordered_map<std::string, LiteAny> separate_config_map) {
  436. if (!m_loader) {
  437. m_input_file =
  438. mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
  439. m_format = mgb::serialization::GraphLoader::identify_graph_dump_format(
  440. *m_input_file);
  441. if (!m_format.valid()) {
  442. LITE_THROW("invalid model format");
  443. }
  444. m_loader = mgb::serialization::GraphLoader::make(
  445. std::move(m_input_file), m_format.val());
  446. }
  447. //! applay the user configration to mge model
  448. application_config();
  449. //! config some flag get from json config file
  450. if (separate_config_map.find("device_id") != separate_config_map.end()) {
  451. set_device_id(separate_config_map["device_id"].safe_cast<int>());
  452. }
  453. if (separate_config_map.find("number_threads") != separate_config_map.end() &&
  454. separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) {
  455. set_cpu_threads_number(
  456. separate_config_map["number_threads"].safe_cast<uint32_t>());
  457. }
  458. if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() &&
  459. separate_config_map["enable_inplace_model"].safe_cast<bool>()) {
  460. set_cpu_inplace_mode();
  461. }
  462. if (separate_config_map.find("use_tensorrt") != separate_config_map.end() &&
  463. separate_config_map["use_tensorrt"].safe_cast<bool>()) {
  464. use_tensorrt();
  465. }
  466. m_load_result = m_loader->load(m_load_config, true);
  467. configure_after_loaded();
  468. }
  469. void NetworkImplDft::configure_after_loaded() {
  470. modify_exection_policy();
  471. layout_transform_optimization();
  472. //! some optimization option maybe invalid in some case, so here just
  473. //! auto determine whether some options will apply.
  474. adapt_option_valid();
  475. //! find how many compnode the model has, this should call before update_io
  476. cross_compnode_model_detect();
  477. //! update the IO of the network
  478. update_io();
  479. //! replace the IO when there is device input or output
  480. compile_graph();
  481. }
  482. void NetworkImplDft::compile_graph() {
  483. replace_dev_input_pass();
  484. make_output_spec();
  485. m_execute_func = m_load_result.graph_compile(m_output_spec);
  486. }
  487. void NetworkImplDft::start() const {
  488. if (m_start_callback) {
  489. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  490. input_io_map;
  491. for (auto&& io_inner : m_network_io->inputs) {
  492. input_io_map[io_inner.name] = {
  493. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  494. io_inner.config_layout},
  495. io_inner.lite_tensor};
  496. }
  497. m_start_callback(input_io_map);
  498. }
  499. }
  500. void NetworkImplDft::forward() {
  501. start();
  502. if (m_load_config.comp_graph &&
  503. m_user_config->options.comp_node_seq_record_level == 2) {
  504. m_load_config.comp_graph.reset();
  505. }
  506. LITE_ASSERT(m_execute_func, "forward must be called after network loaded.");
  507. m_execute_func->execute();
  508. }
  509. void NetworkImplDft::wait() {
  510. if (!m_async) {
  511. m_execute_func->wait();
  512. }
  513. finish();
  514. }
  515. void NetworkImplDft::finish() const {
  516. if (m_async) {
  517. LITE_ASSERT(m_async_callback, "The callback func must set when async mode.");
  518. m_async_callback();
  519. }
  520. if (m_finish_callback) {
  521. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  522. output_io_map;
  523. for (auto&& io_inner : m_network_io->outputs) {
  524. output_io_map[io_inner.name] = {
  525. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  526. io_inner.config_layout},
  527. io_inner.lite_tensor};
  528. }
  529. m_finish_callback(output_io_map);
  530. }
  531. output_plugin_result();
  532. }
  533. void NetworkImplDft::set_io(const NetworkIO& network_io) {
  534. for (auto&& in : network_io.inputs) {
  535. m_network_io->inputs.emplace_back(in);
  536. }
  537. for (auto&& out : network_io.outputs) {
  538. m_network_io->outputs.emplace_back(out);
  539. }
  540. }
  541. void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
  542. if (var.node()->capable_shape_infer()) {
  543. auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
  544. auto shape = static_infer_mgr.infer_shape_fallible(var.node());
  545. if (!shape) {
  546. LITE_WARN(
  547. "Lite infer output shape failed, maybe the model is "
  548. "dynamic "
  549. "shape.\n");
  550. LITE_ASSERT(
  551. !m_user_config->options.force_output_use_user_specified_memory,
  552. "force_output_use_user_specified_memory can't be used when output "
  553. "shape can't be derived.");
  554. return;
  555. }
  556. Layout layout = to_lite_layout(TensorLayout{*shape, var.dtype()});
  557. tensor->set_layout(layout);
  558. }
  559. }
  560. void NetworkImplDft::update_io() {
  561. update_input();
  562. update_output();
  563. }
  564. void NetworkImplDft::update_input() {
  565. auto device_type = m_user_config->device_type;
  566. auto device_id = m_compnode_locator.device;
  567. auto stream_id = m_compnode_locator.stream;
  568. //! if cpu all input and output are host
  569. if (device_type == LiteDeviceType::LITE_CPU) {
  570. for (auto&& in : m_network_io->inputs) {
  571. in.is_host = true;
  572. }
  573. }
  574. //! if cross compnode model, modify the device input if it is not valid
  575. if (m_nr_device_type > 1) {
  576. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  577. for (auto&& config_in : m_network_io->inputs) {
  578. //! if tensor is set to device input
  579. if (in_tensor_iter.first == config_in.name && !config_in.is_host) {
  580. //! if the origin compnode of the tensor is not the device,
  581. //! set the input to host
  582. if (get_device_from_locator(
  583. in_tensor_iter.second->comp_node().locator()) ==
  584. LiteDeviceType::LITE_CPU) {
  585. config_in.is_host = true;
  586. LITE_WARN(
  587. "The input tensor %s of the cross device model "
  588. "should not from device.",
  589. config_in.name.c_str());
  590. }
  591. }
  592. }
  593. }
  594. }
  595. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  596. bool found = false;
  597. for (auto&& config_in : m_network_io->inputs) {
  598. if (in_tensor_iter.first == config_in.name) {
  599. found = true;
  600. if (config_in.is_host) {
  601. config_in.lite_tensor = std::make_shared<Tensor>(
  602. device_id, stream_id, device_type, true);
  603. TensorHelper::implement(config_in.lite_tensor)
  604. ->cast_final_safe<TensorImplDft>()
  605. .m_host_tensor = in_tensor_iter.second;
  606. config_in.lite_tensor->update_from_implement();
  607. } else {
  608. config_in.lite_tensor =
  609. std::make_shared<Tensor>(device_id, stream_id, device_type);
  610. config_in.lite_tensor->set_layout(
  611. to_lite_layout(in_tensor_iter.second->layout()));
  612. }
  613. TensorHelper::implement(config_in.lite_tensor)
  614. ->cast_final_safe<TensorImplDft>()
  615. .m_record_reset =
  616. m_user_config->options.comp_node_seq_record_level > 0;
  617. if (config_in.config_layout.ndim &&
  618. !(config_in.config_layout == config_in.lite_tensor->get_layout())) {
  619. config_in.lite_tensor->set_layout(config_in.config_layout);
  620. }
  621. }
  622. }
  623. if (!found) {
  624. IOInner io_in;
  625. io_in.name = in_tensor_iter.first;
  626. io_in.lite_tensor =
  627. std::make_shared<Tensor>(device_id, stream_id, device_type, true);
  628. TensorHelper::implement(io_in.lite_tensor)
  629. ->cast_final_safe<TensorImplDft>()
  630. .m_host_tensor = in_tensor_iter.second;
  631. TensorHelper::implement(io_in.lite_tensor)
  632. ->cast_final_safe<TensorImplDft>()
  633. .m_record_reset =
  634. m_user_config->options.comp_node_seq_record_level > 0;
  635. io_in.lite_tensor->update_from_implement();
  636. m_network_io->inputs.push_back(io_in);
  637. }
  638. }
  639. //! delete the IO that is not the network
  640. for (auto it = m_network_io->inputs.begin(); it != m_network_io->inputs.end();) {
  641. if (it->lite_tensor == nullptr) {
  642. LITE_LOG("%s is not the network input, ignore it.", it->name.c_str());
  643. it = m_network_io->inputs.erase(it);
  644. } else {
  645. it++;
  646. }
  647. }
  648. }
  649. void NetworkImplDft::update_output() {
  650. auto device_type = m_user_config->device_type;
  651. auto device_id = m_compnode_locator.device;
  652. auto stream_id = m_compnode_locator.stream;
  653. if (device_type == LiteDeviceType::LITE_CPU) {
  654. for (auto&& out : m_network_io->outputs) {
  655. out.is_host = true;
  656. }
  657. }
  658. //! delete the output that is not the network
  659. for (auto out_it = m_network_io->outputs.begin();
  660. out_it != m_network_io->outputs.end();) {
  661. if (std::find_if(
  662. m_load_result.output_var_list.begin(),
  663. m_load_result.output_var_list.end(), [out_it](const SymbolVar var) {
  664. return var.node()->name() == out_it->name;
  665. }) == m_load_result.output_var_list.end()) {
  666. LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str());
  667. out_it = m_network_io->outputs.erase(out_it);
  668. } else {
  669. out_it++;
  670. }
  671. }
  672. //! user config the output tensor, so only compute the config output
  673. if (m_compute_configured_output_only) {
  674. LITE_ASSERT(
  675. m_network_io->outputs.size() > 0,
  676. "compute configured output only with no configure output.");
  677. for (auto out_it = m_network_io->outputs.begin();
  678. out_it != m_network_io->outputs.end(); out_it++) {
  679. //! use pinned memory to copy form device
  680. if (out_it->is_host) {
  681. out_it->lite_tensor = std::make_shared<Tensor>(
  682. device_id, stream_id, device_type, true);
  683. } else {
  684. out_it->lite_tensor =
  685. std::make_shared<Tensor>(device_id, stream_id, device_type);
  686. }
  687. SymbolVar var;
  688. for (auto&& out_var : m_load_result.output_var_list) {
  689. if (out_var.node()->name() == out_it->name) {
  690. var = out_var;
  691. break;
  692. }
  693. }
  694. try_infer_tensor_layout(out_it->lite_tensor, var);
  695. output_tensor_copy_optimize(var, out_it->lite_tensor);
  696. TensorHelper::implement(out_it->lite_tensor)
  697. ->cast_final_safe<TensorImplDft>()
  698. .m_record_reset =
  699. m_user_config->options.comp_node_seq_record_level > 0;
  700. }
  701. //! user not set, use default output
  702. } else {
  703. for (auto&& out : m_load_result.output_var_list) {
  704. std::shared_ptr<Tensor> lite_tensor = nullptr;
  705. auto it = std::find_if(
  706. m_network_io->outputs.begin(), m_network_io->outputs.end(),
  707. [&out](const IOInner io) { return io.name == out.node()->name(); });
  708. if (it != m_network_io->outputs.end()) {
  709. if (it->is_host) {
  710. it->lite_tensor = std::make_shared<Tensor>(
  711. device_id, stream_id, device_type, true);
  712. } else {
  713. it->lite_tensor =
  714. std::make_shared<Tensor>(device_id, stream_id, device_type);
  715. }
  716. try_infer_tensor_layout(it->lite_tensor, out);
  717. lite_tensor = it->lite_tensor;
  718. } else {
  719. IOInner output;
  720. output.name = out.node()->name();
  721. output.lite_tensor = std::make_shared<Tensor>(
  722. device_id, stream_id, device_type, true);
  723. m_network_io->outputs.push_back({output});
  724. try_infer_tensor_layout(output.lite_tensor, out);
  725. lite_tensor = output.lite_tensor;
  726. }
  727. output_tensor_copy_optimize(out, lite_tensor);
  728. TensorHelper::implement(lite_tensor)
  729. ->cast_final_safe<TensorImplDft>()
  730. .m_record_reset =
  731. m_user_config->options.comp_node_seq_record_level > 0;
  732. }
  733. }
  734. }
  735. void NetworkImplDft::output_tensor_copy_optimize(
  736. Var var, std::shared_ptr<Tensor> tensor) {
  737. LITE_ASSERT(
  738. !(m_user_config->options.force_output_use_user_specified_memory &&
  739. m_user_config->options.force_output_dynamic_alloc),
  740. "Can't set force_output_use_user_specified_memory and "
  741. "force_output_dynamic_alloc at the same time.");
  742. if (m_user_config->options.force_output_use_user_specified_memory) {
  743. bool in_record = m_user_config->options.comp_node_seq_record_level > 0;
  744. TensorHelper::implement(tensor)
  745. ->cast_final_safe<TensorImplDft>()
  746. .set_reset_callback([var, in_record](TensorImplDft* dft_tensor) {
  747. dft_tensor->device_share_host_memory();
  748. auto dv = dft_tensor->dev_tensor().get();
  749. dv->comp_node(var.node()->comp_node(), true);
  750. var.node()->init_mem_plan(dv);
  751. if (in_record) {
  752. auto&& device_tensor = var.node()->mutable_dev_tensor();
  753. device_tensor.only_reset_raw_storage(dv->storage());
  754. } else {
  755. var.node()->reset_dev_tensor_from_tensor(*dv);
  756. }
  757. });
  758. }
  759. if (m_user_config->options.force_output_dynamic_alloc) {
  760. TensorHelper::implement(tensor)
  761. ->cast_final_safe<TensorImplDft>()
  762. .set_get_memory_callback([var](TensorImplDft* dft_tensor) {
  763. if (dft_tensor->is_host()) {
  764. auto host_tensor = dft_tensor->m_host_tensor;
  765. *host_tensor =
  766. HostTensorND::make_proxy(var.node()->dev_tensor());
  767. } else {
  768. auto dev_tensor = dft_tensor->m_dev_tensor;
  769. *dev_tensor = var.node()->dev_tensor();
  770. }
  771. });
  772. }
  773. }
  774. std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
  775. std::string io_name, LiteTensorPhase phase) {
  776. if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) {
  777. for (auto&& config_in : m_network_io->inputs) {
  778. if (io_name == config_in.name) {
  779. return config_in.lite_tensor;
  780. }
  781. }
  782. }
  783. if (phase == LiteTensorPhase::LITE_OUTPUT || phase == LiteTensorPhase::LITE_IO) {
  784. for (auto&& config_out : m_network_io->outputs) {
  785. if (io_name == config_out.name) {
  786. config_out.lite_tensor->update_from_implement();
  787. return config_out.lite_tensor;
  788. }
  789. }
  790. }
  791. LITE_THROW(mgb::ssprintf(
  792. "tensor name must be %s input tensor name or the registered "
  793. "output tensor name if NetworkIO is set, if NetworkIO is not set, "
  794. "the output tensor is all the network output tensor, or the output "
  795. "tensor is only the registered tensor.",
  796. io_name.c_str()));
  797. return nullptr;
  798. }
  799. std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) {
  800. return get_io_tensor(get_input_name(index));
  801. }
  802. std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) {
  803. return get_io_tensor(get_output_name(index));
  804. }
  805. //! set opr algorithm selection strategy in the network
  806. void NetworkImplDft::set_network_algo_policy(
  807. LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
  808. bool binary_equal_between_batch) {
  809. using S = megdnn::param::ExecutionPolicy::Strategy;
  810. auto dst_strategy = static_cast<S>(0);
  811. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_HEURISTIC) {
  812. dst_strategy = dst_strategy | S::HEURISTIC;
  813. }
  814. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_PROFILE) {
  815. dst_strategy = dst_strategy | S::PROFILE;
  816. }
  817. if (static_cast<uint32_t>(strategy) &
  818. LiteAlgoSelectStrategy::LITE_ALGO_REPRODUCIBLE) {
  819. dst_strategy = dst_strategy | S::REPRODUCIBLE;
  820. }
  821. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
  822. dst_strategy = dst_strategy | S::OPTIMIZED;
  823. }
  824. if (static_cast<uint32_t>(dst_strategy) != 0)
  825. m_execution_policy = dst_strategy;
  826. auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
  827. fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
  828. fast_run_config.shared_batch_size = shared_batch_size;
  829. if (m_execute_func) {
  830. LITE_WARN(
  831. "set_network_algo_policy maybe cause error after loaded "
  832. "network!!!!");
  833. modify_exection_policy();
  834. }
  835. }
  836. void NetworkImplDft::modify_exection_policy() {
  837. auto& vars = m_load_result.output_var_list;
  838. if (static_cast<uint32_t>(m_execution_policy) != 0) {
  839. mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
  840. }
  841. }
  842. //! set opr algorithm selection strategy in the network
  843. void NetworkImplDft::set_network_algo_workspace_limit(size_t workspace_limit) {
  844. mgb::SymbolVarArray vars;
  845. for (auto i : m_output_spec) {
  846. vars.push_back(i.first);
  847. }
  848. mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit);
  849. }
  850. //! get the input tensor name in the order of graph
  851. std::vector<const char*> NetworkImplDft::get_all_output_name() const {
  852. std::vector<const char*> output_names;
  853. for (auto& output : m_network_io->outputs) {
  854. output_names.push_back(output.name.c_str());
  855. }
  856. return output_names;
  857. }
  858. //! get the input tensor name in the order of graph
  859. std::vector<const char*> NetworkImplDft::get_all_input_name() const {
  860. std::vector<const char*> input_names;
  861. for (auto& input : m_load_result.tensor_map) {
  862. input_names.push_back(input.first.c_str());
  863. }
  864. return input_names;
  865. }
  866. //! get the output tensor name in the order of graph
  867. const char* NetworkImplDft::get_output_name(size_t index) const {
  868. LITE_ASSERT(
  869. index < m_load_result.output_var_list.size(),
  870. "The output tensor index is large than the total outputs number.");
  871. return m_load_result.output_var_list[index].node()->name().c_str();
  872. }
  873. //! get the input tensor name in the order of graph
  874. const char* NetworkImplDft::get_input_name(size_t index) const {
  875. LITE_ASSERT(
  876. index < m_load_result.tensor_map.size(),
  877. "The input tensor index is large than the total inputs number.");
  878. size_t i = 0;
  879. for (auto& input : m_load_result.tensor_map) {
  880. if (i == index) {
  881. return input.first.c_str();
  882. }
  883. i++;
  884. }
  885. LITE_THROW(ssprintf("no input tensor of index %zu.", index));
  886. }
  887. //! Plugin part
  888. void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
  889. #if MGB_ENABLE_JSON
  890. m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
  891. m_profiler_output_file = profile_json_file;
  892. #else
  893. LITE_MARK_USED_VAR(profile_json_file);
  894. LITE_THROW("JSON is disable at compile time.");
  895. #endif
  896. }
  897. void NetworkImplDft::enable_io_txt_dump(std::string io_txt_out_file) {
  898. auto iodump = std::make_unique<mgb::TextOprIODump>(
  899. m_load_config.comp_graph.get(), io_txt_out_file.c_str());
  900. iodump->print_addr(false);
  901. m_iodump = std::move(iodump);
  902. }
  903. void NetworkImplDft::enable_io_bin_dump(std::string io_bin_out_dir) {
  904. m_iodump = std::make_unique<mgb::BinaryOprIODump>(
  905. m_load_config.comp_graph.get(), io_bin_out_dir.c_str());
  906. }
  907. void inline NetworkImplDft::output_plugin_result() const {
  908. #if MGB_ENABLE_JSON
  909. if (m_profiler && m_execute_func) {
  910. m_profiler->to_json_full(m_execute_func.get())
  911. ->writeto_fpath(m_profiler_output_file);
  912. }
  913. #endif
  914. }
  915. void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const {
  916. #ifndef __IN_TEE_ENV__
  917. #if MGB_ENABLE_JSON
  918. m_execute_func->get_static_memory_alloc_info(log_dir);
  919. return;
  920. #endif
  921. #endif
  922. LITE_MARK_USED_VAR(log_dir);
  923. }
  924. void NetworkImplDft::enable_global_layout_transform() {
  925. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  926. switch (m_user_config->device_type) {
  927. case LiteDeviceType::LITE_CPU:
  928. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
  929. break;
  930. case LiteDeviceType::LITE_CUDA:
  931. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
  932. break;
  933. default:
  934. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  935. LITE_WARN(
  936. "lite compnode type: enum value: %d. is unspecial for layout "
  937. "transform",
  938. (int)(m_user_config->device_type));
  939. }
  940. m_set_layout_transform = true;
  941. }
  942. void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_path) {
  943. if (m_set_layout_transform) {
  944. auto out_file = mgb::serialization::OutputFile::make_fs(
  945. optimized_model_path.c_str(), 'w');
  946. using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
  947. DumpConfig config{1, false, false};
  948. auto dumper = mgb::serialization::GraphDumper::make(
  949. std::move(out_file), m_format.val());
  950. dumper->dump(m_load_result.output_var_list, config);
  951. } else {
  952. LITE_THROW(
  953. ssprintf("dump layout transform model should call "
  954. "enable_global_layout_transform before"));
  955. }
  956. }
  957. NetworkIO lite::get_model_io_info_dft(
  958. const std::string& model_path, const Config& config) {
  959. FILE* fin = fopen(model_path.c_str(), "rb");
  960. LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
  961. fseek(fin, 0, SEEK_END);
  962. size_t size = ftell(fin);
  963. fseek(fin, 0, SEEK_SET);
  964. void* ptr = malloc(size);
  965. std::shared_ptr<void> buf{ptr, ::free};
  966. auto nr = fread(buf.get(), 1, size, fin);
  967. LITE_ASSERT(nr == size);
  968. fclose(fin);
  969. return get_model_io_info_dft(ptr, size, config);
  970. }
  971. NetworkIO lite::get_model_io_info_dft(
  972. const void* model_mem, size_t size, const Config& config) {
  973. std::shared_ptr<void> model{const_cast<void*>(model_mem), [](void*) {}};
  974. auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false);
  975. auto format =
  976. mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file);
  977. if (!format.valid()) {
  978. LITE_THROW("invalid model format");
  979. }
  980. auto loader =
  981. mgb::serialization::GraphLoader::make(std::move(input_file), format.val());
  982. mgb::serialization::GraphLoadConfig load_config;
  983. load_config.comp_graph = mgb::ComputingGraph::make();
  984. if (config.has_compression) {
  985. load_config.tensor_value_loader = decompressed_tensor_value_loader;
  986. }
  987. auto compnode_locator = to_compnode_locator(config.device_type);
  988. load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) {
  989. if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
  990. loc.type = compnode_locator.type;
  991. }
  992. loc.device = compnode_locator.device;
  993. };
  994. auto load_result = loader->load(load_config, true);
  995. NetworkIO IOs;
  996. for (auto&& in_tensor_iter : load_result.tensor_map) {
  997. IO in_io;
  998. in_io.name = in_tensor_iter.first;
  999. in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout());
  1000. IOs.inputs.push_back(in_io);
  1001. }
  1002. auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* {
  1003. auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager();
  1004. using InferType = mgb::cg::static_infer::InferType;
  1005. if (static_infer_mgr.get_infer_type(var.node()).shape &
  1006. (InferType::CONST | InferType::RT_STATIC)) {
  1007. return static_infer_mgr.infer_shape_fallible(var.node());
  1008. } else {
  1009. return nullptr;
  1010. }
  1011. };
  1012. for (auto&& out : load_result.output_var_list) {
  1013. IO out_io;
  1014. out_io.name = out.node()->name();
  1015. if (auto shape = infer_shape(out)) {
  1016. out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()});
  1017. } else {
  1018. out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()});
  1019. }
  1020. IOs.outputs.push_back(out_io);
  1021. }
  1022. return IOs;
  1023. }
  1024. #endif
  1025. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}