You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_impl.cpp 41 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017
  1. /**
  2. * \file src/mge/network_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite_build_config.h"
  12. #if LITE_BUILD_WITH_MGE
  13. #include "common.h"
  14. #include "lite/network.h"
  15. #include "memory_allocator.h"
  16. #include "network_impl.h"
  17. #include "parse_info/parse_info_base.h"
  18. #include "parse_model/model_parser.h"
  19. #include "megbrain/common.h"
  20. #include "megbrain/comp_node.h"
  21. #include "megbrain/comp_node_env.h"
  22. #include "megbrain/graph.h"
  23. #include "megbrain/graph/cg.h"
  24. #include "megbrain/opr/io.h"
  25. #include "megbrain/opr/tensor_manip.h"
  26. #include "megbrain/tensor.h"
  27. #if MGB_OPENCL
  28. #include "megcore_opencl.h"
  29. #endif
  30. #include <fstream>
  31. #include <memory>
  32. #include <set>
  33. using namespace lite;
  34. using namespace mgb;
  35. LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);
  36. void NetworkImplDft::set_config(const Config& config) {
  37. m_user_config = std::make_unique<Config>();
  38. *m_user_config = config;
  39. m_compnode_locator = to_compnode_locator(m_user_config->device_type);
  40. m_compnode_locator.device = config.device_id;
  41. }
  42. void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
  43. application_config();
  44. const auto& src_impl = src_network->cast_final_safe<NetworkImplDft>();
  45. LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded.");
  46. m_load_result = src_impl.m_loader->load(m_load_config, true);
  47. //! flag weather the mode is cross compnode model
  48. cross_compnode_model_detect();
  49. //! update the IO of the network
  50. update_io();
  51. //! replace the IO when there is device input or output
  52. compile_graph();
  53. }
  54. void NetworkImplDft::application_config() {
  55. auto device_type = m_user_config->device_type;
  56. m_compnode_locator.type = to_compnode_locator(device_type).type;
  57. m_compnode_locator.device = m_user_config->device_id;
  58. if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) {
  59. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  60. m_compnode_locator.device = m_user_config->device_id;
  61. }
  62. //! model options
  63. #define ConfigOption(mge_name, lite_name) \
  64. options.mge_name = m_user_config->options.lite_name;
  65. auto&& options = m_load_config.comp_graph->options();
  66. ConfigOption(graph_opt.weight_preprocess, weight_preprocess);
  67. ConfigOption(graph_opt.fuse_preprocess, fuse_preprocess);
  68. ConfigOption(fake_next_exec, fake_next_exec);
  69. ConfigOption(var_sanity_check_first_run, var_sanity_check_first_run);
  70. m_load_config.const_var_shape = m_user_config->options.const_shape;
  71. ConfigOption(force_dynamic_alloc, force_dynamic_alloc);
  72. ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc);
  73. ConfigOption(
  74. force_output_use_user_specified_memory,
  75. force_output_use_user_specified_memory);
  76. ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change);
  77. LITE_ASSERT(
  78. m_user_config->options.jit_level == 0 ||
  79. (m_user_config->options.jit_level > 0 &&
  80. device_type == LiteDeviceType::LITE_CUDA),
  81. "jit only support in cuda device.");
  82. ConfigOption(graph_opt.jit, jit_level);
  83. ConfigOption(comp_node_seq_record_level, comp_node_seq_record_level);
  84. ConfigOption(graph_opt_level, graph_opt_level);
  85. ConfigOption(async_exec_level, async_exec_level);
  86. #undef ConfigOption
  87. #define ConfigOptionLayoutTransform(name) \
  88. if (m_user_config->options.name) { \
  89. options.graph_opt.name(); \
  90. }
  91. ConfigOptionLayoutTransform(enable_nchw44);
  92. ConfigOptionLayoutTransform(enable_nchw44_dot);
  93. ConfigOptionLayoutTransform(enable_nchw88);
  94. ConfigOptionLayoutTransform(enable_nhwcd4);
  95. ConfigOptionLayoutTransform(enable_nchw4);
  96. ConfigOptionLayoutTransform(enable_nchw32);
  97. ConfigOptionLayoutTransform(enable_nchw64);
  98. #undef ConfigOptionLayoutTransform
  99. if (m_user_config->has_compression) {
  100. m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
  101. }
  102. //! if device is LITE_NONE, the compnode information is stored in model or
  103. //! xpu in MegEngine
  104. if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
  105. m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
  106. if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
  107. loc.type = m_compnode_locator.type;
  108. }
  109. loc.device = m_compnode_locator.device;
  110. //! if user set the thread number and the compnode is multithread
  111. if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
  112. m_nr_threads != 1) {
  113. loc.stream = m_nr_threads;
  114. } else {
  115. loc.stream = m_compnode_locator.stream;
  116. }
  117. };
  118. }
  119. }
  120. void NetworkImplDft::set_memory_allocator(std::shared_ptr<Allocator> user_allocator) {
  121. auto allocator = std::make_shared<UserStaticMemAlloc>(user_allocator);
  122. LITE_ASSERT(m_load_config.comp_graph);
  123. m_load_config.comp_graph->set_device_memory_allocator(allocator);
  124. }
  125. //! share the runtime memory with other network, the weights is not shared
  126. void NetworkImplDft::share_runtime_memory_with(Network::NetworkImplBase* network_impl) {
  127. LITE_ASSERT(network_impl);
  128. LITE_ASSERT(m_load_config.comp_graph);
  129. m_load_config.comp_graph->share_device_memory_with(*(
  130. network_impl->cast_final_safe<NetworkImplDft>().m_load_config.comp_graph));
  131. }
  132. void NetworkImplDft::set_cpu_inplace_mode() {
  133. LITE_ASSERT(
  134. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  135. "cpu inplace mode is only avaliable in CPU.");
  136. m_is_cpu_inplace_mode = true;
  137. if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) {
  138. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
  139. } else {
  140. LITE_ASSERT(
  141. m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD,
  142. "cpu inplace mode is only avaliable in CPU.");
  143. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  144. }
  145. }
  146. void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) {
  147. LITE_ASSERT(
  148. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  149. "multi threads mode is only avaliable in CPU.");
  150. if (nr_threads > 1) {
  151. m_nr_threads = nr_threads;
  152. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  153. m_compnode_locator.nr_threads = nr_threads;
  154. }
  155. }
  156. void NetworkImplDft::set_runtime_thread_affinity(
  157. const ThreadAffinityCallback& thread_affinity_callback) {
  158. LITE_ASSERT(
  159. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  160. "multi threads mode is only avaliable in CPU.");
  161. mgb::CompNode::Locator loc;
  162. m_load_config.comp_node_mapper(loc);
  163. auto cn = mgb::CompNode::load(loc);
  164. if (m_nr_threads > 1) {
  165. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().set_affinity(
  166. thread_affinity_callback);
  167. } else {
  168. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().dispatch(
  169. [thread_affinity_callback](void) { thread_affinity_callback(0); });
  170. }
  171. }
  172. void NetworkImplDft::set_device_id(int device_id) {
  173. m_compnode_locator.device = device_id;
  174. m_user_config->device_id = device_id;
  175. }
  176. void NetworkImplDft::set_stream_id(int stream_id) {
  177. m_compnode_locator.stream = stream_id;
  178. }
  179. void NetworkImplDft::use_tensorrt() {
  180. auto&& options = m_load_config.comp_graph->options();
  181. options.graph_opt.tensorrt = true;
  182. }
  183. //! set the callback in async model
  184. void NetworkImplDft::set_async_callback(const AsyncCallback& callback) {
  185. LITE_ASSERT(!m_is_cpu_inplace_mode, "cpu inplace mode not support async mode");
  186. LITE_ASSERT(
  187. m_user_config->device_type == LiteDeviceType::LITE_CPU ||
  188. m_user_config->device_type == LiteDeviceType::LITE_CUDA,
  189. "Now only cpu and cuda>10.0 support async mode");
  190. m_async = true;
  191. m_async_callback = std::move(callback);
  192. }
  193. void NetworkImplDft::make_output_spec() {
  194. m_output_spec.clear();
  195. for (auto&& out : m_network_io->outputs) {
  196. if (m_load_result.output_var_map.count(out.name)) {
  197. auto&& load_out = m_load_result.output_var_map[out.name];
  198. auto cb = [&out, this](const mgb::DeviceTensorND& dv) mutable {
  199. mgb::CompNode comp_node = dv.comp_node();
  200. if (out.io_type == LiteIOType::LITE_IO_SHAPE) {
  201. auto mgb_layout = dv.layout();
  202. out.lite_tensor->set_layout(to_lite_layout(mgb_layout));
  203. } else {
  204. TensorHelper::implement(out.lite_tensor)
  205. ->cast_final_safe<TensorImplDft>()
  206. .copy_from_mge_tensor(dv);
  207. out.lite_tensor->update_from_implement();
  208. }
  209. if (m_async) {
  210. out.have_sync = true;
  211. bool need_exec_cb = true;
  212. for (auto&& j : m_network_io->outputs) {
  213. if (!j.have_sync) {
  214. need_exec_cb = false;
  215. }
  216. }
  217. if (need_exec_cb) {
  218. for (auto&& j : m_network_io->outputs) {
  219. j.have_sync = false;
  220. }
  221. comp_node.add_callback([this]() { finish(); });
  222. }
  223. }
  224. };
  225. //! if write to user-specified memory, the CallbackCaller must be nullptr.
  226. if (m_user_config->options.force_output_use_user_specified_memory ||
  227. m_user_config->options.force_output_dynamic_alloc) {
  228. m_output_spec.emplace_back(load_out, nullptr);
  229. } else {
  230. m_output_spec.emplace_back(load_out, std::move(cb));
  231. }
  232. } else {
  233. LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str()));
  234. }
  235. }
  236. }
  237. void NetworkImplDft::replace_dev_input_pass() {
  238. mgb::CompNode::Locator locator;
  239. m_load_config.comp_node_mapper(locator);
  240. //! CPU is not need use device input
  241. if (locator.type == mgb::CompNode::DeviceType::CPU) {
  242. return;
  243. }
  244. //! repalce the H2D with VolatileSharedDeviceTensor, and keep the dev tensor
  245. //! in m_network_io.input, user can directly change the dev tensor
  246. //! storage through m_network_io.input.lite_tensor->reset() befor forward
  247. using DeviceTensorMap =
  248. std::unordered_map<std::string, std::shared_ptr<mgb::DeviceTensorND>>;
  249. DeviceTensorMap name2dev_tensor;
  250. mgb::ThinHashMap<mgb::HostTensorND*, mgb::SymbolVar> host_val2var;
  251. //! construct host_val2var that maps from host tensor to corresponding var
  252. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  253. if (opr->same_type<mgb::opr::Host2DeviceCopy>()) {
  254. mgb::HostTensorND* tensor =
  255. opr->cast_final<mgb::opr::Host2DeviceCopy>().host_data().get();
  256. host_val2var[tensor] = opr->output(0);
  257. }
  258. };
  259. mgb::cg::DepOprIter dep_iter{on_opr};
  260. for (auto i : m_load_result.output_var_list) {
  261. dep_iter.add(i.node()->owner_opr());
  262. }
  263. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> inp_var_map, out_var_map;
  264. mgb::SmallVector<std::string> to_clear;
  265. for (auto&& config_in : m_network_io->inputs) {
  266. if (!config_in.is_host) {
  267. auto host_val = m_load_result.tensor_map[config_in.name];
  268. auto dev_val = TensorHelper::implement(config_in.lite_tensor)
  269. ->cast_final_safe<TensorImplDft>()
  270. .m_dev_tensor;
  271. auto dev_var = mgb::opr::VolatileSharedDeviceTensor::make(
  272. *m_load_result.graph, dev_val, {config_in.name});
  273. inp_var_map[host_val2var.at(host_val.get())] = dev_var;
  274. name2dev_tensor[config_in.name] = dev_val;
  275. }
  276. }
  277. auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map);
  278. for (size_t i = 0; i < new_ovar.size(); ++i) {
  279. out_var_map[m_load_result.output_var_list[i]] = new_ovar[i];
  280. }
  281. for (auto&& i : m_load_result.output_var_map) {
  282. i.second = out_var_map.at(i.second);
  283. }
  284. for (auto&& i : m_load_result.output_var_map_id) {
  285. i.second = out_var_map.at(i.second);
  286. }
  287. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  288. new_ovar[i].rename(m_load_result.output_var_list[i].node()->name());
  289. }
  290. m_load_result.output_var_list = std::move(new_ovar);
  291. }
  292. void NetworkImplDft::cross_compnode_model_detect() {
  293. mgb::ThinHashSet<LiteDeviceType> nr_used_device_type;
  294. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  295. for (auto j : opr->output()) {
  296. if (j->comp_node() != mgb::CompNode::default_cpu()) {
  297. nr_used_device_type.insert(
  298. get_device_from_locator(j->comp_node().locator()));
  299. }
  300. }
  301. };
  302. mgb::cg::DepOprIter dep_iter{on_opr};
  303. for (auto i : m_load_result.output_var_list) {
  304. dep_iter.add(i.node()->owner_opr());
  305. }
  306. m_nr_device_type = nr_used_device_type.size();
  307. }
  308. void NetworkImplDft::adapt_option_valid() {
  309. auto&& options = m_load_config.comp_graph->options();
  310. if (m_user_config->options.force_output_use_user_specified_memory) {
  311. for (auto&& out : m_load_result.output_var_list) {
  312. auto opr = out.node()->owner_opr();
  313. //! all the dest operator inherit from ReadonlyFwdHelper can't
  314. //! support force_output_use_user_specified_memory options
  315. if (opr->try_cast_final<mgb::opr::Reshape>() ||
  316. opr->try_cast_final<mgb::opr::Broadcast>() ||
  317. opr->try_cast_final<mgb::opr::Subtensor>() ||
  318. opr->try_cast_final<mgb::opr::AxisAddRemove>() ||
  319. opr->try_cast_final<mgb::opr::Dimshuffle>()) {
  320. m_user_config->options.force_output_use_user_specified_memory = false;
  321. options.force_output_use_user_specified_memory = false;
  322. LITE_WARN(
  323. "detect the unsupported dest operator %s when config "
  324. "force_output_use_user_specified_memory, set "
  325. "force_output_use_user_specified_memory to false\n",
  326. opr->cname());
  327. break;
  328. }
  329. }
  330. }
  331. }
  332. void NetworkImplDft::global_layout_transform() {
  333. if (m_set_layout_transform) {
  334. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
  335. auto output_var_array = mgb::gopt::layout_transform(
  336. m_load_result.output_var_list, m_layout_transform_target);
  337. // replace symvar in output_var_list
  338. for (size_t idx = 0; idx < output_var_array.size(); ++idx) {
  339. out_var_map[m_load_result.output_var_list[idx]] = output_var_array[idx];
  340. m_load_result.output_var_list[idx] = output_var_array[idx];
  341. }
  342. // replace symvar in output_var_map_id
  343. for (auto&& item : m_load_result.output_var_map_id) {
  344. item.second = out_var_map[item.second];
  345. }
  346. // replace symvar in output_var_map
  347. for (auto&& item : m_load_result.output_var_map) {
  348. item.second = out_var_map[item.second];
  349. }
  350. }
  351. }
  352. void NetworkImplDft::load_model(
  353. std::shared_ptr<void> model_mem, size_t size,
  354. std::unordered_map<std::string, LiteAny> separate_config_map) {
  355. if (!m_loader) {
  356. m_input_file =
  357. mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
  358. m_format = mgb::serialization::GraphLoader::identify_graph_dump_format(
  359. *m_input_file);
  360. if (!m_format.valid()) {
  361. LITE_THROW("invalid model format");
  362. }
  363. m_loader = mgb::serialization::GraphLoader::make(
  364. std::move(m_input_file), m_format.val());
  365. }
  366. //! applay the user configration to mge model
  367. application_config();
  368. //! config some flag get from json config file
  369. if (separate_config_map.find("device_id") != separate_config_map.end()) {
  370. set_device_id(separate_config_map["device_id"].safe_cast<int>());
  371. }
  372. if (separate_config_map.find("number_threads") != separate_config_map.end() &&
  373. separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) {
  374. set_cpu_threads_number(
  375. separate_config_map["number_threads"].safe_cast<uint32_t>());
  376. }
  377. if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() &&
  378. separate_config_map["enable_inplace_model"].safe_cast<bool>()) {
  379. set_cpu_inplace_mode();
  380. }
  381. if (separate_config_map.find("use_tensorrt") != separate_config_map.end() &&
  382. separate_config_map["use_tensorrt"].safe_cast<bool>()) {
  383. use_tensorrt();
  384. }
  385. m_load_result = m_loader->load(m_load_config, true);
  386. modify_exection_policy();
  387. global_layout_transform();
  388. adapt_option_valid();
  389. cross_compnode_model_detect();
  390. //! update the IO of the network
  391. update_io();
  392. //! replace the IO when there is device input or output
  393. compile_graph();
  394. }
  395. void NetworkImplDft::compile_graph() {
  396. replace_dev_input_pass();
  397. make_output_spec();
  398. m_execute_func = m_load_result.graph_compile(m_output_spec);
  399. }
  400. void NetworkImplDft::start() const {
  401. if (m_start_callback) {
  402. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  403. input_io_map;
  404. for (auto&& io_inner : m_network_io->inputs) {
  405. input_io_map[io_inner.name] = {
  406. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  407. io_inner.config_layout},
  408. io_inner.lite_tensor};
  409. }
  410. m_start_callback(input_io_map);
  411. }
  412. }
  413. void NetworkImplDft::forward() {
  414. start();
  415. if (m_load_config.comp_graph &&
  416. m_user_config->options.comp_node_seq_record_level == 2) {
  417. m_load_config.comp_graph.reset();
  418. }
  419. LITE_ASSERT(m_execute_func, "forward must be called after network loaded.");
  420. m_execute_func->execute();
  421. }
  422. void NetworkImplDft::wait() {
  423. if (!m_async) {
  424. m_execute_func->wait();
  425. }
  426. finish();
  427. }
  428. void NetworkImplDft::finish() const {
  429. if (m_async) {
  430. LITE_ASSERT(m_async_callback, "The callback func must set when async mode.");
  431. m_async_callback();
  432. }
  433. if (m_finish_callback) {
  434. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  435. output_io_map;
  436. for (auto&& io_inner : m_network_io->outputs) {
  437. output_io_map[io_inner.name] = {
  438. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  439. io_inner.config_layout},
  440. io_inner.lite_tensor};
  441. }
  442. m_finish_callback(output_io_map);
  443. }
  444. output_plugin_result();
  445. }
  446. void NetworkImplDft::set_io(const NetworkIO& network_io) {
  447. m_network_io = std::make_unique<NetworkIOInner>();
  448. for (auto&& in : network_io.inputs) {
  449. m_network_io->inputs.emplace_back(in);
  450. }
  451. for (auto&& out : network_io.outputs) {
  452. m_network_io->outputs.emplace_back(out);
  453. }
  454. }
  455. void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
  456. if (var.node()->capable_shape_infer()) {
  457. auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
  458. auto shape = static_infer_mgr.infer_shape_fallible(var.node());
  459. if (!shape) {
  460. LITE_WARN(
  461. "Lite infer output shape failed, maybe the model is "
  462. "dynamic "
  463. "shape.\n");
  464. LITE_ASSERT(
  465. !m_user_config->options.force_output_use_user_specified_memory,
  466. "force_output_use_user_specified_memory can't be used when output "
  467. "shape can't be derived.");
  468. return;
  469. }
  470. Layout layout = to_lite_layout(TensorLayout{*shape, var.dtype()});
  471. tensor->set_layout(layout);
  472. }
  473. }
  474. void NetworkImplDft::update_io() {
  475. update_input();
  476. update_output();
  477. }
  478. void NetworkImplDft::update_input() {
  479. auto device_type = m_user_config->device_type;
  480. auto device_id = m_compnode_locator.device;
  481. auto stream_id = m_compnode_locator.stream;
  482. //! if cpu all input and output are host
  483. if (device_type == LiteDeviceType::LITE_CPU) {
  484. for (auto&& in : m_network_io->inputs) {
  485. in.is_host = true;
  486. }
  487. }
  488. //! if cross compnode model, modify the device input if it is not valid
  489. if (m_nr_device_type > 1) {
  490. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  491. for (auto&& config_in : m_network_io->inputs) {
  492. //! if tensor is set to device input
  493. if (in_tensor_iter.first == config_in.name && !config_in.is_host) {
  494. //! if the origin compnode of the tensor is not the device,
  495. //! set the input to host
  496. if (get_device_from_locator(
  497. in_tensor_iter.second->comp_node().locator()) ==
  498. LiteDeviceType::LITE_CPU) {
  499. config_in.is_host = true;
  500. LITE_WARN(
  501. "The input tensor %s of the cross device model "
  502. "should not from device.",
  503. config_in.name.c_str());
  504. }
  505. }
  506. }
  507. }
  508. }
  509. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  510. bool found = false;
  511. for (auto&& config_in : m_network_io->inputs) {
  512. if (in_tensor_iter.first == config_in.name) {
  513. found = true;
  514. if (config_in.is_host) {
  515. config_in.lite_tensor = std::make_shared<Tensor>(
  516. device_id, stream_id, device_type, true);
  517. TensorHelper::implement(config_in.lite_tensor)
  518. ->cast_final_safe<TensorImplDft>()
  519. .m_host_tensor = in_tensor_iter.second;
  520. config_in.lite_tensor->update_from_implement();
  521. } else {
  522. config_in.lite_tensor =
  523. std::make_shared<Tensor>(device_id, stream_id, device_type);
  524. config_in.lite_tensor->set_layout(
  525. to_lite_layout(in_tensor_iter.second->layout()));
  526. }
  527. TensorHelper::implement(config_in.lite_tensor)
  528. ->cast_final_safe<TensorImplDft>()
  529. .m_record_reset =
  530. m_user_config->options.comp_node_seq_record_level > 0;
  531. if (config_in.config_layout.ndim &&
  532. !(config_in.config_layout == config_in.lite_tensor->get_layout())) {
  533. config_in.lite_tensor->set_layout(config_in.config_layout);
  534. }
  535. }
  536. }
  537. if (!found) {
  538. IOInner io_in;
  539. io_in.name = in_tensor_iter.first;
  540. io_in.lite_tensor =
  541. std::make_shared<Tensor>(device_id, stream_id, device_type, true);
  542. TensorHelper::implement(io_in.lite_tensor)
  543. ->cast_final_safe<TensorImplDft>()
  544. .m_host_tensor = in_tensor_iter.second;
  545. TensorHelper::implement(io_in.lite_tensor)
  546. ->cast_final_safe<TensorImplDft>()
  547. .m_record_reset =
  548. m_user_config->options.comp_node_seq_record_level > 0;
  549. io_in.lite_tensor->update_from_implement();
  550. m_network_io->inputs.push_back(io_in);
  551. }
  552. }
  553. //! delete the IO that is not the network
  554. for (auto it = m_network_io->inputs.begin(); it != m_network_io->inputs.end();) {
  555. if (it->lite_tensor == nullptr) {
  556. LITE_LOG("%s is not the network input, ignore it.", it->name.c_str());
  557. it = m_network_io->inputs.erase(it);
  558. } else {
  559. it++;
  560. }
  561. }
  562. }
  563. void NetworkImplDft::update_output() {
  564. auto device_type = m_user_config->device_type;
  565. auto device_id = m_compnode_locator.device;
  566. auto stream_id = m_compnode_locator.stream;
  567. if (device_type == LiteDeviceType::LITE_CPU) {
  568. for (auto&& out : m_network_io->outputs) {
  569. out.is_host = true;
  570. }
  571. }
  572. //! delete the output that is not the network
  573. for (auto out_it = m_network_io->outputs.begin();
  574. out_it != m_network_io->outputs.end();) {
  575. if (std::find_if(
  576. m_load_result.output_var_list.begin(),
  577. m_load_result.output_var_list.end(), [out_it](const SymbolVar var) {
  578. return var.node()->name() == out_it->name;
  579. }) == m_load_result.output_var_list.end()) {
  580. LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str());
  581. out_it = m_network_io->outputs.erase(out_it);
  582. } else {
  583. out_it++;
  584. }
  585. }
  586. //! user config the output tensor, so only compute the config output
  587. if (m_compute_configured_output_only) {
  588. LITE_ASSERT(
  589. m_network_io->outputs.size() > 0,
  590. "compute configured output only with no configure output.");
  591. for (auto out_it = m_network_io->outputs.begin();
  592. out_it != m_network_io->outputs.end(); out_it++) {
  593. //! use pinned memory to copy form device
  594. if (out_it->is_host) {
  595. out_it->lite_tensor = std::make_shared<Tensor>(
  596. device_id, stream_id, device_type, true);
  597. } else {
  598. out_it->lite_tensor =
  599. std::make_shared<Tensor>(device_id, stream_id, device_type);
  600. }
  601. SymbolVar var;
  602. for (auto&& out_var : m_load_result.output_var_list) {
  603. if (out_var.node()->name() == out_it->name) {
  604. var = out_var;
  605. break;
  606. }
  607. }
  608. try_infer_tensor_layout(out_it->lite_tensor, var);
  609. output_tensor_copy_optimize(var, out_it->lite_tensor);
  610. TensorHelper::implement(out_it->lite_tensor)
  611. ->cast_final_safe<TensorImplDft>()
  612. .m_record_reset =
  613. m_user_config->options.comp_node_seq_record_level > 0;
  614. }
  615. //! user not set, use default output
  616. } else {
  617. for (auto&& out : m_load_result.output_var_list) {
  618. std::shared_ptr<Tensor> lite_tensor = nullptr;
  619. auto it = std::find_if(
  620. m_network_io->outputs.begin(), m_network_io->outputs.end(),
  621. [&out](const IOInner io) { return io.name == out.node()->name(); });
  622. if (it != m_network_io->outputs.end()) {
  623. if (it->is_host) {
  624. it->lite_tensor = std::make_shared<Tensor>(
  625. device_id, stream_id, device_type, true);
  626. } else {
  627. it->lite_tensor =
  628. std::make_shared<Tensor>(device_id, stream_id, device_type);
  629. }
  630. try_infer_tensor_layout(it->lite_tensor, out);
  631. lite_tensor = it->lite_tensor;
  632. } else {
  633. IOInner output;
  634. output.name = out.node()->name();
  635. output.lite_tensor = std::make_shared<Tensor>(
  636. device_id, stream_id, device_type, true);
  637. m_network_io->outputs.push_back({output});
  638. try_infer_tensor_layout(output.lite_tensor, out);
  639. lite_tensor = output.lite_tensor;
  640. }
  641. output_tensor_copy_optimize(out, lite_tensor);
  642. TensorHelper::implement(lite_tensor)
  643. ->cast_final_safe<TensorImplDft>()
  644. .m_record_reset =
  645. m_user_config->options.comp_node_seq_record_level > 0;
  646. }
  647. }
  648. }
  649. void NetworkImplDft::output_tensor_copy_optimize(
  650. Var var, std::shared_ptr<Tensor> tensor) {
  651. LITE_ASSERT(
  652. !(m_user_config->options.force_output_use_user_specified_memory &&
  653. m_user_config->options.force_output_dynamic_alloc),
  654. "Can't set force_output_use_user_specified_memory and "
  655. "force_output_dynamic_alloc at the same time.");
  656. if (m_user_config->options.force_output_use_user_specified_memory) {
  657. bool in_record = m_user_config->options.comp_node_seq_record_level > 0;
  658. TensorHelper::implement(tensor)
  659. ->cast_final_safe<TensorImplDft>()
  660. .set_reset_callback([var, in_record](TensorImplDft* dft_tensor) {
  661. dft_tensor->device_share_host_memory();
  662. auto dv = dft_tensor->dev_tensor().get();
  663. dv->comp_node(var.node()->comp_node(), true);
  664. var.node()->init_mem_plan(dv);
  665. if (in_record) {
  666. auto&& device_tensor = var.node()->mutable_dev_tensor();
  667. device_tensor.only_reset_raw_storage(dv->storage());
  668. } else {
  669. var.node()->reset_dev_tensor_from_tensor(*dv);
  670. }
  671. });
  672. }
  673. if (m_user_config->options.force_output_dynamic_alloc) {
  674. TensorHelper::implement(tensor)
  675. ->cast_final_safe<TensorImplDft>()
  676. .set_get_memory_callback([var](TensorImplDft* dft_tensor) {
  677. if (dft_tensor->is_host()) {
  678. auto host_tensor = dft_tensor->m_host_tensor;
  679. *host_tensor =
  680. HostTensorND::make_proxy(var.node()->dev_tensor());
  681. } else {
  682. auto dev_tensor = dft_tensor->m_dev_tensor;
  683. *dev_tensor = var.node()->dev_tensor();
  684. }
  685. });
  686. }
  687. }
  688. std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
  689. std::string io_name, LiteTensorPhase phase) {
  690. if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) {
  691. for (auto&& config_in : m_network_io->inputs) {
  692. if (io_name == config_in.name) {
  693. return config_in.lite_tensor;
  694. }
  695. }
  696. }
  697. if (phase == LiteTensorPhase::LITE_OUTPUT || phase == LiteTensorPhase::LITE_IO) {
  698. for (auto&& config_out : m_network_io->outputs) {
  699. if (io_name == config_out.name) {
  700. config_out.lite_tensor->update_from_implement();
  701. return config_out.lite_tensor;
  702. }
  703. }
  704. }
  705. LITE_THROW(mgb::ssprintf(
  706. "tensor name must be %s input tensor name or the registered "
  707. "output tensor name if NetworkIO is set, if NetworkIO is not set, "
  708. "the output tensor is all the network output tensor, or the output "
  709. "tensor is only the registered tensor.",
  710. io_name.c_str()));
  711. return nullptr;
  712. }
  713. std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) {
  714. return get_io_tensor(get_input_name(index));
  715. }
  716. std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) {
  717. return get_io_tensor(get_output_name(index));
  718. }
  719. //! set opr algorithm selection strategy in the network
  720. void NetworkImplDft::set_network_algo_policy(
  721. LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
  722. bool binary_equal_between_batch) {
  723. using S = megdnn::param::ExecutionPolicy::Strategy;
  724. auto dst_strategy = static_cast<S>(0);
  725. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_HEURISTIC) {
  726. dst_strategy = dst_strategy | S::HEURISTIC;
  727. }
  728. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_PROFILE) {
  729. dst_strategy = dst_strategy | S::PROFILE;
  730. }
  731. if (static_cast<uint32_t>(strategy) &
  732. LiteAlgoSelectStrategy::LITE_ALGO_REPRODUCIBLE) {
  733. dst_strategy = dst_strategy | S::REPRODUCIBLE;
  734. }
  735. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
  736. dst_strategy = dst_strategy | S::OPTIMIZED;
  737. }
  738. if (static_cast<uint32_t>(dst_strategy) != 0)
  739. m_execution_policy = dst_strategy;
  740. auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
  741. fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
  742. fast_run_config.shared_batch_size = shared_batch_size;
  743. if (m_execute_func) {
  744. LITE_WARN(
  745. "set_network_algo_policy maybe cause error after loaded "
  746. "network!!!!");
  747. modify_exection_policy();
  748. }
  749. }
  750. void NetworkImplDft::modify_exection_policy() {
  751. auto& vars = m_load_result.output_var_list;
  752. if (static_cast<uint32_t>(m_execution_policy) != 0) {
  753. mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
  754. }
  755. }
  756. //! set opr algorithm selection strategy in the network
  757. void NetworkImplDft::set_network_algo_workspace_limit(size_t workspace_limit) {
  758. mgb::SymbolVarArray vars;
  759. for (auto i : m_output_spec) {
  760. vars.push_back(i.first);
  761. }
  762. mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit);
  763. }
  764. //! get the input tensor name in the order of graph
  765. std::vector<const char*> NetworkImplDft::get_all_output_name() const {
  766. std::vector<const char*> output_names;
  767. for (auto& output : m_network_io->outputs) {
  768. output_names.push_back(output.name.c_str());
  769. }
  770. return output_names;
  771. }
  772. //! get the input tensor name in the order of graph
  773. std::vector<const char*> NetworkImplDft::get_all_input_name() const {
  774. std::vector<const char*> input_names;
  775. for (auto& input : m_load_result.tensor_map) {
  776. input_names.push_back(input.first.c_str());
  777. }
  778. return input_names;
  779. }
  780. //! get the output tensor name in the order of graph
  781. const char* NetworkImplDft::get_output_name(size_t index) const {
  782. LITE_ASSERT(
  783. index < m_load_result.output_var_list.size(),
  784. "The output tensor index is large than the total outputs number.");
  785. return m_load_result.output_var_list[index].node()->name().c_str();
  786. }
  787. //! get the input tensor name in the order of graph
  788. const char* NetworkImplDft::get_input_name(size_t index) const {
  789. LITE_ASSERT(
  790. index < m_load_result.tensor_map.size(),
  791. "The input tensor index is large than the total inputs number.");
  792. size_t i = 0;
  793. for (auto& input : m_load_result.tensor_map) {
  794. if (i == index) {
  795. return input.first.c_str();
  796. }
  797. i++;
  798. }
  799. LITE_THROW(ssprintf("no input tensor of index %zu.", index));
  800. }
  801. //! Plugin part
  802. void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
  803. #if MGB_ENABLE_JSON
  804. m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
  805. m_profiler_output_file = profile_json_file;
  806. #else
  807. LITE_MARK_USED_VAR(profile_json_file);
  808. LITE_THROW("JSON is disable at compile time.");
  809. #endif
  810. }
  811. void NetworkImplDft::enable_io_txt_dump(std::string io_txt_out_file) {
  812. auto iodump = std::make_unique<mgb::TextOprIODump>(
  813. m_load_config.comp_graph.get(), io_txt_out_file.c_str());
  814. iodump->print_addr(false);
  815. m_iodump = std::move(iodump);
  816. }
  817. void NetworkImplDft::enable_io_bin_dump(std::string io_bin_out_dir) {
  818. m_iodump = std::make_unique<mgb::BinaryOprIODump>(
  819. m_load_config.comp_graph.get(), io_bin_out_dir.c_str());
  820. }
  821. void inline NetworkImplDft::output_plugin_result() const {
  822. #if MGB_ENABLE_JSON
  823. if (m_profiler && m_execute_func) {
  824. m_profiler->to_json_full(m_execute_func.get())
  825. ->writeto_fpath(m_profiler_output_file);
  826. }
  827. #endif
  828. }
  829. void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const {
  830. #ifndef __IN_TEE_ENV__
  831. #if MGB_ENABLE_JSON
  832. m_execute_func->get_static_memory_alloc_info(log_dir);
  833. return;
  834. #endif
  835. #endif
  836. LITE_MARK_USED_VAR(log_dir);
  837. }
  838. void NetworkImplDft::enable_global_layout_transform() {
  839. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  840. switch (m_user_config->device_type) {
  841. case LiteDeviceType::LITE_CPU:
  842. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
  843. break;
  844. case LiteDeviceType::LITE_CUDA:
  845. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
  846. break;
  847. default:
  848. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  849. LITE_WARN(
  850. "lite compnode type: enum value: %d. is unspecial for layout "
  851. "transform",
  852. (int)(m_user_config->device_type));
  853. }
  854. m_set_layout_transform = true;
  855. }
  856. void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_path) {
  857. if (m_set_layout_transform) {
  858. auto out_file = mgb::serialization::OutputFile::make_fs(
  859. optimized_model_path.c_str(), 'w');
  860. using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
  861. DumpConfig config{1, false, false};
  862. auto dumper = mgb::serialization::GraphDumper::make(
  863. std::move(out_file), m_format.val());
  864. dumper->dump(m_load_result.output_var_list, config);
  865. } else {
  866. LITE_THROW(
  867. ssprintf("dump layout transform model should call "
  868. "enable_global_layout_transform before"));
  869. }
  870. }
  871. NetworkIO lite::get_model_io_info_dft(
  872. const std::string& model_path, const Config& config) {
  873. FILE* fin = fopen(model_path.c_str(), "rb");
  874. LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
  875. fseek(fin, 0, SEEK_END);
  876. size_t size = ftell(fin);
  877. fseek(fin, 0, SEEK_SET);
  878. void* ptr = malloc(size);
  879. std::shared_ptr<void> buf{ptr, ::free};
  880. auto nr = fread(buf.get(), 1, size, fin);
  881. LITE_ASSERT(nr == size);
  882. fclose(fin);
  883. return get_model_io_info_dft(ptr, size, config);
  884. }
  885. NetworkIO lite::get_model_io_info_dft(
  886. const void* model_mem, size_t size, const Config& config) {
  887. std::shared_ptr<void> model{const_cast<void*>(model_mem), [](void*) {}};
  888. auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false);
  889. auto format =
  890. mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file);
  891. if (!format.valid()) {
  892. LITE_THROW("invalid model format");
  893. }
  894. auto loader =
  895. mgb::serialization::GraphLoader::make(std::move(input_file), format.val());
  896. mgb::serialization::GraphLoadConfig load_config;
  897. load_config.comp_graph = mgb::ComputingGraph::make();
  898. if (config.has_compression) {
  899. load_config.tensor_value_loader = decompressed_tensor_value_loader;
  900. }
  901. auto compnode_locator = to_compnode_locator(config.device_type);
  902. load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) {
  903. if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
  904. loc.type = compnode_locator.type;
  905. }
  906. loc.device = compnode_locator.device;
  907. };
  908. auto load_result = loader->load(load_config, true);
  909. NetworkIO IOs;
  910. for (auto&& in_tensor_iter : load_result.tensor_map) {
  911. IO in_io;
  912. in_io.name = in_tensor_iter.first;
  913. in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout());
  914. IOs.inputs.push_back(in_io);
  915. }
  916. auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* {
  917. auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager();
  918. using InferType = mgb::cg::static_infer::InferType;
  919. if (static_infer_mgr.get_infer_type(var.node()).shape &
  920. (InferType::CONST | InferType::RT_STATIC)) {
  921. return static_infer_mgr.infer_shape_fallible(var.node());
  922. } else {
  923. return nullptr;
  924. }
  925. };
  926. for (auto&& out : load_result.output_var_list) {
  927. IO out_io;
  928. out_io.name = out.node()->name();
  929. if (auto shape = infer_shape(out)) {
  930. out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()});
  931. } else {
  932. out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()});
  933. }
  934. IOs.outputs.push_back(out_io);
  935. }
  936. return IOs;
  937. }
  938. #endif
  939. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}