You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_impl.cpp 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793
  1. /**
  2. * \file src/mge/network_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite_build_config.h"
  12. #if LITE_BUILD_WITH_MGE
  13. #include "common.h"
  14. #include "lite/network.h"
  15. #include "memory_allocator.h"
  16. #include "network_impl.h"
  17. #include "parse_info/parse_info_base.h"
  18. #include "parse_model/model_parser.h"
  19. #include "megbrain/common.h"
  20. #include "megbrain/comp_node.h"
  21. #include "megbrain/comp_node_env.h"
  22. #include "megbrain/gopt/inference.h"
  23. #include "megbrain/graph.h"
  24. #include "megbrain/graph/cg.h"
  25. #include "megbrain/opr/io.h"
  26. #include "megbrain/tensor.h"
  27. #if MGB_OPENCL
  28. #include "megcore_opencl.h"
  29. #endif
  30. #include <fstream>
  31. #include <memory>
  32. #include <set>
  33. using namespace lite;
  34. using namespace mgb;
  35. LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);
  36. void NetworkImplDft::set_config(const Config& config) {
  37. m_user_config = std::make_unique<Config>();
  38. *m_user_config = config;
  39. m_load_config.comp_graph = mgb::ComputingGraph::make();
  40. m_compnode_locator = to_compnode_locator(m_user_config->device_type);
  41. m_compnode_locator.device = config.device_id;
  42. }
  43. void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
  44. application_config();
  45. const auto& src_impl = src_network->cast_final_safe<NetworkImplDft>();
  46. LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded.");
  47. m_load_result = src_impl.m_loader->load(m_load_config, true);
  48. //! flag weather the mode is cross compnode model
  49. cross_compnode_model_detect();
  50. //! update the IO of the network
  51. update_io();
  52. //! replace the IO when there is device input or output
  53. compile_graph();
  54. }
  55. void NetworkImplDft::application_config() {
  56. auto device_type = m_user_config->device_type;
  57. m_compnode_locator.type = to_compnode_locator(device_type).type;
  58. m_compnode_locator.device = m_user_config->device_id;
  59. if (m_nr_threads > 1 && device_type == LiteDeviceType::LITE_CPU) {
  60. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  61. m_compnode_locator.device = m_user_config->device_id;
  62. }
  63. //! model options
  64. #define ConfigOption(mge_name, lite_name) \
  65. options.mge_name = m_user_config->options.lite_name;
  66. auto&& options = m_load_config.comp_graph->options();
  67. ConfigOption(graph_opt.weight_preprocess, weight_preprocess);
  68. ConfigOption(graph_opt.fuse_preprocess, fuse_preprocess);
  69. ConfigOption(fake_next_exec, fake_next_exec);
  70. ConfigOption(var_sanity_check_first_run, var_sanity_check_first_run);
  71. m_load_config.const_var_shape = m_user_config->options.const_shape;
  72. ConfigOption(force_dynamic_alloc, force_dynamic_alloc);
  73. ConfigOption(force_output_dynamic_alloc, force_output_dynamic_alloc);
  74. ConfigOption(no_profiling_on_shape_change, no_profiling_on_shape_change);
  75. LITE_ASSERT(
  76. m_user_config->options.jit_level == 0 ||
  77. (m_user_config->options.jit_level > 0 &&
  78. device_type == LiteDeviceType::LITE_CUDA),
  79. "jit only support in cuda device.");
  80. ConfigOption(graph_opt.jit, jit_level);
  81. ConfigOption(comp_node_seq_record_level, comp_node_seq_record_level);
  82. ConfigOption(graph_opt_level, graph_opt_level);
  83. ConfigOption(async_exec_level, async_exec_level);
  84. #undef ConfigOption
  85. #define ConfigOptionLayoutTransform(name) \
  86. if (m_user_config->options.name) { \
  87. options.graph_opt.name(); \
  88. }
  89. ConfigOptionLayoutTransform(enable_nchw44);
  90. ConfigOptionLayoutTransform(enable_nchw44_dot);
  91. ConfigOptionLayoutTransform(enable_nchw88);
  92. ConfigOptionLayoutTransform(enable_nhwcd4);
  93. ConfigOptionLayoutTransform(enable_nchw4);
  94. ConfigOptionLayoutTransform(enable_nchw32);
  95. ConfigOptionLayoutTransform(enable_nchw64);
  96. #undef ConfigOptionLayoutTransform
  97. if (m_user_config->has_compression) {
  98. m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
  99. }
  100. //! if device is LITE_NONE, the compnode information is stored in model
  101. if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
  102. //! currently not set Locator type because an atlas mgb model is a
  103. //! cross-compnode graph
  104. if (device_type == LiteDeviceType::LITE_ATLAS) {
  105. m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
  106. if (loc.type == mgb::CompNode::DeviceType::ATLAS) {
  107. loc.device = m_compnode_locator.device;
  108. loc.stream = m_compnode_locator.stream;
  109. } else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
  110. loc.stream = m_nr_threads;
  111. }
  112. };
  113. } else {
  114. m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
  115. loc = m_compnode_locator;
  116. };
  117. }
  118. }
  119. }
  120. void NetworkImplDft::set_memory_allocator(std::shared_ptr<Allocator> user_allocator) {
  121. auto allocator = std::make_shared<UserStaticMemAlloc>(user_allocator);
  122. LITE_ASSERT(m_load_config.comp_graph);
  123. m_load_config.comp_graph->set_device_memory_allocator(allocator);
  124. }
  125. //! share the runtime memory with other network, the weights is not shared
  126. void NetworkImplDft::share_runtime_memory_with(Network::NetworkImplBase* network_impl) {
  127. LITE_ASSERT(network_impl);
  128. LITE_ASSERT(m_load_config.comp_graph);
  129. m_load_config.comp_graph->share_device_memory_with(*(
  130. network_impl->cast_final_safe<NetworkImplDft>().m_load_config.comp_graph));
  131. }
  132. void NetworkImplDft::set_cpu_inplace_mode() {
  133. LITE_ASSERT(
  134. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  135. "cpu inplace mode is only avaliable in CPU.");
  136. m_is_cpu_inplace_mode = true;
  137. if (m_compnode_locator.type == mgb::CompNode::DeviceType::CPU) {
  138. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
  139. } else {
  140. LITE_ASSERT(
  141. m_compnode_locator.type == CompNode::DeviceType::MULTITHREAD,
  142. "cpu inplace mode is only avaliable in CPU.");
  143. m_compnode_locator.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
  144. }
  145. }
  146. void NetworkImplDft::set_cpu_threads_number(size_t nr_threads) {
  147. LITE_ASSERT(
  148. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  149. "multi threads mode is only avaliable in CPU.");
  150. if (nr_threads > 1) {
  151. m_nr_threads = nr_threads;
  152. m_compnode_locator.type = mgb::CompNode::DeviceType::MULTITHREAD;
  153. m_compnode_locator.nr_threads = nr_threads;
  154. }
  155. }
  156. void NetworkImplDft::set_runtime_thread_affinity(
  157. const ThreadAffinityCallback& thread_affinity_callback) {
  158. LITE_ASSERT(
  159. m_user_config->device_type == LiteDeviceType::LITE_CPU,
  160. "multi threads mode is only avaliable in CPU.");
  161. mgb::CompNode::Locator loc;
  162. m_load_config.comp_node_mapper(loc);
  163. auto cn = mgb::CompNode::load(loc);
  164. if (m_nr_threads > 1) {
  165. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().set_affinity(
  166. thread_affinity_callback);
  167. } else {
  168. mgb::CompNodeEnv::from_comp_node(cn).cpu_env().dispatch(
  169. [thread_affinity_callback](void) { thread_affinity_callback(0); });
  170. }
  171. }
  172. void NetworkImplDft::set_device_id(int device_id) {
  173. m_compnode_locator.device = device_id;
  174. m_user_config->device_id = device_id;
  175. }
  176. void NetworkImplDft::set_stream_id(int stream_id) {
  177. m_compnode_locator.stream = stream_id;
  178. }
  179. void NetworkImplDft::use_tensorrt() {
  180. auto&& options = m_load_config.comp_graph->options();
  181. options.graph_opt.tensorrt = true;
  182. }
  183. //! set the callback in async model
  184. void NetworkImplDft::set_async_callback(const AsyncCallback& callback) {
  185. LITE_ASSERT(!m_is_cpu_inplace_mode, "cpu inplace mode not support async mode");
  186. LITE_ASSERT(
  187. m_user_config->device_type == LiteDeviceType::LITE_CPU ||
  188. m_user_config->device_type == LiteDeviceType::LITE_CUDA,
  189. "Now only cpu and cuda>10.0 support async mode");
  190. m_async = true;
  191. m_async_callback = std::move(callback);
  192. }
  193. void NetworkImplDft::make_output_spec() {
  194. m_output_spec.clear();
  195. for (auto&& out : m_network_io->outputs) {
  196. if (m_load_result.output_var_map.count(out.name)) {
  197. auto&& load_out = m_load_result.output_var_map[out.name];
  198. auto cb = [&out, this](const mgb::DeviceTensorND& dv) mutable {
  199. mgb::CompNode comp_node = dv.comp_node();
  200. if (out.io_type == LiteIOType::LITE_IO_SHAPE) {
  201. auto mgb_layout = dv.layout();
  202. out.lite_tensor->set_layout(to_lite_layout(mgb_layout));
  203. } else {
  204. TensorHelper::implement(out.lite_tensor)
  205. ->cast_final_safe<TensorImplDft>()
  206. .copy_from_mge_tensor(dv);
  207. out.lite_tensor->update_from_implement();
  208. }
  209. if (m_async) {
  210. out.have_sync = true;
  211. bool need_exec_cb = true;
  212. for (auto&& j : m_network_io->outputs) {
  213. if (!j.have_sync) {
  214. need_exec_cb = false;
  215. }
  216. }
  217. if (need_exec_cb) {
  218. for (auto&& j : m_network_io->outputs) {
  219. j.have_sync = false;
  220. }
  221. comp_node.add_callback([this]() { finish(); });
  222. }
  223. }
  224. };
  225. m_output_spec.emplace_back(load_out, std::move(cb));
  226. } else {
  227. LITE_THROW(ssprintf("no output named : %s in the mode", out.name.c_str()));
  228. }
  229. }
  230. }
  231. void NetworkImplDft::replace_dev_input_pass() {
  232. mgb::CompNode::Locator locator;
  233. m_load_config.comp_node_mapper(locator);
  234. //! CPU is not need use device input
  235. if (locator.type == mgb::CompNode::DeviceType::CPU) {
  236. return;
  237. }
  238. //! repalce the H2D with VolatileSharedDeviceTensor, and keep the dev tensor
  239. //! in m_network_io.input, user can directly change the dev tensor
  240. //! storage through m_network_io.input.lite_tensor->reset() befor forward
  241. using DeviceTensorMap =
  242. std::unordered_map<std::string, std::shared_ptr<mgb::DeviceTensorND>>;
  243. DeviceTensorMap name2dev_tensor;
  244. mgb::ThinHashMap<mgb::HostTensorND*, mgb::SymbolVar> host_val2var;
  245. //! construct host_val2var that maps from host tensor to corresponding var
  246. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  247. if (opr->same_type<mgb::opr::Host2DeviceCopy>()) {
  248. mgb::HostTensorND* tensor =
  249. opr->cast_final<mgb::opr::Host2DeviceCopy>().host_data().get();
  250. host_val2var[tensor] = opr->output(0);
  251. }
  252. };
  253. mgb::cg::DepOprIter dep_iter{on_opr};
  254. for (auto i : m_load_result.output_var_list) {
  255. dep_iter.add(i.node()->owner_opr());
  256. }
  257. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> inp_var_map, out_var_map;
  258. mgb::SmallVector<std::string> to_clear;
  259. for (auto&& config_in : m_network_io->inputs) {
  260. if (!config_in.is_host) {
  261. auto host_val = m_load_result.tensor_map[config_in.name];
  262. auto dev_val = TensorHelper::implement(config_in.lite_tensor)
  263. ->cast_final_safe<TensorImplDft>()
  264. .m_dev_tensor;
  265. auto dev_var = mgb::opr::VolatileSharedDeviceTensor::make(
  266. *m_load_result.graph, dev_val, {config_in.name});
  267. inp_var_map[host_val2var.at(host_val.get())] = dev_var;
  268. name2dev_tensor[config_in.name] = dev_val;
  269. }
  270. }
  271. auto new_ovar = mgb::cg::replace_vars(m_load_result.output_var_list, inp_var_map);
  272. for (size_t i = 0; i < new_ovar.size(); ++i) {
  273. out_var_map[m_load_result.output_var_list[i]] = new_ovar[i];
  274. }
  275. for (auto&& i : m_load_result.output_var_map) {
  276. i.second = out_var_map.at(i.second);
  277. }
  278. for (auto&& i : m_load_result.output_var_map_id) {
  279. i.second = out_var_map.at(i.second);
  280. }
  281. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  282. new_ovar[i].rename(m_load_result.output_var_list[i].node()->name());
  283. }
  284. m_load_result.output_var_list = std::move(new_ovar);
  285. }
  286. void NetworkImplDft::cross_compnode_model_detect() {
  287. mgb::ThinHashSet<LiteDeviceType> nr_used_device_type;
  288. auto on_opr = [&](mgb::cg::OperatorNodeBase* opr) {
  289. for (auto j : opr->output()) {
  290. if (j->comp_node() != mgb::CompNode::default_cpu()) {
  291. nr_used_device_type.insert(
  292. get_device_from_locator(j->comp_node().locator()));
  293. }
  294. }
  295. };
  296. mgb::cg::DepOprIter dep_iter{on_opr};
  297. for (auto i : m_load_result.output_var_list) {
  298. dep_iter.add(i.node()->owner_opr());
  299. }
  300. m_nr_device_type = nr_used_device_type.size();
  301. }
  302. void NetworkImplDft::load_model(
  303. std::shared_ptr<void> model_mem, size_t size,
  304. std::unordered_map<std::string, LiteAny> separate_config_map) {
  305. if (!m_loader) {
  306. m_input_file =
  307. mgb::serialization::InputFile::make_mem_proxy(model_mem, size, false);
  308. auto format = mgb::serialization::GraphLoader::identify_graph_dump_format(
  309. *m_input_file);
  310. if (!format.valid()) {
  311. LITE_THROW("invalid model format");
  312. }
  313. m_loader = mgb::serialization::GraphLoader::make(
  314. std::move(m_input_file), format.val());
  315. }
  316. //! applay the user configration to mge model
  317. application_config();
  318. //! config some flag get from json config file
  319. if (separate_config_map.find("device_id") != separate_config_map.end()) {
  320. set_device_id(separate_config_map["device_id"].safe_cast<int>());
  321. }
  322. if (separate_config_map.find("number_threads") != separate_config_map.end() &&
  323. separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) {
  324. set_cpu_threads_number(
  325. separate_config_map["number_threads"].safe_cast<uint32_t>());
  326. }
  327. if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() &&
  328. separate_config_map["enable_inplace_model"].safe_cast<bool>()) {
  329. set_cpu_inplace_mode();
  330. }
  331. if (separate_config_map.find("use_tensorrt") != separate_config_map.end() &&
  332. separate_config_map["use_tensorrt"].safe_cast<bool>()) {
  333. use_tensorrt();
  334. }
  335. m_load_result = m_loader->load(m_load_config, true);
  336. cross_compnode_model_detect();
  337. //! update the IO of the network
  338. update_io();
  339. //! replace the IO when there is device input or output
  340. compile_graph();
  341. }
  342. void NetworkImplDft::compile_graph() {
  343. modify_exection_policy();
  344. replace_dev_input_pass();
  345. make_output_spec();
  346. m_execute_func = m_load_result.graph_compile(m_output_spec);
  347. }
  348. void NetworkImplDft::start() const {
  349. if (m_start_callback) {
  350. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  351. input_io_map;
  352. for (auto&& io_inner : m_network_io->inputs) {
  353. input_io_map[io_inner.name] = {
  354. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  355. io_inner.config_layout},
  356. io_inner.lite_tensor};
  357. }
  358. m_start_callback(input_io_map);
  359. }
  360. }
  361. void NetworkImplDft::forward() {
  362. start();
  363. LITE_ASSERT(m_execute_func, "forward must be called after network loaded.");
  364. m_execute_func->execute();
  365. }
  366. void NetworkImplDft::wait() {
  367. if (!m_async) {
  368. m_execute_func->wait();
  369. }
  370. finish();
  371. }
  372. void NetworkImplDft::finish() const {
  373. if (m_async) {
  374. LITE_ASSERT(m_async_callback, "The callback func must set when async mode.");
  375. m_async_callback();
  376. }
  377. if (m_finish_callback) {
  378. std::unordered_map<std::string, std::pair<IO, std::shared_ptr<Tensor>>>
  379. output_io_map;
  380. for (auto&& io_inner : m_network_io->outputs) {
  381. output_io_map[io_inner.name] = {
  382. IO{io_inner.name, io_inner.is_host, io_inner.io_type,
  383. io_inner.config_layout},
  384. io_inner.lite_tensor};
  385. }
  386. m_finish_callback(output_io_map);
  387. }
  388. output_plugin_result();
  389. }
  390. void NetworkImplDft::set_io(const NetworkIO& network_io) {
  391. m_network_io = std::make_unique<NetworkIOInner>();
  392. for (auto&& in : network_io.inputs) {
  393. m_network_io->inputs.emplace_back(in);
  394. }
  395. for (auto&& out : network_io.outputs) {
  396. m_network_io->outputs.emplace_back(out);
  397. }
  398. }
  399. void NetworkImplDft::try_infer_tensor_layout(
  400. std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var) {
  401. auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
  402. auto infer_trait = var.node()->get_static_infer_trait();
  403. if (std::get<0>(infer_trait)) {
  404. auto shape = static_infer_mgr.infer_shape_fallible(var.node());
  405. if (!shape) {
  406. LITE_WARN(
  407. "Lite infer output shape failed, maybe the model is "
  408. "dynamic "
  409. "shape.\n");
  410. return;
  411. }
  412. Layout layout = to_lite_layout(mgb::TensorLayout{*shape, var.dtype()});
  413. tensor->set_layout(layout);
  414. }
  415. }
  416. void NetworkImplDft::update_io() {
  417. update_input();
  418. update_output();
  419. }
  420. void NetworkImplDft::update_input() {
  421. auto device_type = m_user_config->device_type;
  422. auto device_id = m_compnode_locator.device;
  423. auto stream_id = m_compnode_locator.stream;
  424. //! if cpu all input and output are host
  425. if (device_type == LiteDeviceType::LITE_CPU) {
  426. for (auto&& in : m_network_io->inputs) {
  427. in.is_host = true;
  428. }
  429. }
  430. //! if cross compnode model, modify the device input if it is not valid
  431. if (m_nr_device_type > 1) {
  432. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  433. for (auto&& config_in : m_network_io->inputs) {
  434. //! if tensor is set to device input
  435. if (in_tensor_iter.first == config_in.name && !config_in.is_host) {
  436. //! if the origin compnode of the tensor is not the device,
  437. //! set the input to host
  438. if (get_device_from_locator(
  439. in_tensor_iter.second->comp_node().locator()) ==
  440. LiteDeviceType::LITE_CPU) {
  441. config_in.is_host = true;
  442. LITE_WARN(
  443. "The input tensor %s of the cross device model "
  444. "should not from device.",
  445. config_in.name.c_str());
  446. }
  447. }
  448. }
  449. }
  450. }
  451. for (auto&& in_tensor_iter : m_load_result.tensor_map) {
  452. bool found = false;
  453. for (auto&& config_in : m_network_io->inputs) {
  454. if (in_tensor_iter.first == config_in.name) {
  455. found = true;
  456. if (config_in.is_host) {
  457. config_in.lite_tensor = std::make_shared<Tensor>(
  458. device_id, stream_id, device_type, true);
  459. TensorHelper::implement(config_in.lite_tensor)
  460. ->cast_final_safe<TensorImplDft>()
  461. .m_host_tensor = in_tensor_iter.second;
  462. config_in.lite_tensor->update_from_implement();
  463. } else {
  464. config_in.lite_tensor =
  465. std::make_shared<Tensor>(device_id, stream_id, device_type);
  466. config_in.lite_tensor->set_layout(
  467. to_lite_layout(in_tensor_iter.second->layout()));
  468. }
  469. if (config_in.config_layout.ndim &&
  470. !(config_in.config_layout == config_in.lite_tensor->get_layout())) {
  471. config_in.lite_tensor->set_layout(config_in.config_layout);
  472. }
  473. }
  474. }
  475. if (!found) {
  476. IOInner io_in;
  477. io_in.name = in_tensor_iter.first;
  478. io_in.lite_tensor =
  479. std::make_shared<Tensor>(device_id, stream_id, device_type, true);
  480. TensorHelper::implement(io_in.lite_tensor)
  481. ->cast_final_safe<TensorImplDft>()
  482. .m_host_tensor = in_tensor_iter.second;
  483. io_in.lite_tensor->update_from_implement();
  484. m_network_io->inputs.push_back(io_in);
  485. }
  486. }
  487. //! delete the IO that is not the network
  488. for (auto it = m_network_io->inputs.begin(); it != m_network_io->inputs.end();) {
  489. if (it->lite_tensor == nullptr) {
  490. LITE_LOG("%s is not the network input, ignore it.", it->name.c_str());
  491. it = m_network_io->inputs.erase(it);
  492. } else {
  493. it++;
  494. }
  495. }
  496. }
  497. void NetworkImplDft::update_output() {
  498. auto device_type = m_user_config->device_type;
  499. auto device_id = m_compnode_locator.device;
  500. auto stream_id = m_compnode_locator.stream;
  501. if (device_type == LiteDeviceType::LITE_CPU) {
  502. for (auto&& out : m_network_io->outputs) {
  503. out.is_host = true;
  504. }
  505. }
  506. //! delete the output that is not the network
  507. for (auto out_it = m_network_io->outputs.begin();
  508. out_it != m_network_io->outputs.end();) {
  509. if (std::find_if(
  510. m_load_result.output_var_list.begin(),
  511. m_load_result.output_var_list.end(),
  512. [out_it](const mgb::SymbolVar var) {
  513. return var.node()->name() == out_it->name;
  514. }) == m_load_result.output_var_list.end()) {
  515. LITE_LOG("%s is not the network output, ignore it.", out_it->name.c_str());
  516. out_it = m_network_io->outputs.erase(out_it);
  517. } else {
  518. out_it++;
  519. }
  520. }
  521. //! user config the output tensor, so only compute the config output
  522. if (m_compute_configured_output_only) {
  523. LITE_ASSERT(
  524. m_network_io->outputs.size() > 0,
  525. "compute configured output only with no configure output.");
  526. for (auto out_it = m_network_io->outputs.begin();
  527. out_it != m_network_io->outputs.end(); out_it++) {
  528. //! use pinned memory to copy form device
  529. if (out_it->is_host) {
  530. out_it->lite_tensor = std::make_shared<Tensor>(
  531. device_id, stream_id, device_type, true);
  532. } else {
  533. out_it->lite_tensor =
  534. std::make_shared<Tensor>(device_id, stream_id, device_type);
  535. }
  536. mgb::SymbolVar var;
  537. for (auto&& out_var : m_load_result.output_var_list) {
  538. if (out_var.node()->name() == out_it->name) {
  539. var = out_var;
  540. break;
  541. }
  542. }
  543. try_infer_tensor_layout(out_it->lite_tensor, var);
  544. }
  545. //! user not set, use default output
  546. } else {
  547. for (auto&& out : m_load_result.output_var_list) {
  548. auto it = std::find_if(
  549. m_network_io->outputs.begin(), m_network_io->outputs.end(),
  550. [&out](const IOInner io) { return io.name == out.node()->name(); });
  551. if (it != m_network_io->outputs.end()) {
  552. if (it->is_host) {
  553. it->lite_tensor = std::make_shared<Tensor>(
  554. device_id, stream_id, device_type, true);
  555. } else {
  556. it->lite_tensor =
  557. std::make_shared<Tensor>(device_id, stream_id, device_type);
  558. }
  559. try_infer_tensor_layout(it->lite_tensor, out);
  560. } else {
  561. IOInner output;
  562. output.name = out.node()->name();
  563. output.lite_tensor = std::make_shared<Tensor>(
  564. device_id, stream_id, device_type, true);
  565. m_network_io->outputs.push_back({output});
  566. try_infer_tensor_layout(output.lite_tensor, out);
  567. }
  568. }
  569. }
  570. }
  571. std::shared_ptr<Tensor> NetworkImplDft::get_io_tensor(
  572. std::string io_name, LiteTensorPhase phase) {
  573. if (phase == LiteTensorPhase::LITE_INPUT || phase == LiteTensorPhase::LITE_IO) {
  574. for (auto&& config_in : m_network_io->inputs) {
  575. if (io_name == config_in.name) {
  576. return config_in.lite_tensor;
  577. }
  578. }
  579. }
  580. if (phase == LiteTensorPhase::LITE_OUTPUT || phase == LiteTensorPhase::LITE_IO) {
  581. for (auto&& config_out : m_network_io->outputs) {
  582. if (io_name == config_out.name) {
  583. config_out.lite_tensor->update_from_implement();
  584. return config_out.lite_tensor;
  585. }
  586. }
  587. }
  588. LITE_THROW(mgb::ssprintf(
  589. "tensor name must be %s input tensor name or the registered "
  590. "output tensor name if NetworkIO is set, if NetworkIO is not set, "
  591. "the output tensor is all the network output tensor, or the output "
  592. "tensor is only the registered tensor.",
  593. io_name.c_str()));
  594. return nullptr;
  595. }
  596. std::shared_ptr<Tensor> NetworkImplDft::get_input_tensor(size_t index) {
  597. return get_io_tensor(get_input_name(index));
  598. }
  599. std::shared_ptr<Tensor> NetworkImplDft::get_output_tensor(size_t index) {
  600. return get_io_tensor(get_output_name(index));
  601. }
  602. //! set opr algorithm selection strategy in the network
  603. void NetworkImplDft::set_network_algo_policy(
  604. LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
  605. bool binary_equal_between_batch) {
  606. using S = megdnn::param::ExecutionPolicy::Strategy;
  607. auto dst_strategy = static_cast<S>(0);
  608. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_HEURISTIC) {
  609. dst_strategy = dst_strategy | S::HEURISTIC;
  610. }
  611. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_PROFILE) {
  612. dst_strategy = dst_strategy | S::PROFILE;
  613. }
  614. if (static_cast<uint32_t>(strategy) &
  615. LiteAlgoSelectStrategy::LITE_ALGO_REPRODUCIBLE) {
  616. dst_strategy = dst_strategy | S::REPRODUCIBLE;
  617. }
  618. if (static_cast<uint32_t>(strategy) & LiteAlgoSelectStrategy::LITE_ALGO_OPTIMIZED) {
  619. dst_strategy = dst_strategy | S::OPTIMIZED;
  620. }
  621. m_execution_policy = dst_strategy;
  622. auto&& fast_run_config = m_load_config.comp_graph->options().fast_run_config;
  623. fast_run_config.binary_equal_between_batch = binary_equal_between_batch;
  624. fast_run_config.shared_batch_size = shared_batch_size;
  625. if (m_execute_func) {
  626. LITE_WARN(
  627. "set_network_algo_policy maybe cause error after loaded "
  628. "network!!!!");
  629. modify_exection_policy();
  630. }
  631. }
  632. void NetworkImplDft::modify_exection_policy() {
  633. mgb::SymbolVarArray vars;
  634. for (auto i : m_output_spec) {
  635. vars.push_back(i.first);
  636. }
  637. if (static_cast<uint32_t>(m_execution_policy) != 0)
  638. mgb::gopt::modify_opr_algo_strategy_inplace(vars, m_execution_policy);
  639. }
  640. //! set opr algorithm selection strategy in the network
  641. void NetworkImplDft::set_network_algo_workspace_limit(size_t workspace_limit) {
  642. mgb::SymbolVarArray vars;
  643. for (auto i : m_output_spec) {
  644. vars.push_back(i.first);
  645. }
  646. mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit);
  647. }
  648. //! get the input tensor name in the order of graph
  649. std::vector<const char*> NetworkImplDft::get_all_output_name() const {
  650. std::vector<const char*> output_names;
  651. for (auto& output : m_network_io->outputs) {
  652. output_names.push_back(output.name.c_str());
  653. }
  654. return output_names;
  655. }
  656. //! get the input tensor name in the order of graph
  657. std::vector<const char*> NetworkImplDft::get_all_input_name() const {
  658. std::vector<const char*> input_names;
  659. for (auto& input : m_load_result.tensor_map) {
  660. input_names.push_back(input.first.c_str());
  661. }
  662. return input_names;
  663. }
  664. //! get the output tensor name in the order of graph
  665. const char* NetworkImplDft::get_output_name(size_t index) const {
  666. LITE_ASSERT(
  667. index < m_load_result.output_var_list.size(),
  668. "The output tensor index is large than the total outputs number.");
  669. return m_load_result.output_var_list[index].node()->name().c_str();
  670. }
  671. //! get the input tensor name in the order of graph
  672. const char* NetworkImplDft::get_input_name(size_t index) const {
  673. LITE_ASSERT(
  674. index < m_load_result.tensor_map.size(),
  675. "The input tensor index is large than the total inputs number.");
  676. size_t i = 0;
  677. for (auto& input : m_load_result.tensor_map) {
  678. if (i == index) {
  679. return input.first.c_str();
  680. }
  681. i++;
  682. }
  683. LITE_THROW(ssprintf("no input tensor of index %zu.", index));
  684. }
  685. //! Plugin part
  686. void NetworkImplDft::enable_profile_performance(std::string profile_json_file) {
  687. #if MGB_ENABLE_JSON
  688. #if MGB_OPENCL
  689. mgb::CompNode::enable_opencl_profile(true);
  690. #endif
  691. m_profiler = std::make_unique<mgb::GraphProfiler>(m_load_config.comp_graph.get());
  692. m_profiler_output_file = profile_json_file;
  693. #else
  694. LITE_MARK_USED_VAR(profile_json_file);
  695. LITE_THROW("JSON is disable at compile time.");
  696. #endif
  697. }
  698. void NetworkImplDft::enable_io_txt_dump(std::string io_txt_out_file) {
  699. auto iodump = std::make_unique<mgb::TextOprIODump>(
  700. m_load_config.comp_graph.get(), io_txt_out_file.c_str());
  701. iodump->print_addr(false);
  702. m_iodump = std::move(iodump);
  703. }
  704. void NetworkImplDft::enable_io_bin_dump(std::string io_bin_out_dir) {
  705. m_iodump = std::make_unique<mgb::BinaryOprIODump>(
  706. m_load_config.comp_graph.get(), io_bin_out_dir.c_str());
  707. }
  708. void inline NetworkImplDft::output_plugin_result() const {
  709. #if MGB_ENABLE_JSON
  710. if (m_profiler && m_execute_func) {
  711. m_profiler->to_json_full(m_execute_func.get())
  712. ->writeto_fpath(m_profiler_output_file);
  713. }
  714. #endif
  715. }
  716. void NetworkImplDft::get_static_memory_alloc_info(const std::string& log_dir) const {
  717. #ifndef __IN_TEE_ENV__
  718. #if MGB_ENABLE_JSON
  719. m_execute_func->get_static_memory_alloc_info(log_dir);
  720. return;
  721. #endif
  722. #endif
  723. LITE_MARK_USED_VAR(log_dir);
  724. }
  725. #endif
  726. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台