You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_impl.cpp 38 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933
  1. /**
  2. * \file src/mge/network_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite_build_config.h"
  12. #if LITE_BUILD_WITH_MGE
  13. #include "common.h"
  14. #include "lite/network.h"
  15. #include "memory_allocator.h"
  16. #include "network_impl.h"
  17. #include "parse_info/parse_info_base.h"
  18. #include "parse_model/model_parser.h"
  19. #include "megbrain/common.h"
  20. #include "megbrain/comp_node.h"
  21. #include "megbrain/comp_node_env.h"
  22. #include "megbrain/graph.h"
  23. #include "megbrain/graph/cg.h"
  24. #include "megbrain/opr/io.h"
  25. #include "megbrain/opr/tensor_manip.h"
  26. #include "megbrain/tensor.h"
  27. #if MGB_OPENCL
  28. #include "megcore_opencl.h"
  29. #endif
  30. #include <fstream>
  31. #include <memory>
  32. #include <set>
  33. using namespace lite;
  34. using namespace mgb;
  35. LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);
  36. void NetworkImplDft::set_config(const Config& config) {
  37. m_user_config = std::make_unique<Config>();
  38. *m_user_config = config;
  39. m_compnode_locator = to_compnode_locator(m_user_config->device_type);
  40. m_compnode_locator.device = config.device_id;
  41. }
  42. void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
  43. application_config();
  44. const auto& src_impl = src_network->cast_final_safe<NetworkImplDft>();
  45. LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded.");
  46. m_load_result = src_impl.m_loader->load(m_load_config, true);
  47. //! flag weather the mode is cross compnode model
  48. cross_compnode_model_detect();
  49. //! update the IO of the network
  50. update_io();
  51. //! replace the IO when there is device input or output
  52. compile_graph();
  53. }
  54. void NetworkImplDft::application_config() {
  55. auto device_type = m_user_config->device_type;
  56. m_compnode_locator.type = to_compnode_locator(device_type).type;
  57. m_compnode_locator.device = m_user_config->device_id;
  58. if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) {
  59. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  60. m_compnode_locator.device = m_user_config->device_id;
  61. }
  62. //! model options
  63. #define ConfigOption(mge_name, lite_name) \
  64. options.mge_name = m_user_config->options.lite_name;
  65. auto&& options = m_load_config.comp_graph->options();
  66. ConfigOption(graph_opt.weight_preprocess, weight_preprocess);
  67. ConfigOption(graph_opt.fuse_preprocess, fuse_preprocess);
  68. ConfigOption(fake_next_exec, fake_next_exec);
  69. ConfigOption(var_sanity_check_first_run, var_sanity_check_first_run);
  70. m_load_config.const_var_shape = m_user_config->options.const_shape;
  71. ConfigOption(force_dynamic_alloc, force_dynamic_alloc);
  72. ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc);
  73. ConfigOption(
  74. force_output_use_user_specified_memory,
  75. force_output_use_user_specified_memory);
  76. ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change);
  77. LITE_ASSERT(
  78. m_user_config->options.jit_level == 0 ||
  79. (m_user_config->options.jit_level > 0 &&
  80. device_type == LiteDeviceType::LITE_CUDA),
  81. "jit only support in cuda device.");
  82. ConfigOption(graph_opt.jit, jit_level);
  83. ConfigOption(comp_node_seq_record_level, comp_node_seq_record_level);
  84. ConfigOption(graph_opt_level, graph_opt_level);
  85. ConfigOption(async_exec_level, async_exec_level);
  86. #undef ConfigOption
  87. #define ConfigOptionLayoutTransform(name) \
  88. if (m_user_config->options.name) { \
  89. options.graph_opt.name(); \
  90. }
  91. ConfigOptionLayoutTransform(enable_nchw44);
  92. ConfigOptionLayoutTransform(enable_nchw44_dot);
  93. ConfigOptionLayoutTransform(enable_nchw88);
  94. ConfigOptionLayoutTransform(enable_nhwcd4);
  95. ConfigOptionLayoutTransform(enable_nchw4);
  96. ConfigOptionLayoutTransform(enable_nchw32);
  97. ConfigOptionLayoutTransform(enable_nchw64);
  98. #undef ConfigOptionLayoutTransform
  99. if (m_user_config->has_compression) {
  100. m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
  101. }
  102. //! if device is LITE_NONE, the compnode information is stored in model or
  103. //! xpu in MegEngine
  104. if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
  105. m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
  106. if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
  107. loc.type = m_compnode_locator.type;
  108. }
  109. loc.device = m_compnode_locator.device;
  110. //! if user set the thread number and the compnode is multithread
  111. if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
  112. m_nr_threads != 1) {
  113. loc.stream = m_nr_threads;
  114. } else {
  115. loc.stream = m_compnode_locator.stream;
  116. }
  117. };
  118. }
  119. }
  120. void NetworkImplDft::set_memory_allocator(std::shared_ptr<Allocator> user_allocator) {
  121. auto allocator = std::make_shared<UserStaticMemAlloc>(user_allocator);
  122. LITE_ASSERT(m_load_config.comp_graph);
  123. m_load_config.comp_graph->set_device_memory_allocator(allocator);
  124. }
  125. //! share the runtime memory with other network, the weights is not shared
  126. void NetworkImplDft::share_runtime_memory_with(Network::NetworkImplBase* network_impl) {
  127. LITE_ASSERT(network_impl);
  128. LITE_ASSERT(m_load_config.comp_graph);
  129. m_load_config.comp_graph->share_device_memory_with(*(
  130. network_impl->cast_final_safe<NetworkImplDft>().m_load_config.comp_graph));
  131. }
  132. void NetworkImplDft::set_cpu_inplace_mode() {
  133. LITE_ASSERT(
  134. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  135. "cpu inplace mode is only avaliable in CPU.");
  136. m_is_cpu_inplace_mode = true;
  137. if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) {
  138. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
  139. } else {
  140. LITE_ASSERT(
  141. m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD,
  142. "cpu inplace mode is only avaliable in CPU.");
  143. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  144. }
  145. }
  146. void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) {
  147. LITE_ASSERT(
  148. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  149. "multi threads mode is only avaliable in CPU.");
  150. if (nr_threads > 1) {
  151. m_nr_threads = nr_threads;
  152. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  153. m_compnode_locator.nr_threads = nr_threads;
  154. }
  155. }
  156. void NetworkImplDft::set_runtime_thread_affinity(
  157. const ThreadAffinityCallback& thread_affinity_callback) {
  158. LITE_ASSERT(
  159. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  160. "multi threads mode is only avaliable in CPU.");
  161. mgb::CompNode::Locator loc;
  162. m_load_config.comp_node_mapper(loc);
  163. auto cn = mgb::CompNode::load(loc);
  164. if (m_nr_threads > 1) {
  165. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().set_affinity(
  166. thread_affinity_callback);
  167. } else {
  168. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().dispatch(
  169. [thread_affinity_callback](void) { thread_affinity_callback(0); });
  170. }
  171. }
  172. void NetworkImplDft::set_device_id(int device_id) {
  173. m_compnode_locator.device = device_id;
  174. m_user_config->device_id = device_id;
  175. }
  176. void NetworkImplDft::set_stream_id(int stream_id) {
  177. m_compnode_locator.stream = stream_id;
  178. }
  179. void NetworkImplDft::use_tensorrt() {
  180. auto&& options = m_load_config.comp_graph->options();
  181. options.graph_opt.tensorrt = true;
  182. }
  183. //! set the callback in async model
  184. void NetworkImplDft::set_async_callback(const AsyncCallback& callback) {
  185. LITE_ASSERT(!m_is_cpu_inplace_mode, "cpu inplace mode not support async mode");
  186. LITE_ASSERT(
  187. m_user_config->device_type == LiteDeviceType::LITE_CPU ||
  188. m_user_config->device_type == LiteDeviceType::LITE_CUDA,
  189. "Now only cpu and cuda>10.0 support async mode");
  190. m_async = true;
  191. m_async_callback = std::move(callback);
  192. }
  193. void NetworkImplDft::make_output_spec() {
  194. m_output_spec.clear();
  195. for (auto&& out : m_network_io->outputs) {
  196. if (m_load_result.output_var_map.count(out.name)) {
  197. auto&& load_out = m_load_result.output_var_map[out.name];
  198. auto cb = [&out, this](const mgb::DeviceTensorND& dv) mutable {
  199. mgb::CompNode comp_node = dv.comp_node();
  200. if (out.io_type == LiteIOType::LITE_IO_SHAPE) {
  201. auto mgb_layout = dv.layout();
  202. out.lite_tensor->set_layout(to_lite_layout(mgb_layout));
  203. } else {
  204. TensorHelper::implement(out.lite_tensor)
  205. ->cast_final_safe<TensorImplDft>()
  206. .copy_from_mge_tensor(dv);
  207. out.lite_tensor->update_from_implement();
  208. }
  209. if (m_async) {
  210. out.have_sync = true;
  211. bool need_exec_cb = true;
  212. for (auto&& j : m_network_io->outputs) {
  213. if (!j.have_sync) {
  214. need_exec_cb = false;
  215. }
  216. }
  217. if (need_exec_cb) {
  218. for (auto&& j : m_network_io->outputs) {
  219. j.have_sync = false;
  220. }
  221. comp_node.add_callback([this]() { finish(); });
  222. }
  223. }
  224. };
  225. //! if write to user-specified memory, the CallbackCaller must be nullptr.
  226. if (m_user_config->options.force_output_use_user_specified_memory ||
  227. m_user_config->options.force_output_dynamic_alloc) {
  228. m_output_spec.emplace_back(load_out, nullptr);
  229. } else {
  230. m_output_spec.emplace_back(load_out, std::move(cb));
  231. }
  232. } else {
  233. LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str()));
  234. }
  235. }
  236. }
  237. void NetworkImplDft::replace_dev_input_pass() {
  238. mgb::CompNode::Locator locator;
  239. m_load_config.comp_node_mapper(locator);
  240. //! CPU is not need use device input
  241. if (locator.type == mgb::CompNode::DeviceType::CPU) {
  242. return;
  243. }
  244. //! repalce the H2D with VolatileSharedDeviceTensor, and keep the dev tensor
  245. //! in m_network_io.input, user can directly change the dev tensor
  246. //! storage through m_network_io.input.lite_tensor->reset() befor forward
  247. using DeviceTensorMap =
  248. std::unordered_map<std::string, std::shared_ptr<mgb::DeviceTensorND>>;
  249. DeviceTensorMap name2dev_tensor;
  250. mgb::ThinHashMap<mgb::HostTensorND*, mgb::SymbolVar> host_val2var;
  251. //! construct host_val2var that maps from host tensor to corresponding var
  252. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  253. if (opr->same_type<mgb::opr::Host2DeviceCopy>()) {
  254. mgb::HostTensorND* tensor =
  255. opr->cast_final<mgb::opr::Host2DeviceCopy>().host_data().get();
  256. host_val2var[tensor] = opr->output(0);
  257. }
  258. };
  259. mgb::cg::DepOprIter dep_iter{on_opr};
  260. for (auto i : m_load_result.output_var_list) {
  261. dep_iter.add(i.node()->owner_opr());
  262. }
  263. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> inp_var_map, out_var_map;
  264. mgb::SmallVector<std::string> to_clear;
  265. for (auto&& config_in : m_network_io->inputs) {
  266. if (!config_in.is_host) {
  267. auto host_val = m_load_result.tensor_map[config_in.name];
  268. auto dev_val = TensorHelper::implement(config_in.lite_tensor)
  269. ->cast_final_safe<TensorImplDft>()
  270. .m_dev_tensor;
  271. auto dev_var = mgb::opr::VolatileSharedDeviceTensor::make(
  272. *m_load_result.graph, dev_val, {config_in.name});
  273. inp_var_map[host_val2var.at(host_val.get())] = dev_var;
  274. name2dev_tensor[config_in.name] = dev_val;
  275. }
  276. }
  277. auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map);
  278. for (size_t i = 0; i < new_ovar.size(); ++i) {
  279. out_var_map[m_load_result.output_var_list[i]] = new_ovar[i];
  280. }
  281. for (auto&& i : m_load_result.output_var_map) {
  282. i.second = out_var_map.at(i.second);
  283. }
  284. for (auto&& i : m_load_result.output_var_map_id) {
  285. i.second = out_var_map.at(i.second);
  286. }
  287. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  288. new_ovar[i].rename(m_load_result.output_var_list[i].node()->name());
  289. }
  290. m_load_result.output_var_list = std::move(new_ovar);
  291. }
  292. void NetworkImplDft::cross_compnode_model_detect() {
  293. mgb::ThinHashSet<LiteDeviceType> nr_used_device_type;
  294. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  295. for (auto j : opr->output()) {
  296. if (j->comp_node() != mgb::CompNode::default_cpu()) {
  297. nr_used_device_type.insert(
  298. get_device_from_locator(j->comp_node().locator()));
  299. }
  300. }
  301. };
  302. mgb::cg::DepOprIter dep_iter{on_opr};
  303. for (auto i : m_load_result.output_var_list) {
  304. dep_iter.add(i.node()->owner_opr());
  305. }
  306. m_nr_device_type = nr_used_device_type.size();
  307. }
  308. void NetworkImplDft::adapt_option_valid() {
  309. auto&& options = m_load_config.comp_graph->options();
  310. if (m_user_config->options.force_output_use_user_specified_memory) {
  311. for (auto&& out : m_load_result.output_var_list) {
  312. auto opr = out.node()->owner_opr();
  313. //! all the dest operator inherit from ReadonlyFwdHelper can't
  314. //! support force_output_use_user_specified_memory options
  315. if (opr->try_cast_final<mgb::opr::Reshape>() ||
  316. opr->try_cast_final<mgb::opr::Broadcast>() ||
  317. opr->try_cast_final<mgb::opr::Subtensor>() ||
  318. opr->try_cast_final<mgb::opr::AxisAddRemove>() ||
  319. opr->try_cast_final<mgb::opr::Dimshuffle>()) {
  320. m_user_config->options.force_output_use_user_specified_memory = false;
  321. options.force_output_use_user_specified_memory = false;
  322. LITE_WARN(
  323. "detect the unsupported dest operator %s when config "
  324. "force_output_use_user_specified_memory, set "
  325. "force_output_use_user_specified_memory to false\n",
  326. opr->cname());
  327. break;
  328. }
  329. }
  330. }
  331. }
  332. void NetworkImplDft::global_layout_transform() {
  333. if (m_set_layout_transform) {
  334. m_load_result.output_var_list = mgb::gopt::layout_transform(
  335. m_load_result.output_var_list, m_layout_transform_target);
  336. }
  337. }
  338. void NetworkImplDft::load_model(
  339. std::shared_ptr<void> model_mem, size_t size,
  340. std::unordered_map<std::string, LiteAny> separate_config_map) {
  341. if (!m_loader) {
  342. m_input_file =
  343. mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
  344. m_format = mgb::serialization::GraphLoader::identify_graph_dump_format(
  345. *m_input_file);
  346. if (!m_format.valid()) {
  347. LITE_THROW("invalid model format");
  348. }
  349. m_loader = mgb::serialization::GraphLoader::make(
  350. std::move(m_input_file), m_format.val());
  351. }
  352. //! applay the user configration to mge model
  353. application_config();
  354. //! config some flag get from json config file
  355. if (separate_config_map.find("device_id") != separate_config_map.end()) {
  356. set_device_id(separate_config_map["device_id"].safe_cast<int>());
  357. }
  358. if (separate_config_map.find("number_threads") != separate_config_map.end() &&
  359. separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) {
  360. set_cpu_threads_number(
  361. separate_config_map["number_threads"].safe_cast<uint32_t>());
  362. }
  363. if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() &&
  364. separate_config_map["enable_inplace_model"].safe_cast<bool>()) {
  365. set_cpu_inplace_mode();
  366. }
  367. if (separate_config_map.find("use_tensorrt") != separate_config_map.end() &&
  368. separate_config_map["use_tensorrt"].safe_cast<bool>()) {
  369. use_tensorrt();
  370. }
  371. m_load_result = m_loader->load(m_load_config, false);
  372. global_layout_transform();
  373. adapt_option_valid();
  374. cross_compnode_model_detect();
  375. //! update the IO of the network
  376. update_io();
  377. //! replace the IO when there is device input or output
  378. compile_graph();
  379. }
  380. void NetworkImplDft::compile_graph() {
  381. modify_exection_policy();
  382. replace_dev_input_pass();
  383. make_output_spec();
  384. m_execute_func = m_load_result.graph_compile(m_output_spec);
  385. }
  386. void NetworkImplDft::start() const {
  387. if (m_start_callback) {
  388. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  389. input_io_map;
  390. for (auto&& io_inner : m_network_io->inputs) {
  391. input_io_map[io_inner.name] = {
  392. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  393. io_inner.config_layout},
  394. io_inner.lite_tensor};
  395. }
  396. m_start_callback(input_io_map);
  397. }
  398. }
  399. void NetworkImplDft::forward() {
  400. start();
  401. if (m_load_config.comp_graph &&
  402. m_user_config->options.comp_node_seq_record_level == 2) {
  403. m_load_config.comp_graph.reset();
  404. }
  405. LITE_ASSERT(m_execute_func, "forward must be called after network loaded.");
  406. m_execute_func->execute();
  407. }
  408. void NetworkImplDft::wait() {
  409. if (!m_async) {
  410. m_execute_func->wait();
  411. }
  412. finish();
  413. }
  414. void NetworkImplDft::finish() const {
  415. if (m_async) {
  416. LITE_ASSERT(m_async_callback, "The callback func must set when async mode.");
  417. m_async_callback();
  418. }
  419. if (m_finish_callback) {
  420. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  421. output_io_map;
  422. for (auto&& io_inner : m_network_io->outputs) {
  423. output_io_map[io_inner.name] = {
  424. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  425. io_inner.config_layout},
  426. io_inner.lite_tensor};
  427. }
  428. m_finish_callback(output_io_map);
  429. }
  430. output_plugin_result();
  431. }
  432. void NetworkImplDft::set_io(const NetworkIO& network_io) {
  433. m_network_io = std::make_unique<NetworkIOInner>();
  434. for (auto&& in : network_io.inputs) {
  435. m_network_io->inputs.emplace_back(in);
  436. }
  437. for (auto&& out : network_io.outputs) {
  438. m_network_io->outputs.emplace_back(out);
  439. }
  440. }
  441. void NetworkImplDft::try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var) {
  442. if (var.node()->capable_shape_infer()) {
  443. auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
  444. auto shape = static_infer_mgr.infer_shape_fallible(var.node());
  445. if (!shape) {
  446. LITE_WARN(
  447. "Lite infer output shape failed, maybe the model is "
  448. "dynamic "
  449. "shape.\n");
  450. LITE_ASSERT(
  451. !m_user_config->options.force_output_use_user_specified_memory,
  452. "force_output_use_user_specified_memory can't be used when output "
  453. "shape can't be derived.");
  454. return;
  455. }
  456. Layout layout = to_lite_layout(TensorLayout{*shape, var.dtype()});
  457. tensor->set_layout(layout);
  458. }
  459. }
  460. void NetworkImplDft::update_io() {
  461. update_input();
  462. update_output();
  463. }
  464. void NetworkImplDft::update_input() {
  465. auto device_type = m_user_config->device_type;
  466. auto device_id = m_compnode_locator.device;
  467. auto stream_id = m_compnode_locator.stream;
  468. //! if cpu all input and output are host
  469. if (device_type == LiteDeviceType::LITE_CPU) {
  470. for (auto&& in : m_network_io->inputs) {
  471. in.is_host = true;
  472. }
  473. }
  474. //! if cross compnode model, modify the device input if it is not valid
  475. if (m_nr_device_type > 1) {
  476. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  477. for (auto&& config_in : m_network_io->inputs) {
  478. //! if tensor is set to device input
  479. if (in_tensor_iter.first == config_in.name && !config_in.is_host) {
  480. //! if the origin compnode of the tensor is not the device,
  481. //! set the input to host
  482. if (get_device_from_locator(
  483. in_tensor_iter.second->comp_node().locator()) ==
  484. LiteDeviceType::LITE_CPU) {
  485. config_in.is_host = true;
  486. LITE_WARN(
  487. "The input tensor %s of the cross device model "
  488. "should not from device.",
  489. config_in.name.c_str());
  490. }
  491. }
  492. }
  493. }
  494. }
  495. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  496. bool found = false;
  497. for (auto&& config_in : m_network_io->inputs) {
  498. if (in_tensor_iter.first == config_in.name) {
  499. found = true;
  500. if (config_in.is_host) {
  501. config_in.lite_tensor = std::make_shared<Tensor>(
  502. device_id, stream_id, device_type, true);
  503. TensorHelper::implement(config_in.lite_tensor)
  504. ->cast_final_safe<TensorImplDft>()
  505. .m_host_tensor = in_tensor_iter.second;
  506. config_in.lite_tensor->update_from_implement();
  507. } else {
  508. config_in.lite_tensor =
  509. std::make_shared<Tensor>(device_id, stream_id, device_type);
  510. config_in.lite_tensor->set_layout(
  511. to_lite_layout(in_tensor_iter.second->layout()));
  512. }
  513. TensorHelper::implement(config_in.lite_tensor)
  514. ->cast_final_safe<TensorImplDft>()
  515. .m_record_reset =
  516. m_user_config->options.comp_node_seq_record_level > 0;
  517. if (config_in.config_layout.ndim &&
  518. !(config_in.config_layout == config_in.lite_tensor->get_layout())) {
  519. config_in.lite_tensor->set_layout(config_in.config_layout);
  520. }
  521. }
  522. }
  523. if (!found) {
  524. IOInner io_in;
  525. io_in.name = in_tensor_iter.first;
  526. io_in.lite_tensor =
  527. std::make_shared<Tensor>(device_id, stream_id, device_type, true);
  528. TensorHelper::implement(io_in.lite_tensor)
  529. ->cast_final_safe<TensorImplDft>()
  530. .m_host_tensor = in_tensor_iter.second;
  531. TensorHelper::implement(io_in.lite_tensor)
  532. ->cast_final_safe<TensorImplDft>()
  533. .m_record_reset =
  534. m_user_config->options.comp_node_seq_record_level > 0;
  535. io_in.lite_tensor->update_from_implement();
  536. m_network_io->inputs.push_back(io_in);
  537. }
  538. }
  539. //! delete the IO that is not the network
  540. for (auto it = m_network_io->inputs.begin(); it != m_network_io->inputs.end();) {
  541. if (it->lite_tensor == nullptr) {
  542. LITE_LOG("%s is not the network input, ignore it.", it->name.c_str());
  543. it = m_network_io->inputs.erase(it);
  544. } else {
  545. it++;
  546. }
  547. }
  548. }
  549. void NetworkImplDft::update_output() {
  550. auto device_type = m_user_config->device_type;
  551. auto device_id = m_compnode_locator.device;
  552. auto stream_id = m_compnode_locator.stream;
  553. if (device_type == LiteDeviceType::LITE_CPU) {
  554. for (auto&& out : m_network_io->outputs) {
  555. out.is_host = true;
  556. }
  557. }
  558. //! delete the output that is not the network
  559. for (auto out_it = m_network_io->outputs.begin();
  560. out_it != m_network_io->outputs.end();) {
  561. if (std::find_if(
  562. m_load_result.output_var_list.begin(),
  563. m_load_result.output_var_list.end(), [out_it](const SymbolVar var) {
  564. return var.node()->name() == out_it->name;
  565. }) == m_load_result.output_var_list.end()) {
  566. LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str());
  567. out_it = m_network_io->outputs.erase(out_it);
  568. } else {
  569. out_it++;
  570. }
  571. }
  572. //! user config the output tensor, so only compute the config output
  573. if (m_compute_configured_output_only) {
  574. LITE_ASSERT(
  575. m_network_io->outputs.size() > 0,
  576. "compute configured output only with no configure output.");
  577. for (auto out_it = m_network_io->outputs.begin();
  578. out_it != m_network_io->outputs.end(); out_it++) {
  579. //! use pinned memory to copy form device
  580. if (out_it->is_host) {
  581. out_it->lite_tensor = std::make_shared<Tensor>(
  582. device_id, stream_id, device_type, true);
  583. } else {
  584. out_it->lite_tensor =
  585. std::make_shared<Tensor>(device_id, stream_id, device_type);
  586. }
  587. SymbolVar var;
  588. for (auto&& out_var : m_load_result.output_var_list) {
  589. if (out_var.node()->name() == out_it->name) {
  590. var = out_var;
  591. break;
  592. }
  593. }
  594. try_infer_tensor_layout(out_it->lite_tensor, var);
  595. output_tensor_copy_optimize(var, out_it->lite_tensor);
  596. TensorHelper::implement(out_it->lite_tensor)
  597. ->cast_final_safe<TensorImplDft>()
  598. .m_record_reset =
  599. m_user_config->options.comp_node_seq_record_level > 0;
  600. }
  601. //! user not set, use default output
  602. } else {
  603. for (auto&& out : m_load_result.output_var_list) {
  604. std::shared_ptr<Tensor> lite_tensor = nullptr;
  605. auto it = std::find_if(
  606. m_network_io->outputs.begin(), m_network_io->outputs.end(),
  607. [&out](const IOInner io) { return io.name == out.node()->name(); });
  608. if (it != m_network_io->outputs.end()) {
  609. if (it->is_host) {
  610. it->lite_tensor = std::make_shared<Tensor>(
  611. device_id, stream_id, device_type, true);
  612. } else {
  613. it->lite_tensor =
  614. std::make_shared<Tensor>(device_id, stream_id, device_type);
  615. }
  616. try_infer_tensor_layout(it->lite_tensor, out);
  617. lite_tensor = it->lite_tensor;
  618. } else {
  619. IOInner output;
  620. output.name = out.node()->name();
  621. output.lite_tensor = std::make_shared<Tensor>(
  622. device_id, stream_id, device_type, true);
  623. m_network_io->outputs.push_back({output});
  624. try_infer_tensor_layout(output.lite_tensor, out);
  625. lite_tensor = output.lite_tensor;
  626. }
  627. output_tensor_copy_optimize(out, lite_tensor);
  628. TensorHelper::implement(lite_tensor)
  629. ->cast_final_safe<TensorImplDft>()
  630. .m_record_reset =
  631. m_user_config->options.comp_node_seq_record_level > 0;
  632. }
  633. }
  634. }
  635. void NetworkImplDft::output_tensor_copy_optimize(
  636. Var var, std::shared_ptr<Tensor> tensor) {
  637. LITE_ASSERT(
  638. !(m_user_config->options.force_output_use_user_specified_memory &&
  639. m_user_config->options.force_output_dynamic_alloc),
  640. "Can't set force_output_use_user_specified_memory and "
  641. "force_output_dynamic_alloc at the same time.");
  642. if (m_user_config->options.force_output_use_user_specified_memory) {
  643. bool in_record = m_user_config->options.comp_node_seq_record_level > 0;
  644. TensorHelper::implement(tensor)
  645. ->cast_final_safe<TensorImplDft>()
  646. .set_reset_callback([var, in_record](TensorImplDft* dft_tensor) {
  647. dft_tensor->device_share_host_memory();
  648. auto dv = dft_tensor->dev_tensor().get();
  649. dv->comp_node(var.node()->comp_node(), true);
  650. var.node()->init_mem_plan(dv);
  651. if (in_record) {
  652. auto&& device_tensor = var.node()->mutable_dev_tensor();
  653. device_tensor.only_reset_raw_storage(dv->storage());
  654. } else {
  655. var.node()->reset_dev_tensor_from_tensor(*dv);
  656. }
  657. });
  658. }
  659. if (m_user_config->options.force_output_dynamic_alloc) {
  660. TensorHelper::implement(tensor)
  661. ->cast_final_safe<TensorImplDft>()
  662. .set_get_memory_callback([var](TensorImplDft* dft_tensor) {
  663. if (dft_tensor->is_host()) {
  664. auto host_tensor = dft_tensor->m_host_tensor;
  665. *host_tensor =
  666. HostTensorND::make_proxy(var.node()->dev_tensor());
  667. } else {
  668. auto dev_tensor = dft_tensor->m_dev_tensor;
  669. *dev_tensor = var.node()->dev_tensor();
  670. }
  671. });
  672. }
  673. }
  674. std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
  675. std::string io_name, LiteTensorPhase phase) {
  676. if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) {
  677. for (auto&& config_in : m_network_io->inputs) {
  678. if (io_name == config_in.name) {
  679. return config_in.lite_tensor;
  680. }
  681. }
  682. }
  683. if (phase == LiteTensorPhase::LITE_OUTPUT || phase == LiteTensorPhase::LITE_IO) {
  684. for (auto&& config_out : m_network_io->outputs) {
  685. if (io_name == config_out.name) {
  686. config_out.lite_tensor->update_from_implement();
  687. return config_out.lite_tensor;
  688. }
  689. }
  690. }
  691. LITE_THROW(mgb::ssprintf(
  692. "tensor name must be %s input tensor name or the registered "
  693. "output tensor name if NetworkIO is set, if NetworkIO is not set, "
  694. "the output tensor is all the network output tensor, or the output "
  695. "tensor is only the registered tensor.",
  696. io_name.c_str()));
  697. return nullptr;
  698. }
  699. std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) {
  700. return get_io_tensor(get_input_name(index));
  701. }
  702. std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) {
  703. return get_io_tensor(get_output_name(index));
  704. }
  705. //! set opr algorithm selection strategy in the network
  706. void NetworkImplDft::set_network_algo_policy(
  707. LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
  708. bool binary_equal_between_batch) {
  709. using S = megdnn::param::ExecutionPolicy::Strategy;
  710. auto dst_strategy = static_cast<S>(0);
  711. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_HEURISTIC) {
  712. dst_strategy = dst_strategy | S::HEURISTIC;
  713. }
  714. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_PROFILE) {
  715. dst_strategy = dst_strategy | S::PROFILE;
  716. }
  717. if (static_cast<uint32_t>(strategy) &
  718. LiteAlgoSelectStrategy::LITE_ALGO_REPRODUCIBLE) {
  719. dst_strategy = dst_strategy | S::REPRODUCIBLE;
  720. }
  721. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
  722. dst_strategy = dst_strategy | S::OPTIMIZED;
  723. }
  724. m_execution_policy = dst_strategy;
  725. auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
  726. fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
  727. fast_run_config.shared_batch_size = shared_batch_size;
  728. if (m_execute_func) {
  729. LITE_WARN(
  730. "set_network_algo_policy maybe cause error after loaded "
  731. "network!!!!");
  732. modify_exection_policy();
  733. }
  734. }
  735. void NetworkImplDft::modify_exection_policy() {
  736. mgb::SymbolVarArray vars;
  737. for (auto i : m_output_spec) {
  738. vars.push_back(i.first);
  739. }
  740. if (static_cast<uint32_t>(m_execution_policy) != 0)
  741. mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
  742. }
  743. //! set opr algorithm selection strategy in the network
  744. void NetworkImplDft::set_network_algo_workspace_limit(size_t workspace_limit) {
  745. mgb::SymbolVarArray vars;
  746. for (auto i : m_output_spec) {
  747. vars.push_back(i.first);
  748. }
  749. mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit);
  750. }
  751. //! get the input tensor name in the order of graph
  752. std::vector<const char*> NetworkImplDft::get_all_output_name() const {
  753. std::vector<const char*> output_names;
  754. for (auto& output : m_network_io->outputs) {
  755. output_names.push_back(output.name.c_str());
  756. }
  757. return output_names;
  758. }
  759. //! get the input tensor name in the order of graph
  760. std::vector<const char*> NetworkImplDft::get_all_input_name() const {
  761. std::vector<const char*> input_names;
  762. for (auto& input : m_load_result.tensor_map) {
  763. input_names.push_back(input.first.c_str());
  764. }
  765. return input_names;
  766. }
  767. //! get the output tensor name in the order of graph
  768. const char* NetworkImplDft::get_output_name(size_t index) const {
  769. LITE_ASSERT(
  770. index < m_load_result.output_var_list.size(),
  771. "The output tensor index is large than the total outputs number.");
  772. return m_load_result.output_var_list[index].node()->name().c_str();
  773. }
  774. //! get the input tensor name in the order of graph
  775. const char* NetworkImplDft::get_input_name(size_t index) const {
  776. LITE_ASSERT(
  777. index < m_load_result.tensor_map.size(),
  778. "The input tensor index is large than the total inputs number.");
  779. size_t i = 0;
  780. for (auto& input : m_load_result.tensor_map) {
  781. if (i == index) {
  782. return input.first.c_str();
  783. }
  784. i++;
  785. }
  786. LITE_THROW(ssprintf("no input tensor of index %zu.", index));
  787. }
  788. //! Plugin part
  789. void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
  790. #if MGB_ENABLE_JSON
  791. m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
  792. m_profiler_output_file = profile_json_file;
  793. #else
  794. LITE_MARK_USED_VAR(profile_json_file);
  795. LITE_THROW("JSON is disable at compile time.");
  796. #endif
  797. }
  798. void NetworkImplDft::enable_io_txt_dump(std::string io_txt_out_file) {
  799. auto iodump = std::make_unique<mgb::TextOprIODump>(
  800. m_load_config.comp_graph.get(), io_txt_out_file.c_str());
  801. iodump->print_addr(false);
  802. m_iodump = std::move(iodump);
  803. }
  804. void NetworkImplDft::enable_io_bin_dump(std::string io_bin_out_dir) {
  805. m_iodump = std::make_unique<mgb::BinaryOprIODump>(
  806. m_load_config.comp_graph.get(), io_bin_out_dir.c_str());
  807. }
  808. void inline NetworkImplDft::output_plugin_result() const {
  809. #if MGB_ENABLE_JSON
  810. if (m_profiler && m_execute_func) {
  811. m_profiler->to_json_full(m_execute_func.get())
  812. ->writeto_fpath(m_profiler_output_file);
  813. }
  814. #endif
  815. }
  816. void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const {
  817. #ifndef __IN_TEE_ENV__
  818. #if MGB_ENABLE_JSON
  819. m_execute_func->get_static_memory_alloc_info(log_dir);
  820. return;
  821. #endif
  822. #endif
  823. LITE_MARK_USED_VAR(log_dir);
  824. }
  825. void NetworkImplDft::enable_global_layout_transform() {
  826. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  827. switch (m_user_config->device_type) {
  828. case LiteDeviceType::LITE_CPU:
  829. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
  830. break;
  831. case LiteDeviceType::LITE_CUDA:
  832. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
  833. break;
  834. default:
  835. m_layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
  836. LITE_WARN(
  837. "lite compnode type: enum value: %d. is unspecial for layout "
  838. "transform",
  839. (int)(m_user_config->device_type));
  840. }
  841. m_set_layout_transform = true;
  842. }
  843. void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_path) {
  844. if (m_set_layout_transform) {
  845. auto out_file = mgb::serialization::OutputFile::make_fs(
  846. optimized_model_path.c_str(), 'w');
  847. using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
  848. DumpConfig config{1, false, false};
  849. auto dumper = mgb::serialization::GraphDumper::make(
  850. std::move(out_file), m_format.val());
  851. dumper->dump(m_load_result.output_var_list, config);
  852. } else {
  853. LITE_THROW(
  854. ssprintf("dump layout transform model should call "
  855. "enable_global_layout_transform before"));
  856. }
  857. }
  858. #endif
  859. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}