You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_impl.cpp 52 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259
  1. #include "lite_build_config.h"
  2. #if LITE_BUILD_WITH_MGE
  3. #include "common.h"
  4. #include "lite/network.h"
  5. #include "memory_allocator.h"
  6. #include "network_impl.h"
  7. #include "parse_info/parse_info_base.h"
  8. #include "parse_model/model_parser.h"
  9. #include "megbrain/common.h"
  10. #include "megbrain/comp_node.h"
  11. #include "megbrain/comp_node_env.h"
  12. #include "megbrain/graph.h"
  13. #include "megbrain/graph/cg.h"
  14. #include "megbrain/opr/imgproc.h"
  15. #include "megbrain/opr/io.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/tensor.h"
  18. #if MGB_OPENCL
  19. #include "megcore_opencl.h"
  20. #endif
  21. #if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
  22. #include "cpuinfo.h"
  23. #endif
  24. #include <fstream>
  25. #include <memory>
  26. #include <set>
  27. using namespace lite;
  28. using namespace mgb;
  29. LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);
  30. void NetworkImplDft::set_config(const Config& config) {
  31. *m_user_config = config;
  32. m_compnode_locator = to_compnode_locator(m_user_config->device_type);
  33. m_compnode_locator.device = config.device_id;
  34. }
  35. void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
  36. application_config();
  37. const auto& src_impl = src_network->cast_final_safe<NetworkImplDft>();
  38. LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded.");
  39. m_load_result = src_impl.m_loader->load(m_load_config, true);
  40. configure_after_loaded();
  41. }
  42. void NetworkImplDft::application_config() {
  43. auto device_type = m_user_config->device_type;
  44. m_compnode_locator.type = to_compnode_locator(device_type).type;
  45. //! when the device id is not configured, configure it
  46. if (m_compnode_locator.device == -1) {
  47. m_compnode_locator.device = m_user_config->device_id;
  48. }
  49. if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) {
  50. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  51. if (m_compnode_locator.device == -1) {
  52. m_compnode_locator.device = m_user_config->device_id;
  53. }
  54. }
  55. //! model options
  56. #define ConfigOption(mge_name, lite_name) \
  57. options.mge_name = m_user_config->options.lite_name;
  58. auto&& options = m_load_config.comp_graph->options();
  59. ConfigOption(graph_opt.weight_preprocess, weight_preprocess);
  60. ConfigOption(graph_opt.fuse_preprocess, fuse_preprocess);
  61. ConfigOption(fake_next_exec, fake_next_exec);
  62. ConfigOption(var_sanity_check_first_run, var_sanity_check_first_run);
  63. m_load_config.const_var_shape = m_user_config->options.const_shape;
  64. ConfigOption(force_dynamic_alloc, force_dynamic_alloc);
  65. ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc);
  66. ConfigOption(
  67. force_output_use_user_specified_memory,
  68. force_output_use_user_specified_memory);
  69. ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change);
  70. LITE_ASSERT(
  71. m_user_config->options.jit_level == 0 ||
  72. (m_user_config->options.jit_level > 0 &&
  73. device_type == LiteDeviceType::LITE_CUDA),
  74. "jit only support in cuda device.");
  75. ConfigOption(graph_opt.jit, jit_level);
  76. ConfigOption(comp_node_seq_record_level, comp_node_seq_record_level);
  77. ConfigOption(graph_opt_level, graph_opt_level);
  78. ConfigOption(async_exec_level, async_exec_level);
  79. #undef ConfigOption
  80. #define ConfigOptionLayoutTransform(name) \
  81. if (m_user_config->options.name) { \
  82. options.graph_opt.name(); \
  83. }
  84. ConfigOptionLayoutTransform(enable_nchw44);
  85. ConfigOptionLayoutTransform(enable_nchw44_dot);
  86. ConfigOptionLayoutTransform(enable_nchw88);
  87. ConfigOptionLayoutTransform(enable_nhwcd4);
  88. ConfigOptionLayoutTransform(enable_nchw4);
  89. ConfigOptionLayoutTransform(enable_nchw32);
  90. ConfigOptionLayoutTransform(enable_nchw64);
  91. #undef ConfigOptionLayoutTransform
  92. if (m_user_config->has_compression) {
  93. m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
  94. }
  95. //! if device is LITE_NONE, the compnode information is stored in model or
  96. //! xpu in MegEngine
  97. if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
  98. m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
  99. if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
  100. loc.type = m_compnode_locator.type;
  101. }
  102. loc.device = m_compnode_locator.device;
  103. //! if user set the thread number and the compnode is multithread
  104. if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
  105. m_nr_threads != 1) {
  106. loc.stream = m_nr_threads;
  107. } else {
  108. loc.stream = m_compnode_locator.stream;
  109. }
  110. };
  111. }
  112. }
  113. void NetworkImplDft::set_memory_allocator(std::shared_ptr<Allocator> user_allocator) {
  114. auto allocator = std::make_shared<UserStaticMemAlloc>(user_allocator);
  115. LITE_ASSERT(m_load_config.comp_graph);
  116. m_load_config.comp_graph->set_device_memory_allocator(allocator);
  117. }
  118. //! share the runtime memory with other network, the weights is not shared
  119. void NetworkImplDft::share_runtime_memory_with(Network::NetworkImplBase* network_impl) {
  120. LITE_ASSERT(network_impl);
  121. LITE_ASSERT(m_load_config.comp_graph);
  122. m_load_config.comp_graph->share_device_memory_with(*(
  123. network_impl->cast_final_safe<NetworkImplDft>().m_load_config.comp_graph));
  124. }
  125. void NetworkImplDft::set_cpu_inplace_mode() {
  126. LITE_ASSERT(
  127. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  128. "cpu inplace mode is only avaliable in CPU.");
  129. m_is_cpu_inplace_mode = true;
  130. if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) {
  131. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
  132. m_user_config->device_id = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
  133. } else {
  134. LITE_ASSERT(
  135. m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD,
  136. "cpu inplace mode is only avaliable in CPU.");
  137. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  138. m_user_config->device_id = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  139. }
  140. }
  141. void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) {
  142. LITE_ASSERT(
  143. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  144. "multi threads mode is only avaliable in CPU.");
  145. if (nr_threads > 1) {
  146. m_nr_threads = nr_threads;
  147. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  148. if (m_is_cpu_inplace_mode) {
  149. m_compnode_locator.device =
  150. mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  151. m_user_config->device_id =
  152. mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  153. }
  154. m_compnode_locator.nr_threads = nr_threads;
  155. }
  156. }
  157. void NetworkImplDft::set_runtime_thread_affinity(
  158. const ThreadAffinityCallback& thread_affinity_callback) {
  159. LITE_ASSERT(
  160. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  161. "multi threads mode is only avaliable in CPU.");
  162. mgb::CompNode::Locator loc;
  163. m_load_config.comp_node_mapper(loc);
  164. auto cn = mgb::CompNode::load(loc);
  165. if (m_nr_threads > 1) {
  166. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().set_affinity(
  167. thread_affinity_callback);
  168. } else {
  169. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().dispatch(
  170. [thread_affinity_callback](void) { thread_affinity_callback(0); });
  171. }
  172. }
  173. void NetworkImplDft::set_device_id(int device_id) {
  174. m_compnode_locator.device = device_id;
  175. m_user_config->device_id = device_id;
  176. }
  177. void NetworkImplDft::set_stream_id(int stream_id) {
  178. m_compnode_locator.stream = stream_id;
  179. }
  180. void NetworkImplDft::use_tensorrt() {
  181. auto&& options = m_load_config.comp_graph->options();
  182. options.graph_opt.tensorrt = true;
  183. }
  184. //! set the callback in async model
  185. void NetworkImplDft::set_async_callback(const AsyncCallback& callback) {
  186. LITE_ASSERT(!m_is_cpu_inplace_mode, "cpu inplace mode not support async mode");
  187. LITE_ASSERT(
  188. m_user_config->device_type == LiteDeviceType::LITE_CPU ||
  189. m_user_config->device_type == LiteDeviceType::LITE_CUDA,
  190. "Now only cpu and cuda>10.0 support async mode");
  191. m_async = true;
  192. m_async_callback = std::move(callback);
  193. }
  194. void NetworkImplDft::make_output_spec() {
  195. m_output_spec.clear();
  196. for (auto&& out : m_network_io->outputs) {
  197. if (m_load_result.output_var_map.count(out.name)) {
  198. auto&& load_out = m_load_result.output_var_map[out.name];
  199. auto cb = [&out, this](const mgb::DeviceTensorND& dv) mutable {
  200. mgb::CompNode comp_node = dv.comp_node();
  201. if (out.io_type == LiteIOType::LITE_IO_SHAPE) {
  202. auto mgb_layout = dv.layout();
  203. out.lite_tensor->set_layout(to_lite_layout(mgb_layout));
  204. } else {
  205. TensorHelper::implement(out.lite_tensor)
  206. ->cast_final_safe<TensorImplDft>()
  207. .copy_from_mge_tensor(dv);
  208. out.lite_tensor->update_from_implement();
  209. }
  210. if (m_async) {
  211. out.have_sync = true;
  212. bool need_exec_cb = true;
  213. for (auto&& j : m_network_io->outputs) {
  214. if (!j.have_sync) {
  215. need_exec_cb = false;
  216. }
  217. }
  218. if (need_exec_cb) {
  219. for (auto&& j : m_network_io->outputs) {
  220. j.have_sync = false;
  221. }
  222. comp_node.add_callback([this]() { finish(); });
  223. }
  224. }
  225. };
  226. //! if write to user-specified memory, the CallbackCaller must be nullptr.
  227. if (m_user_config->options.force_output_use_user_specified_memory ||
  228. m_user_config->options.force_output_dynamic_alloc) {
  229. m_output_spec.emplace_back(load_out, nullptr);
  230. } else {
  231. m_output_spec.emplace_back(load_out, std::move(cb));
  232. }
  233. } else {
  234. LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str()));
  235. }
  236. }
  237. }
  238. void NetworkImplDft::replace_src_discrete_input_opr_pass() {
  239. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
  240. auto dest_with_extra_deps =
  241. get_dest_vars_with_extra_deps(m_load_result.output_var_list);
  242. gopt::SubGraph graph{dest_with_extra_deps};
  243. auto rewriter = graph.make_rewriter();
  244. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  245. bool replace_output = false;
  246. for (auto inp : opr->input()) {
  247. if ((inp->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>() ||
  248. inp->owner_opr()->same_type<mgb::opr::VolatileSharedDeviceTensor>()) &&
  249. inp->name() == m_user_config->discrete_input_name) {
  250. bool is_h2d = true;
  251. if (inp->owner_opr()->same_type<mgb::opr::Host2DeviceCopy>()) {
  252. is_h2d = true;
  253. } else {
  254. is_h2d = false;
  255. }
  256. SymbolVarArray srcs;
  257. if (is_h2d) {
  258. auto h2d = inp->owner_opr();
  259. for (auto&& i :
  260. get_discrete_tensors(m_user_config->discrete_input_name)) {
  261. auto val = TensorHelper::implement(i)
  262. ->cast_final_safe<TensorImplDft>()
  263. .m_host_tensor;
  264. LITE_ASSERT(val);
  265. srcs.push_back(mgb::opr::Host2DeviceCopy::make(
  266. *m_load_result.graph, val, h2d->config()));
  267. }
  268. } else {
  269. auto volatiled = inp->owner_opr();
  270. for (auto&& i :
  271. get_discrete_tensors(m_user_config->discrete_input_name)) {
  272. auto val = TensorHelper::implement(i)
  273. ->cast_final_safe<TensorImplDft>()
  274. .m_dev_tensor;
  275. LITE_ASSERT(val);
  276. srcs.push_back(mgb::opr::VolatileSharedDeviceTensor::make(
  277. *m_load_result.graph, val, volatiled->config()));
  278. }
  279. }
  280. if (opr->same_type<mgb::opr::WarpPerspective>()) {
  281. auto& warp = opr->cast_final<mgb::opr::WarpPerspective>();
  282. SymbolVar new_out;
  283. if (opr->input().size() == 3) {
  284. new_out = mgb::opr::WarpPerspective::make(
  285. srcs, warp.input(1), warp.input(2), warp.param(),
  286. warp.config());
  287. } else {
  288. LITE_ASSERT(opr->input().size() == 4);
  289. new_out = mgb::opr::WarpPerspective::make(
  290. srcs, warp.input(1), warp.input(2), warp.input(3),
  291. warp.param(), warp.config());
  292. }
  293. rewriter.replace_var(
  294. warp.output(0), new_out.node(),
  295. "replace WarpPerspective to WarpPerspective multi src "
  296. "version.");
  297. replace_output = true;
  298. } else {
  299. auto concat = mgb::opr::Concat::make(srcs, 0);
  300. rewriter.replace_var(inp, concat.node(), "add a concat opr.");
  301. }
  302. }
  303. }
  304. if (!replace_output) {
  305. rewriter.auto_replace_outputs(opr);
  306. }
  307. };
  308. graph.iter(on_opr);
  309. rewriter.apply_inplace();
  310. auto new_ovar = graph.endpoint_vars();
  311. new_ovar.resize(m_load_result.output_var_list.size());
  312. for (size_t i = 0; i < new_ovar.size(); ++i) {
  313. out_var_map[m_load_result.output_var_list[i]] = new_ovar[i];
  314. }
  315. for (auto&& i : m_load_result.output_var_map) {
  316. i.second = out_var_map.at(i.second);
  317. }
  318. for (auto&& i : m_load_result.output_var_map_id) {
  319. i.second = out_var_map.at(i.second);
  320. }
  321. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  322. new_ovar[i].rename(m_load_result.output_var_list[i].node()->name());
  323. }
  324. m_load_result.output_var_list = std::move(new_ovar);
  325. }
  326. void NetworkImplDft::replace_dev_input_pass() {
  327. mgb::CompNode::Locator locator;
  328. m_load_config.comp_node_mapper(locator);
  329. //! CPU is not need use device input
  330. if (locator.type == mgb::CompNode::DeviceType::CPU) {
  331. return;
  332. }
  333. //! repalce the H2D with VolatileSharedDeviceTensor, and keep the dev tensor
  334. //! in m_network_io.input, user can directly change the dev tensor
  335. //! storage through m_network_io.input.lite_tensor->reset() befor forward
  336. using DeviceTensorMap =
  337. std::unordered_map<std::string, std::shared_ptr<mgb::DeviceTensorND>>;
  338. DeviceTensorMap name2dev_tensor;
  339. mgb::ThinHashMap<mgb::HostTensorND*, mgb::SymbolVar> host_val2var;
  340. //! construct host_val2var that maps from host tensor to corresponding var
  341. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  342. if (opr->same_type<mgb::opr::Host2DeviceCopy>()) {
  343. mgb::HostTensorND* tensor =
  344. opr->cast_final<mgb::opr::Host2DeviceCopy>().host_data().get();
  345. host_val2var[tensor] = opr->output(0);
  346. }
  347. };
  348. mgb::cg::DepOprIter dep_iter{on_opr};
  349. for (auto i : m_load_result.output_var_list) {
  350. dep_iter.add(i.node()->owner_opr());
  351. }
  352. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> inp_var_map, out_var_map;
  353. mgb::SmallVector<std::string> to_clear;
  354. for (auto&& config_in : m_network_io->inputs) {
  355. if (!config_in.is_host) {
  356. auto host_val = m_load_result.tensor_map[config_in.name];
  357. auto dev_val = TensorHelper::implement(config_in.lite_tensor)
  358. ->cast_final_safe<TensorImplDft>()
  359. .m_dev_tensor;
  360. auto dev_var = mgb::opr::VolatileSharedDeviceTensor::make(
  361. *m_load_result.graph, dev_val, {config_in.name});
  362. inp_var_map[host_val2var.at(host_val.get())] = dev_var;
  363. name2dev_tensor[config_in.name] = dev_val;
  364. }
  365. //! reset lite_tensor in discrete mode
  366. if (config_in.name == m_user_config->discrete_input_name) {
  367. config_in.lite_tensor.reset();
  368. }
  369. }
  370. auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map);
  371. for (size_t i = 0; i < new_ovar.size(); ++i) {
  372. out_var_map[m_load_result.output_var_list[i]] = new_ovar[i];
  373. }
  374. for (auto&& i : m_load_result.output_var_map) {
  375. i.second = out_var_map.at(i.second);
  376. }
  377. for (auto&& i : m_load_result.output_var_map_id) {
  378. i.second = out_var_map.at(i.second);
  379. }
  380. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  381. new_ovar[i].rename(m_load_result.output_var_list[i].node()->name());
  382. }
  383. m_load_result.output_var_list = std::move(new_ovar);
  384. }
  385. void NetworkImplDft::cross_compnode_model_detect() {
  386. mgb::ThinHashSet<LiteDeviceType> nr_used_device_type;
  387. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  388. for (auto j : opr->output()) {
  389. if (j->comp_node() != mgb::CompNode::default_cpu()) {
  390. nr_used_device_type.insert(
  391. get_device_from_locator(j->comp_node().locator()));
  392. }
  393. }
  394. };
  395. mgb::cg::DepOprIter dep_iter{on_opr};
  396. for (auto i : m_load_result.output_var_list) {
  397. dep_iter.add(i.node()->owner_opr());
  398. }
  399. m_nr_device_type = nr_used_device_type.size();
  400. }
  401. void NetworkImplDft::layout_transform_optimization() {
  402. if (m_set_layout_transform) {
  403. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
  404. auto output_var_array = mgb::gopt::layout_transform(
  405. m_load_result.output_var_list, m_layout_transform_target);
  406. m_load_result.update_output_var_list(output_var_array);
  407. } else if (m_user_config->auto_optimize_inference) {
  408. //! set model weight preprocess
  409. m_load_config.comp_graph->options().graph_opt.weight_preprocess = true;
  410. LITE_LOG(
  411. "weight_preprocess is enabled, this maybe use more memory when "
  412. "infernece.");
  413. //! get the current format and data type of the model
  414. bool is_model_nchw = true;
  415. //! is any convolution is int8
  416. bool is_model_int8 = false;
  417. //! is all convolution is float32
  418. bool is_model_float32 = true;
  419. float conv_cnt = 0;
  420. float dimshuffle_cnt = 0;
  421. auto detect_int8_model = [&](const VarNode* input) {
  422. if (input->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
  423. input->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm) {
  424. is_model_int8 = true;
  425. is_model_float32 = false;
  426. } else if (input->dtype().enumv() == megdnn::DTypeEnum::Float32) {
  427. is_model_float32 = (is_model_float32 && true);
  428. } else {
  429. is_model_float32 = false;
  430. }
  431. };
  432. cg::DepOprIter dep([&](cg::OperatorNodeBase* opr) {
  433. if (auto conv = opr->try_cast_final<opr::ConvolutionForward>()) {
  434. if (conv->param().format != megdnn::param::ConvBias::Format::NCHW) {
  435. is_model_nchw = false;
  436. }
  437. conv_cnt++;
  438. detect_int8_model(conv->input(0));
  439. } else if (auto conv_bias = opr->try_cast_final<opr::ConvBias>()) {
  440. if (conv_bias->param().format !=
  441. megdnn::param::ConvBias::Format::NCHW) {
  442. is_model_nchw = false;
  443. }
  444. conv_cnt++;
  445. detect_int8_model(conv->input(0));
  446. } else if (auto dimshuffle = opr->try_cast_final<opr::Dimshuffle>()) {
  447. LITE_MARK_USED_VAR(dimshuffle);
  448. dimshuffle_cnt++;
  449. }
  450. });
  451. for (auto&& i : m_load_result.output_var_list)
  452. dep.add(i);
  453. float radio_dimshuffle_conv = 0;
  454. if (conv_cnt > 0) {
  455. radio_dimshuffle_conv = dimshuffle_cnt / conv_cnt;
  456. }
  457. //! format optimize can only applied on nchw model,
  458. //! shufflenet like model will hurt the performance when using nchw88 or nchw44
  459. //! format, here just heuristically decide the gate radio of
  460. //! dimshuffle and convolution
  461. if (!is_model_nchw || radio_dimshuffle_conv > 0.15f) {
  462. return;
  463. }
  464. //! determine the layout by the device information
  465. //! TODO: shufflenet like model use nchw88 or nchw44 will hurt the
  466. //! performance
  467. if (m_user_config->device_type == LITE_CPU) {
  468. #if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
  469. cpuinfo_initialize();
  470. //! if all convolution and matmul data type is float32
  471. if (is_model_float32) {
  472. //! if device is x86
  473. //! if x86 support avx, use format nchw88
  474. if (cpuinfo_has_x86_avx()) {
  475. m_load_config.comp_graph->options().graph_opt.enable_nchw88();
  476. LITE_LOG("Configure model inference with nchw88 format.");
  477. } else if (cpuinfo_has_x86_sse2() && !cpuinfo_has_x86_sse3()) {
  478. //! if x86 only support sse2, use format nchw44
  479. m_load_config.comp_graph->options().graph_opt.enable_nchw44();
  480. LITE_LOG("Configure model inference with nchw44 format.");
  481. } else if (cpuinfo_has_arm_neon()) {
  482. //! if device is arm, use format nchw44
  483. m_load_config.comp_graph->options().graph_opt.enable_nchw44();
  484. LITE_LOG("Configure model inference with nchw44 format.");
  485. }
  486. } else if (is_model_int8) {
  487. //! if date type of convolution is int8
  488. //! if device is arm and support dot, use nchw44-dot format
  489. if (cpuinfo_has_arm_neon() && cpuinfo_has_arm_neon_dot()) {
  490. m_load_config.comp_graph->options().graph_opt.enable_nchw44_dot();
  491. LITE_LOG("Configure model inference with nchw44-dot format.");
  492. } else if (cpuinfo_has_arm_neon()) {
  493. //! if device is arm and do not support dot, use nchw44 format
  494. m_load_config.comp_graph->options().graph_opt.enable_nchw44();
  495. LITE_LOG("Configure model inference with nchw44 format.");
  496. }
  497. }
  498. #endif
  499. }
  500. }
  501. }
  502. void NetworkImplDft::load_model(
  503. std::shared_ptr<void> model_mem, size_t size,
  504. std::unordered_map<std::string, LiteAny> separate_config_map) {
  505. if (!m_loader) {
  506. m_input_file =
  507. mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
  508. m_format = mgb::serialization::GraphLoader::identify_graph_dump_format(
  509. *m_input_file);
  510. if (!m_format.valid()) {
  511. LITE_THROW("invalid model format");
  512. }
  513. m_loader = mgb::serialization::GraphLoader::make(
  514. std::move(m_input_file), m_format.val());
  515. }
  516. //! applay the user configration to mge model
  517. application_config();
  518. //! config some flag get from json config file
  519. if (separate_config_map.find("device_id") != separate_config_map.end()) {
  520. set_device_id(separate_config_map["device_id"].safe_cast<int>());
  521. }
  522. if (separate_config_map.find("number_threads") != separate_config_map.end() &&
  523. separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) {
  524. set_cpu_threads_number(
  525. separate_config_map["number_threads"].safe_cast<uint32_t>());
  526. }
  527. if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() &&
  528. separate_config_map["enable_inplace_model"].safe_cast<bool>()) {
  529. set_cpu_inplace_mode();
  530. }
  531. if (separate_config_map.find("use_tensorrt") != separate_config_map.end() &&
  532. separate_config_map["use_tensorrt"].safe_cast<bool>()) {
  533. use_tensorrt();
  534. }
  535. m_load_result = m_loader->load(m_load_config, true);
  536. configure_after_loaded();
  537. }
  538. void NetworkImplDft::configure_after_loaded() {
  539. modify_exection_policy();
  540. layout_transform_optimization();
  541. //! find how many compnode the model has, this should call before update_io
  542. cross_compnode_model_detect();
  543. //! update the IO of the network
  544. update_io();
  545. //! replace the IO when there is device input or output
  546. compile_graph();
  547. }
  548. void NetworkImplDft::compile_graph() {
  549. replace_dev_input_pass();
  550. if (!m_user_config->discrete_input_name.empty()) {
  551. replace_src_discrete_input_opr_pass();
  552. }
  553. make_output_spec();
  554. m_execute_func = m_load_result.graph_compile(m_output_spec);
  555. }
  556. void NetworkImplDft::start() const {
  557. if (m_start_callback) {
  558. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  559. input_io_map;
  560. for (auto&& io_inner : m_network_io->inputs) {
  561. input_io_map[io_inner.name] = {
  562. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  563. io_inner.config_layout},
  564. io_inner.lite_tensor};
  565. }
  566. m_start_callback(input_io_map);
  567. }
  568. }
  569. void NetworkImplDft::forward() {
  570. start();
  571. if (m_load_config.comp_graph &&
  572. m_user_config->options.comp_node_seq_record_level == 2) {
  573. m_load_config.comp_graph.reset();
  574. }
  575. LITE_ASSERT(m_execute_func, "forward must be called after network loaded.");
  576. m_execute_func->execute();
  577. }
  578. void NetworkImplDft::wait() {
  579. if (!m_async) {
  580. m_execute_func->wait();
  581. }
  582. finish();
  583. }
  584. void NetworkImplDft::finish() const {
  585. if (m_async) {
  586. LITE_ASSERT(m_async_callback, "The callback func must set when async mode.");
  587. m_async_callback();
  588. }
  589. if (m_finish_callback) {
  590. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  591. output_io_map;
  592. for (auto&& io_inner : m_network_io->outputs) {
  593. output_io_map[io_inner.name] = {
  594. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  595. io_inner.config_layout},
  596. io_inner.lite_tensor};
  597. }
  598. m_finish_callback(output_io_map);
  599. }
  600. output_plugin_result();
  601. }
  602. void NetworkImplDft::set_io(const NetworkIO& network_io) {
  603. for (auto&& in : network_io.inputs) {
  604. m_network_io->inputs.emplace_back(in);
  605. }
  606. for (auto&& out : network_io.outputs) {
  607. m_network_io->outputs.emplace_back(out);
  608. }
  609. }
  610. void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
  611. if (var.node()->capable_shape_infer()) {
  612. auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
  613. auto shape = static_infer_mgr.infer_shape_fallible(var.node());
  614. if (!shape) {
  615. LITE_WARN(
  616. "Lite infer output shape failed, maybe the model is "
  617. "dynamic "
  618. "shape.\n");
  619. LITE_ASSERT(
  620. !m_user_config->options.force_output_use_user_specified_memory,
  621. "force_output_use_user_specified_memory can't be used when output "
  622. "shape can't be derived.");
  623. return;
  624. }
  625. Layout layout = to_lite_layout(TensorLayout{*shape, var.dtype()});
  626. tensor->set_layout(layout);
  627. }
  628. }
  629. void NetworkImplDft::update_io() {
  630. update_input();
  631. update_output();
  632. }
  633. void NetworkImplDft::update_input() {
  634. auto device_type = m_user_config->device_type;
  635. auto device_id = m_compnode_locator.device;
  636. auto stream_id = m_compnode_locator.stream;
  637. //! if cpu all input and output are host
  638. if (device_type == LiteDeviceType::LITE_CPU) {
  639. for (auto&& in : m_network_io->inputs) {
  640. in.is_host = true;
  641. }
  642. }
  643. //! if cross compnode model, modify the device input if it is not valid
  644. if (m_nr_device_type > 1) {
  645. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  646. for (auto&& config_in : m_network_io->inputs) {
  647. //! if tensor is set to device input
  648. if (in_tensor_iter.first == config_in.name && !config_in.is_host) {
  649. //! if the origin compnode of the tensor is not the device,
  650. //! set the input to host
  651. if (get_device_from_locator(
  652. in_tensor_iter.second->comp_node().locator()) ==
  653. LiteDeviceType::LITE_CPU) {
  654. config_in.is_host = true;
  655. LITE_WARN(
  656. "The input tensor %s of the cross device model "
  657. "should not from device.",
  658. config_in.name.c_str());
  659. }
  660. }
  661. }
  662. }
  663. }
  664. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  665. bool found = false;
  666. for (auto&& config_in : m_network_io->inputs) {
  667. if (in_tensor_iter.first == config_in.name) {
  668. found = true;
  669. if (config_in.is_host) {
  670. config_in.lite_tensor = std::make_shared<Tensor>(
  671. device_id, stream_id, device_type, true);
  672. TensorHelper::implement(config_in.lite_tensor)
  673. ->cast_final_safe<TensorImplDft>()
  674. .m_host_tensor = in_tensor_iter.second;
  675. config_in.lite_tensor->update_from_implement();
  676. } else {
  677. config_in.lite_tensor =
  678. std::make_shared<Tensor>(device_id, stream_id, device_type);
  679. config_in.lite_tensor->set_layout(
  680. to_lite_layout(in_tensor_iter.second->layout()));
  681. }
  682. TensorHelper::implement(config_in.lite_tensor)
  683. ->cast_final_safe<TensorImplDft>()
  684. .m_record_reset =
  685. m_user_config->options.comp_node_seq_record_level > 0;
  686. if (config_in.config_layout.ndim &&
  687. !(config_in.config_layout == config_in.lite_tensor->get_layout())) {
  688. config_in.lite_tensor->set_layout(config_in.config_layout);
  689. }
  690. }
  691. }
  692. if (!found) {
  693. IOInner io_in;
  694. io_in.name = in_tensor_iter.first;
  695. io_in.lite_tensor =
  696. std::make_shared<Tensor>(device_id, stream_id, device_type, true);
  697. TensorHelper::implement(io_in.lite_tensor)
  698. ->cast_final_safe<TensorImplDft>()
  699. .m_host_tensor = in_tensor_iter.second;
  700. TensorHelper::implement(io_in.lite_tensor)
  701. ->cast_final_safe<TensorImplDft>()
  702. .m_record_reset =
  703. m_user_config->options.comp_node_seq_record_level > 0;
  704. io_in.lite_tensor->update_from_implement();
  705. m_network_io->inputs.push_back(io_in);
  706. }
  707. }
  708. if (!m_user_config->discrete_input_name.empty()) {
  709. update_input_lite_tensors();
  710. }
  711. //! delete the IO that is not the network
  712. for (auto it = m_network_io->inputs.begin(); it != m_network_io->inputs.end();) {
  713. if (it->lite_tensor == nullptr) {
  714. LITE_LOG("%s is not the network input, ignore it.", it->name.c_str());
  715. it = m_network_io->inputs.erase(it);
  716. } else {
  717. it++;
  718. }
  719. }
  720. }
  721. //! initialization lite_tensors when input is composed of discrete multiple tensors
  722. void NetworkImplDft::update_input_lite_tensors() {
  723. auto device_type = m_user_config->device_type;
  724. auto device_id = m_compnode_locator.device;
  725. auto stream_id = m_compnode_locator.stream;
  726. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  727. if (in_tensor_iter.first != m_user_config->discrete_input_name) {
  728. continue;
  729. }
  730. for (auto&& config_in : m_network_io->inputs) {
  731. if (in_tensor_iter.first == config_in.name) {
  732. size_t bs = in_tensor_iter.second->shape(0);
  733. auto shape = in_tensor_iter.second->shape();
  734. if (config_in.config_layout.ndim) {
  735. bs = config_in.config_layout.shapes[0];
  736. for (size_t i = 0; i < config_in.config_layout.ndim; ++i) {
  737. shape.shape[i] = config_in.config_layout.shapes[i];
  738. }
  739. }
  740. shape.shape[0] = 1;
  741. for (size_t i = 0; i < bs; ++i) {
  742. HostTensorND tensor(
  743. in_tensor_iter.second->comp_node(), shape,
  744. in_tensor_iter.second->dtype(),
  745. in_tensor_iter.second->format());
  746. if (config_in.is_host) {
  747. config_in.lite_tensors.push_back(std::make_shared<Tensor>(
  748. device_id, stream_id, device_type, true));
  749. TensorHelper::implement(config_in.lite_tensors[i])
  750. ->cast_final_safe<TensorImplDft>()
  751. .m_host_tensor = std::make_shared<HostTensorND>(tensor);
  752. config_in.lite_tensors[i]->update_from_implement();
  753. } else {
  754. config_in.lite_tensors.push_back(std::make_shared<Tensor>(
  755. device_id, stream_id, device_type));
  756. config_in.lite_tensors[i]->set_layout(
  757. to_lite_layout(tensor.layout()));
  758. }
  759. TensorHelper::implement(config_in.lite_tensors[i])
  760. ->cast_final_safe<TensorImplDft>()
  761. .m_record_reset =
  762. m_user_config->options.comp_node_seq_record_level > 0;
  763. }
  764. }
  765. }
  766. }
  767. }
  768. void NetworkImplDft::update_output() {
  769. auto device_type = m_user_config->device_type;
  770. auto device_id = m_compnode_locator.device;
  771. auto stream_id = m_compnode_locator.stream;
  772. if (device_type == LiteDeviceType::LITE_CPU) {
  773. for (auto&& out : m_network_io->outputs) {
  774. out.is_host = true;
  775. }
  776. }
  777. //! delete the output that is not the network
  778. for (auto out_it = m_network_io->outputs.begin();
  779. out_it != m_network_io->outputs.end();) {
  780. if (std::find_if(
  781. m_load_result.output_var_list.begin(),
  782. m_load_result.output_var_list.end(), [out_it](const SymbolVar var) {
  783. return var.node()->name() == out_it->name;
  784. }) == m_load_result.output_var_list.end()) {
  785. LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str());
  786. out_it = m_network_io->outputs.erase(out_it);
  787. } else {
  788. out_it++;
  789. }
  790. }
  791. //! user config the output tensor, so only compute the config output
  792. if (m_compute_configured_output_only) {
  793. LITE_ASSERT(
  794. m_network_io->outputs.size() > 0,
  795. "compute configured output only with no configure output.");
  796. for (auto out_it = m_network_io->outputs.begin();
  797. out_it != m_network_io->outputs.end(); out_it++) {
  798. //! use pinned memory to copy form device
  799. if (out_it->is_host) {
  800. out_it->lite_tensor = std::make_shared<Tensor>(
  801. device_id, stream_id, device_type, true);
  802. } else {
  803. out_it->lite_tensor =
  804. std::make_shared<Tensor>(device_id, stream_id, device_type);
  805. }
  806. SymbolVar var;
  807. for (auto&& out_var : m_load_result.output_var_list) {
  808. if (out_var.node()->name() == out_it->name) {
  809. var = out_var;
  810. break;
  811. }
  812. }
  813. try_infer_tensor_layout(out_it->lite_tensor, var);
  814. output_tensor_copy_optimize(var, out_it->lite_tensor);
  815. TensorHelper::implement(out_it->lite_tensor)
  816. ->cast_final_safe<TensorImplDft>()
  817. .m_record_reset =
  818. m_user_config->options.comp_node_seq_record_level > 0;
  819. }
  820. //! user not set, use default output
  821. } else {
  822. for (auto&& out : m_load_result.output_var_list) {
  823. std::shared_ptr<Tensor> lite_tensor = nullptr;
  824. auto it = std::find_if(
  825. m_network_io->outputs.begin(), m_network_io->outputs.end(),
  826. [&out](const IOInner io) { return io.name == out.node()->name(); });
  827. if (it != m_network_io->outputs.end()) {
  828. if (it->is_host) {
  829. it->lite_tensor = std::make_shared<Tensor>(
  830. device_id, stream_id, device_type, true);
  831. } else {
  832. it->lite_tensor =
  833. std::make_shared<Tensor>(device_id, stream_id, device_type);
  834. }
  835. try_infer_tensor_layout(it->lite_tensor, out);
  836. lite_tensor = it->lite_tensor;
  837. } else {
  838. IOInner output;
  839. output.name = out.node()->name();
  840. output.lite_tensor = std::make_shared<Tensor>(
  841. device_id, stream_id, device_type, true);
  842. m_network_io->outputs.push_back({output});
  843. try_infer_tensor_layout(output.lite_tensor, out);
  844. lite_tensor = output.lite_tensor;
  845. }
  846. output_tensor_copy_optimize(out, lite_tensor);
  847. TensorHelper::implement(lite_tensor)
  848. ->cast_final_safe<TensorImplDft>()
  849. .m_record_reset =
  850. m_user_config->options.comp_node_seq_record_level > 0;
  851. }
  852. }
  853. }
  854. void NetworkImplDft::output_tensor_copy_optimize(
  855. Var var, std::shared_ptr<Tensor> tensor) {
  856. LITE_ASSERT(
  857. !(m_user_config->options.force_output_use_user_specified_memory &&
  858. m_user_config->options.force_output_dynamic_alloc),
  859. "Can't set force_output_use_user_specified_memory and "
  860. "force_output_dynamic_alloc at the same time.");
  861. if (m_user_config->options.force_output_use_user_specified_memory) {
  862. bool in_record = m_user_config->options.comp_node_seq_record_level > 0;
  863. TensorHelper::implement(tensor)
  864. ->cast_final_safe<TensorImplDft>()
  865. .set_reset_callback([var, in_record](TensorImplDft* dft_tensor) {
  866. dft_tensor->device_share_host_memory();
  867. auto dv = dft_tensor->dev_tensor().get();
  868. dv->comp_node(var.node()->comp_node(), true);
  869. var.node()->init_mem_plan(dv);
  870. if (in_record) {
  871. auto&& device_tensor = var.node()->mutable_dev_tensor();
  872. device_tensor.only_reset_raw_storage(dv->storage());
  873. } else {
  874. var.node()->reset_dev_tensor_from_tensor(*dv);
  875. }
  876. });
  877. }
  878. if (m_user_config->options.force_output_dynamic_alloc) {
  879. TensorHelper::implement(tensor)
  880. ->cast_final_safe<TensorImplDft>()
  881. .set_get_memory_callback([var](TensorImplDft* dft_tensor) {
  882. if (dft_tensor->is_host()) {
  883. auto host_tensor = dft_tensor->m_host_tensor;
  884. *host_tensor =
  885. HostTensorND::make_proxy(var.node()->dev_tensor());
  886. } else {
  887. auto dev_tensor = dft_tensor->m_dev_tensor;
  888. *dev_tensor = var.node()->dev_tensor();
  889. }
  890. });
  891. }
  892. }
  893. std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
  894. std::string io_name, LiteTensorPhase phase) {
  895. if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) {
  896. for (auto&& config_in : m_network_io->inputs) {
  897. if (io_name == config_in.name) {
  898. if (config_in.lite_tensor) {
  899. return config_in.lite_tensor;
  900. } else {
  901. LITE_THROW(mgb::ssprintf(
  902. "%s input tensor is in discrete mode, you can use "
  903. "get_discrete_tensors to get this input.",
  904. io_name.c_str()));
  905. return nullptr;
  906. }
  907. }
  908. }
  909. }
  910. if (phase == LiteTensorPhase::LITE_OUTPUT || phase == LiteTensorPhase::LITE_IO) {
  911. for (auto&& config_out : m_network_io->outputs) {
  912. if (io_name == config_out.name) {
  913. config_out.lite_tensor->update_from_implement();
  914. return config_out.lite_tensor;
  915. }
  916. }
  917. }
  918. LITE_THROW(mgb::ssprintf(
  919. "tensor name must be %s input tensor name or the registered "
  920. "output tensor name if NetworkIO is set, if NetworkIO is not set, "
  921. "the output tensor is all the network output tensor, or the output "
  922. "tensor is only the registered tensor.",
  923. io_name.c_str()));
  924. return nullptr;
  925. }
  926. std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_discrete_tensors(
  927. std::string io_name, LiteTensorPhase phase) {
  928. if (phase == LiteTensorPhase::LITE_INPUT) {
  929. for (auto&& config_in : m_network_io->inputs) {
  930. if (io_name == config_in.name &&
  931. config_in.name == m_user_config->discrete_input_name) {
  932. return config_in.lite_tensors;
  933. }
  934. }
  935. }
  936. LITE_THROW(mgb::ssprintf(
  937. "tensor name must be %s input tensor name.", io_name.c_str()));
  938. return {};
  939. }
  940. std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) {
  941. return get_io_tensor(get_input_name(index));
  942. }
  943. std::vector<std::shared_ptr<Tensor>> NetworkImplDft::get_input_tensors(size_t index) {
  944. return get_discrete_tensors(get_input_name(index));
  945. }
  946. std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) {
  947. return get_io_tensor(get_output_name(index));
  948. }
  949. //! set opr algorithm selection strategy in the network
  950. void NetworkImplDft::set_network_algo_policy(
  951. LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
  952. bool binary_equal_between_batch) {
  953. using S = megdnn::param::ExecutionPolicy::Strategy;
  954. auto dst_strategy = static_cast<S>(0);
  955. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_HEURISTIC) {
  956. dst_strategy = dst_strategy | S::HEURISTIC;
  957. }
  958. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_PROFILE) {
  959. dst_strategy = dst_strategy | S::PROFILE;
  960. }
  961. if (static_cast<uint32_t>(strategy) &
  962. LiteAlgoSelectStrategy::LITE_ALGO_REPRODUCIBLE) {
  963. dst_strategy = dst_strategy | S::REPRODUCIBLE;
  964. }
  965. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
  966. dst_strategy = dst_strategy | S::OPTIMIZED;
  967. }
  968. if (static_cast<uint32_t>(dst_strategy) != 0)
  969. m_execution_policy = dst_strategy;
  970. auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
  971. fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
  972. fast_run_config.shared_batch_size = shared_batch_size;
  973. if (m_execute_func) {
  974. LITE_WARN(
  975. "set_network_algo_policy maybe cause error after loaded "
  976. "network!!!!");
  977. modify_exection_policy();
  978. }
  979. }
  980. void NetworkImplDft::modify_exection_policy() {
  981. auto& vars = m_load_result.output_var_list;
  982. if (static_cast<uint32_t>(m_execution_policy) != 0) {
  983. mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
  984. }
  985. }
  986. //! set opr algorithm selection strategy in the network
  987. void NetworkImplDft::set_network_algo_workspace_limit(size_t workspace_limit) {
  988. mgb::SymbolVarArray vars;
  989. for (auto i : m_output_spec) {
  990. vars.push_back(i.first);
  991. }
  992. mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit);
  993. }
  994. //! get the input tensor name in the order of graph
  995. std::vector<const char*> NetworkImplDft::get_all_output_name() const {
  996. std::vector<const char*> output_names;
  997. for (auto& output : m_network_io->outputs) {
  998. output_names.push_back(output.name.c_str());
  999. }
  1000. return output_names;
  1001. }
  1002. //! get the input tensor name in the order of graph
  1003. std::vector<const char*> NetworkImplDft::get_all_input_name() const {
  1004. std::vector<const char*> input_names;
  1005. for (auto& input : m_load_result.tensor_map) {
  1006. input_names.push_back(input.first.c_str());
  1007. }
  1008. return input_names;
  1009. }
  1010. //! get the output tensor name in the order of graph
  1011. const char* NetworkImplDft::get_output_name(size_t index) const {
  1012. LITE_ASSERT(
  1013. index < m_load_result.output_var_list.size(),
  1014. "The output tensor index is large than the total outputs number.");
  1015. return m_load_result.output_var_list[index].node()->name().c_str();
  1016. }
  1017. //! get the input tensor name in the order of graph
  1018. const char* NetworkImplDft::get_input_name(size_t index) const {
  1019. LITE_ASSERT(
  1020. index < m_load_result.tensor_map.size(),
  1021. "The input tensor index is large than the total inputs number.");
  1022. size_t i = 0;
  1023. for (auto& input : m_load_result.tensor_map) {
  1024. if (i == index) {
  1025. return input.first.c_str();
  1026. }
  1027. i++;
  1028. }
  1029. LITE_THROW(ssprintf("no input tensor of index %zu.", index));
  1030. }
  1031. //! Plugin part
  1032. void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
  1033. #if MGB_ENABLE_JSON
  1034. m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
  1035. m_profiler_output_file = profile_json_file;
  1036. #else
  1037. LITE_MARK_USED_VAR(profile_json_file);
  1038. LITE_THROW("JSON is disable at compile time.");
  1039. #endif
  1040. }
  1041. void NetworkImplDft::enable_io_txt_dump(std::string io_txt_out_file) {
  1042. auto iodump = std::make_unique<mgb::TextOprIODump>(
  1043. m_load_config.comp_graph.get(), io_txt_out_file.c_str());
  1044. iodump->print_addr(false);
  1045. m_iodump = std::move(iodump);
  1046. }
  1047. void NetworkImplDft::enable_io_bin_dump(std::string io_bin_out_dir) {
  1048. m_iodump = std::make_unique<mgb::BinaryOprIODump>(
  1049. m_load_config.comp_graph.get(), io_bin_out_dir.c_str());
  1050. }
  1051. void inline NetworkImplDft::output_plugin_result() const {
  1052. #if MGB_ENABLE_JSON
  1053. if (m_profiler && m_execute_func) {
  1054. m_profiler->to_json_full(m_execute_func.get())
  1055. ->writeto_fpath(m_profiler_output_file);
  1056. }
  1057. #endif
  1058. }
  1059. void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const {
  1060. #ifndef __IN_TEE_ENV__
  1061. #if MGB_ENABLE_JSON
  1062. m_execute_func->get_static_memory_alloc_info(log_dir);
  1063. return;
  1064. #endif
  1065. #endif
  1066. LITE_MARK_USED_VAR(log_dir);
  1067. }
  1068. void NetworkImplDft::enable_global_layout_transform() {
  1069. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  1070. switch (m_user_config->device_type) {
  1071. case LiteDeviceType::LITE_CPU:
  1072. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
  1073. break;
  1074. case LiteDeviceType::LITE_CUDA:
  1075. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
  1076. break;
  1077. default:
  1078. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  1079. LITE_WARN(
  1080. "lite compnode type: enum value: %d. is unspecial for layout "
  1081. "transform",
  1082. (int)(m_user_config->device_type));
  1083. }
  1084. m_set_layout_transform = true;
  1085. }
  1086. void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_path) {
  1087. if (m_set_layout_transform) {
  1088. auto out_file = mgb::serialization::OutputFile::make_fs(
  1089. optimized_model_path.c_str(), 'w');
  1090. using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
  1091. DumpConfig config{1, false, false};
  1092. auto dumper = mgb::serialization::GraphDumper::make(
  1093. std::move(out_file), m_format.val());
  1094. dumper->dump(m_load_result.output_var_list, config);
  1095. } else {
  1096. LITE_THROW(
  1097. ssprintf("dump layout transform model should call "
  1098. "enable_global_layout_transform before"));
  1099. }
  1100. }
  1101. NetworkIO lite::get_model_io_info_dft(
  1102. const std::string& model_path, const Config& config) {
  1103. FILE* fin = fopen(model_path.c_str(), "rb");
  1104. LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
  1105. fseek(fin, 0, SEEK_END);
  1106. size_t size = ftell(fin);
  1107. fseek(fin, 0, SEEK_SET);
  1108. void* ptr = malloc(size);
  1109. std::shared_ptr<void> buf{ptr, ::free};
  1110. auto nr = fread(buf.get(), 1, size, fin);
  1111. LITE_ASSERT(nr == size);
  1112. fclose(fin);
  1113. return get_model_io_info_dft(ptr, size, config);
  1114. }
  1115. NetworkIO lite::get_model_io_info_dft(
  1116. const void* model_mem, size_t size, const Config& config) {
  1117. std::shared_ptr<void> model{const_cast<void*>(model_mem), [](void*) {}};
  1118. auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false);
  1119. auto format =
  1120. mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file);
  1121. if (!format.valid()) {
  1122. LITE_THROW("invalid model format");
  1123. }
  1124. auto loader =
  1125. mgb::serialization::GraphLoader::make(std::move(input_file), format.val());
  1126. mgb::serialization::GraphLoadConfig load_config;
  1127. load_config.comp_graph = mgb::ComputingGraph::make();
  1128. if (config.has_compression) {
  1129. load_config.tensor_value_loader = decompressed_tensor_value_loader;
  1130. }
  1131. auto compnode_locator = to_compnode_locator(config.device_type);
  1132. load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) {
  1133. if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
  1134. loc.type = compnode_locator.type;
  1135. }
  1136. loc.device = compnode_locator.device;
  1137. };
  1138. auto load_result = loader->load(load_config, true);
  1139. NetworkIO IOs;
  1140. for (auto&& in_tensor_iter : load_result.tensor_map) {
  1141. IO in_io;
  1142. in_io.name = in_tensor_iter.first;
  1143. in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout());
  1144. IOs.inputs.push_back(in_io);
  1145. }
  1146. auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* {
  1147. auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager();
  1148. using InferType = mgb::cg::static_infer::InferType;
  1149. if (static_infer_mgr.get_infer_type(var.node()).shape &
  1150. (InferType::CONST | InferType::RT_STATIC)) {
  1151. return static_infer_mgr.infer_shape_fallible(var.node());
  1152. } else {
  1153. return nullptr;
  1154. }
  1155. };
  1156. for (auto&& out : load_result.output_var_list) {
  1157. IO out_io;
  1158. out_io.name = out.node()->name();
  1159. if (auto shape = infer_shape(out)) {
  1160. out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()});
  1161. } else {
  1162. out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()});
  1163. }
  1164. IOs.outputs.push_back(out_io);
  1165. }
  1166. return IOs;
  1167. }
  1168. #endif
  1169. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}