You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. /**
  2. * \file lite-c/src/network.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/network.h"
  12. #include "common.h"
  13. #include "lite-c/network_c.h"
  14. #include "../../src/network_impl_base.h"
  15. #include <memory>
  16. #include <mutex>
  17. #include <unordered_map>
  18. #include <string.h>
  19. //! define a default Options
  20. const LiteOptions default_option = {
  21. .weight_preprocess = false,
  22. .fuse_preprocess = false,
  23. .fake_next_exec = false,
  24. .var_sanity_check_first_run = true,
  25. .const_shape = false,
  26. .force_dynamic_alloc = false,
  27. .force_output_dynamic_alloc = false,
  28. .no_profiling_on_shape_change = false,
  29. .jit_level = 0,
  30. .comp_node_seq_record_level = 0,
  31. .graph_opt_level = 2,
  32. .async_exec_level = 1,
  33. //! layout transform options
  34. .enable_nchw44 = 0,
  35. .enable_nchw44_dot = 0,
  36. .enable_nchw88 = 0,
  37. .enable_nhwcd4 = 0,
  38. .enable_nchw4 = 0,
  39. .enable_nchw32 = 0,
  40. .enable_nchw64 = 0,
  41. };
  42. //! define a default config
  43. LiteConfig default_config_t = {.has_compression = false,
  44. .device_id = -1,
  45. .device_type = LiteDeviceType::LITE_CPU,
  46. .backend = LiteBackend::LITE_DEFAULT,
  47. .bare_model_cryption_name = nullptr,
  48. .options = default_option};
  49. LiteConfig* default_config() {
  50. return &default_config_t;
  51. }
  52. //! define a default IO
  53. const LiteIO default_io = {.name = nullptr,
  54. .is_host = true,
  55. .io_type = LiteIOType::LITE_IO_VALUE,
  56. .config_layout = default_layout};
  57. //! define a default NetworkIO
  58. LiteNetworkIO default_network_io_t = {.inputs = nullptr,
  59. .outputs = nullptr,
  60. .input_size = 0,
  61. .output_size = 0};
  62. LiteNetworkIO* default_network_io() {
  63. return &default_network_io_t;
  64. }
  65. namespace {
  66. std::unordered_map<void*, std::shared_ptr<lite::Network>>&
  67. get_gloabl_network_holder() {
  68. static thread_local std::unordered_map<void*,
  69. std::shared_ptr<lite::Network>>
  70. network_holder;
  71. return network_holder;
  72. }
  73. /*!
  74. * \brief A user-implemented allocator interface
  75. */
  76. class UserAllocator : public lite::Allocator {
  77. public:
  78. UserAllocator(LiteAllocate allocate_func, LiteFree free_func)
  79. : m_allocator(allocate_func), m_free(free_func) {
  80. LITE_ASSERT(m_allocator && m_free);
  81. }
  82. //! allocate memory of size in the given device with the given align
  83. void* allocate(LiteDeviceType device_type, int device_id, size_t size,
  84. size_t align) override {
  85. return m_allocator(device_type, device_id, size, align);
  86. }
  87. //! free the memory pointed by ptr in the given device
  88. void free(LiteDeviceType device_type, int device_id, void* ptr) override {
  89. m_free(device_type, device_id, ptr);
  90. }
  91. private:
  92. LiteAllocate m_allocator;
  93. LiteFree m_free;
  94. };
  95. } // namespace
  96. //! convert c config to lite::config
  97. lite::Config convert_to_lite_config(const LiteConfig c_config) {
  98. lite::Config lite_config;
  99. lite_config.device_type = c_config.device_type;
  100. if (c_config.bare_model_cryption_name) {
  101. lite_config.bare_model_cryption_name =
  102. c_config.bare_model_cryption_name;
  103. }
  104. lite_config.backend = c_config.backend;
  105. lite_config.has_compression = c_config.has_compression;
  106. lite_config.device_id = c_config.device_id;
  107. lite_config.options.weight_preprocess = c_config.options.weight_preprocess;
  108. lite_config.options.fuse_preprocess = c_config.options.fuse_preprocess;
  109. lite_config.options.fake_next_exec = c_config.options.fake_next_exec;
  110. lite_config.options.var_sanity_check_first_run =
  111. c_config.options.var_sanity_check_first_run;
  112. lite_config.options.const_shape = c_config.options.const_shape;
  113. lite_config.options.force_dynamic_alloc = c_config.options.const_shape;
  114. lite_config.options.force_output_dynamic_alloc =
  115. c_config.options.force_output_dynamic_alloc;
  116. lite_config.options.no_profiling_on_shape_change =
  117. c_config.options.no_profiling_on_shape_change;
  118. lite_config.options.jit_level = c_config.options.jit_level;
  119. lite_config.options.comp_node_seq_record_level =
  120. c_config.options.comp_node_seq_record_level;
  121. lite_config.options.graph_opt_level = c_config.options.graph_opt_level;
  122. lite_config.options.async_exec_level = c_config.options.async_exec_level;
  123. lite_config.options.enable_nchw44 = c_config.options.enable_nchw44;
  124. lite_config.options.enable_nchw44_dot = c_config.options.enable_nchw44_dot;
  125. lite_config.options.enable_nchw88 = c_config.options.enable_nchw88;
  126. lite_config.options.enable_nchw4 = c_config.options.enable_nchw4;
  127. lite_config.options.enable_nhwcd4 = c_config.options.enable_nhwcd4;
  128. lite_config.options.enable_nchw32 = c_config.options.enable_nchw32;
  129. lite_config.options.enable_nchw64 = c_config.options.enable_nchw64;
  130. return lite_config;
  131. }
  132. //! convert C NetworkIO io to lite::NetworkIO
  133. lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) {
  134. lite::NetworkIO network_io;
  135. for (size_t i = 0; i < c_network_io.input_size; i++) {
  136. LiteIO* c_io = c_network_io.inputs + i;
  137. LITE_ASSERT(c_io->name, "input name of io tensor must set.");
  138. network_io.inputs.push_back(
  139. {c_io->name, static_cast<bool>(c_io->is_host), c_io->io_type,
  140. convert_to_layout(c_io->config_layout)});
  141. }
  142. for (size_t i = 0; i < c_network_io.output_size; i++) {
  143. LiteIO* c_io = c_network_io.outputs + i;
  144. LITE_ASSERT(c_io->name, "output name of io tensor must set.");
  145. network_io.outputs.push_back(
  146. {c_io->name, static_cast<bool>(c_io->is_host), c_io->io_type,
  147. convert_to_layout(c_io->config_layout)});
  148. }
  149. return network_io;
  150. }
  151. int LITE_make_default_network(LiteNetwork* network) {
  152. LITE_CAPI_BEGIN();
  153. LITE_ASSERT(network, "The network pass to LITE api is null");
  154. auto lite_network = std::make_shared<lite::Network>();
  155. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  156. *network = lite_network.get();
  157. LITE_CAPI_END();
  158. }
  159. int LITE_make_network(LiteNetwork* network, const LiteConfig config,
  160. const LiteNetworkIO network_io) {
  161. LITE_CAPI_BEGIN();
  162. LITE_ASSERT(network, "The network pass to LITE api is null");
  163. auto lite_network = std::make_shared<lite::Network>(
  164. convert_to_lite_config(config), convert_to_lite_io(network_io));
  165. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  166. *network = lite_network.get();
  167. LITE_CAPI_END();
  168. }
  169. int LITE_make_network_config(LiteNetwork* network, const LiteConfig config) {
  170. LITE_CAPI_BEGIN();
  171. LITE_ASSERT(network, "The network pass to LITE api is null");
  172. auto lite_network =
  173. std::make_shared<lite::Network>(convert_to_lite_config(config));
  174. get_gloabl_network_holder()[lite_network.get()] = lite_network;
  175. *network = lite_network.get();
  176. LITE_CAPI_END();
  177. }
  178. int LITE_load_model_from_mem(LiteNetwork network, void* model_mem,
  179. size_t size) {
  180. LITE_CAPI_BEGIN();
  181. LITE_ASSERT(network, "The network pass to LITE api is null");
  182. LITE_ASSERT(model_mem, "The model memory pass to LITE api is null");
  183. static_cast<lite::Network*>(network)->load_model(model_mem, size);
  184. LITE_CAPI_END();
  185. }
  186. int LITE_load_model_from_path(LiteNetwork network, const char* model_path) {
  187. LITE_CAPI_BEGIN();
  188. LITE_ASSERT(network, "The network pass to LITE api is null");
  189. LITE_ASSERT(model_path, "The model path pass to LITE api is null");
  190. static_cast<lite::Network*>(network)->load_model(model_path);
  191. LITE_CAPI_END();
  192. }
  193. int LITE_destroy_network(LiteNetwork network) {
  194. LITE_CAPI_BEGIN();
  195. LITE_ASSERT(network, "The network pass to LITE api is null");
  196. get_gloabl_network_holder().erase(network);
  197. LITE_CAPI_END();
  198. }
  199. int LITE_forward(const LiteNetwork network) {
  200. LITE_CAPI_BEGIN();
  201. LITE_ASSERT(network, "The network pass to LITE api is null");
  202. static_cast<lite::Network*>(network)->forward();
  203. LITE_CAPI_END();
  204. }
  205. int LITE_wait(const LiteNetwork network) {
  206. LITE_CAPI_BEGIN();
  207. LITE_ASSERT(network, "The network pass to LITE api is null");
  208. static_cast<lite::Network*>(network)->wait();
  209. LITE_CAPI_END();
  210. }
  211. int LITE_get_io_tensor(LiteNetwork network, const char* io_name,
  212. LiteTensorPhase phase, LiteTensor* tensor) {
  213. LITE_CAPI_BEGIN();
  214. LITE_ASSERT(network, "The network pass to LITE api is null");
  215. auto io_tensor =
  216. static_cast<lite::Network*>(network)->get_io_tensor(io_name, phase);
  217. *tensor = io_tensor.get();
  218. LITE_CAPI_END();
  219. }
  220. int LITE_get_input_name(const LiteNetwork network, size_t index,
  221. const char** name) {
  222. LITE_CAPI_BEGIN();
  223. LITE_ASSERT(network && name, "The network pass to LITE api is null");
  224. *name = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  225. ->get_input_name(index);
  226. LITE_CAPI_END();
  227. }
  228. int LITE_get_output_name(const LiteNetwork network, size_t index,
  229. const char** name) {
  230. LITE_CAPI_BEGIN();
  231. LITE_ASSERT(network, "The network pass to LITE api is null");
  232. LITE_ASSERT(name, "The name ptr pass to LITE api is null");
  233. *name = lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  234. ->get_output_name(index);
  235. LITE_CAPI_END();
  236. }
  237. int LITE_get_all_input_name(const LiteNetwork network, size_t* size,
  238. const char** name) {
  239. LITE_CAPI_BEGIN();
  240. LITE_ASSERT(network, "The network pass to LITE api is null");
  241. auto&& names =
  242. lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  243. ->get_all_input_name();
  244. if (size)
  245. *size = names.size();
  246. if (name) {
  247. for (auto in_name : names) {
  248. *name = in_name;
  249. name++;
  250. }
  251. }
  252. LITE_CAPI_END();
  253. }
  254. int LITE_get_all_output_name(const LiteNetwork network, size_t* size,
  255. const char** name) {
  256. LITE_CAPI_BEGIN();
  257. LITE_ASSERT(network, "The network pass to LITE api is null");
  258. auto&& names =
  259. lite::NetworkHelper::implement(static_cast<lite::Network*>(network))
  260. ->get_all_output_name();
  261. if (size)
  262. *size = names.size();
  263. if (name) {
  264. for (auto in_name : names) {
  265. *name = in_name;
  266. name++;
  267. }
  268. }
  269. LITE_CAPI_END();
  270. }
  271. int LITE_set_device_id(LiteNetwork network, int device_id) {
  272. LITE_CAPI_BEGIN();
  273. LITE_ASSERT(network, "The network pass to LITE api is null");
  274. static_cast<lite::Network*>(network)->set_device_id(device_id);
  275. LITE_CAPI_END();
  276. }
  277. int LITE_get_device_id(const LiteNetwork network, int* device_id) {
  278. LITE_CAPI_BEGIN();
  279. LITE_ASSERT(network, "The network pass to LITE api is null");
  280. LITE_ASSERT(device_id, "The device_id pass to LITE api is null");
  281. *device_id = static_cast<lite::Network*>(network)->get_device_id();
  282. LITE_CAPI_END();
  283. }
  284. int LITE_set_stream_id(LiteNetwork network, int stream_id) {
  285. LITE_CAPI_BEGIN();
  286. LITE_ASSERT(network, "The network pass to LITE api is null");
  287. static_cast<lite::Network*>(network)->set_stream_id(stream_id);
  288. LITE_CAPI_END();
  289. }
  290. int LITE_get_stream_id(const LiteNetwork network, int* stream_id) {
  291. LITE_CAPI_BEGIN();
  292. LITE_ASSERT(network, "The network pass to LITE api is null");
  293. LITE_ASSERT(stream_id, "The stream_id pass to LITE api is null");
  294. *stream_id = static_cast<lite::Network*>(network)->get_stream_id();
  295. LITE_CAPI_END();
  296. }
  297. int LITE_get_model_extra_info(const LiteNetwork network, const char** info,
  298. int* info_size) {
  299. LITE_CAPI_BEGIN();
  300. LITE_ASSERT(network, "The network pass to LITE api is null");
  301. LITE_ASSERT(info_size, "The info and info_size are all null");
  302. auto& extra_info =
  303. static_cast<lite::Network*>(network)->get_model_extra_info();
  304. *info_size = extra_info.size();
  305. *info = extra_info.c_str();
  306. LITE_MARK_USED_VAR(info);
  307. LITE_CAPI_END();
  308. }
  309. int LITE_get_device_type(const LiteNetwork network,
  310. LiteDeviceType* device_type) {
  311. LITE_CAPI_BEGIN();
  312. LITE_ASSERT(network, "The network pass to LITE api is null");
  313. LITE_ASSERT(device_type, "The device_type pass to LITE api is null");
  314. *device_type = static_cast<lite::Network*>(network)->get_device_type();
  315. LITE_CAPI_END();
  316. }
  317. int LITE_set_async_callback(LiteNetwork network,
  318. const LiteAsyncCallback async_callback) {
  319. LITE_CAPI_BEGIN();
  320. LITE_ASSERT(network, "The network pass to LITE api is null");
  321. LITE_ASSERT(async_callback, "The ptr pass to LITE api is null");
  322. static_cast<lite::Network*>(network)->set_async_callback(
  323. std::move(async_callback));
  324. LITE_CAPI_END();
  325. }
  326. int LITE_set_start_callback(LiteNetwork network,
  327. const LiteStartCallback start_callback) {
  328. LITE_CAPI_BEGIN();
  329. LITE_ASSERT(network, "The network pass to LITE api is null");
  330. auto lite_start_callback =
  331. [start_callback](
  332. const std::unordered_map<
  333. std::string,
  334. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  335. inputs_map) -> void {
  336. std::vector<LiteIO> ios;
  337. std::vector<LiteTensor> io_tensors;
  338. size_t nr_io = 0;
  339. for (const auto& io : inputs_map) {
  340. nr_io++;
  341. auto&& lite_io = io.second.first;
  342. ios.push_back({lite_io.name.c_str(), lite_io.is_host,
  343. lite_io.io_type,
  344. convert_to_clayout(lite_io.config_layout)});
  345. io_tensors.push_back(io.second.second.get());
  346. }
  347. start_callback(ios.data(), io_tensors.data(), nr_io);
  348. };
  349. static_cast<lite::Network*>(network)->set_start_callback(
  350. lite_start_callback);
  351. LITE_CAPI_END();
  352. }
  353. int LITE_set_finish_callback(LiteNetwork network,
  354. const LiteFinishCallback finish_callback) {
  355. LITE_CAPI_BEGIN();
  356. LITE_ASSERT(network, "The network pass to LITE api is null");
  357. auto lite_finish_callback =
  358. [finish_callback](
  359. const std::unordered_map<
  360. std::string,
  361. std::pair<lite::IO, std::shared_ptr<lite::Tensor>>>&
  362. outputs_map) -> void {
  363. std::vector<LiteIO> ios;
  364. std::vector<LiteTensor> io_tensors;
  365. size_t nr_io = 0;
  366. for (const auto& io : outputs_map) {
  367. nr_io++;
  368. auto&& lite_io = io.second.first;
  369. ios.push_back({lite_io.name.c_str(), lite_io.is_host,
  370. lite_io.io_type,
  371. convert_to_clayout(lite_io.config_layout)});
  372. io_tensors.push_back(io.second.second.get());
  373. }
  374. finish_callback(ios.data(), io_tensors.data(), nr_io);
  375. };
  376. static_cast<lite::Network*>(network)->set_finish_callback(
  377. lite_finish_callback);
  378. LITE_CAPI_END();
  379. }
  380. int LITE_enable_profile_performance(LiteNetwork network,
  381. const char* profile_json_file_path) {
  382. LITE_CAPI_BEGIN();
  383. LITE_ASSERT(network, "The network pass to LITE api is null");
  384. static_cast<lite::Network*>(network)->enable_profile_performance(
  385. profile_json_file_path);
  386. LITE_CAPI_END();
  387. }
  388. int LITE_is_cpu_inplace_mode(const LiteNetwork network,
  389. int* is_cpu_inplace_mode) {
  390. LITE_CAPI_BEGIN();
  391. LITE_ASSERT(network && is_cpu_inplace_mode,
  392. "The network pass to LITE api is null");
  393. std::shared_ptr<lite::Network> network_shared{
  394. static_cast<lite::Network*>(network), [](void*) {}};
  395. *is_cpu_inplace_mode = lite::Runtime::is_cpu_inplace_mode(network_shared);
  396. LITE_CAPI_END();
  397. }
  398. int LITE_get_cpu_threads_number(const LiteNetwork network, size_t* nr_threads) {
  399. LITE_CAPI_BEGIN();
  400. LITE_ASSERT(network, "The network pass to LITE api is null");
  401. LITE_ASSERT(nr_threads, "The ptr pass to LITE api is null");
  402. std::shared_ptr<lite::Network> network_shared{
  403. static_cast<lite::Network*>(network), [](void*) {}};
  404. *nr_threads = lite::Runtime::get_cpu_threads_number(network_shared);
  405. LITE_CAPI_END();
  406. }
  407. int LITE_set_cpu_inplace_mode(LiteNetwork network) {
  408. LITE_CAPI_BEGIN();
  409. LITE_ASSERT(network, "The network pass to LITE api is null");
  410. std::shared_ptr<lite::Network> network_shared{
  411. static_cast<lite::Network*>(network), [](void*) {}};
  412. lite::Runtime::set_cpu_inplace_mode(network_shared);
  413. LITE_CAPI_END();
  414. }
  415. int LITE_use_tensorrt(LiteNetwork network){
  416. LITE_CAPI_BEGIN();
  417. LITE_ASSERT(network, "The network pass to LITE api is null");
  418. std::shared_ptr<lite::Network> network_shared{
  419. static_cast<lite::Network*>(network), [](void*) {}};
  420. lite::Runtime::use_tensorrt(network_shared);
  421. LITE_CAPI_END();
  422. }
  423. int LITE_set_cpu_threads_number(LiteNetwork network, size_t nr_threads) {
  424. LITE_CAPI_BEGIN();
  425. LITE_ASSERT(network, "The network pass to LITE api is null");
  426. std::shared_ptr<lite::Network> network_shared{
  427. static_cast<lite::Network*>(network), [](void*) {}};
  428. lite::Runtime::set_cpu_threads_number(network_shared, nr_threads);
  429. LITE_CAPI_END();
  430. }
  431. int LITE_set_network_algo_policy(LiteNetwork network,
  432. LiteAlgoSelectStrategy strategy) {
  433. LITE_CAPI_BEGIN();
  434. LITE_ASSERT(network, "The network pass to LITE api is null");
  435. std::shared_ptr<lite::Network> network_shared{
  436. static_cast<lite::Network*>(network), [](void*) {}};
  437. lite::Runtime::set_network_algo_policy(network_shared, strategy);
  438. LITE_CAPI_END();
  439. }
  440. int LITE_set_network_algo_fastrun_config(LiteNetwork network,
  441. unsigned int shared_batch_size,
  442. int binary_equal_between_batch) {
  443. LITE_CAPI_BEGIN();
  444. LITE_ASSERT(network, "The network pass to LITE api is null");
  445. std::shared_ptr<lite::Network> network_shared{
  446. static_cast<lite::Network*>(network), [](void*) {}};
  447. lite::Runtime::set_network_algo_policy(
  448. network_shared, LiteAlgoSelectStrategy(0), shared_batch_size,
  449. binary_equal_between_batch);
  450. LITE_CAPI_END();
  451. }
  452. int LITE_set_network_algo_workspace_limit(LiteNetwork network,
  453. size_t workspace_limit) {
  454. LITE_CAPI_BEGIN();
  455. LITE_ASSERT(network, "The network pass to LITE api is null");
  456. std::shared_ptr<lite::Network> network_shared{
  457. static_cast<lite::Network*>(network), [](void*) {}};
  458. lite::Runtime::set_network_algo_workspace_limit(network_shared,
  459. workspace_limit);
  460. LITE_CAPI_END();
  461. }
  462. int LITE_set_runtime_thread_affinity(
  463. LiteNetwork network,
  464. const LiteThreadAffinityCallback thread_affinity_callback) {
  465. LITE_CAPI_BEGIN();
  466. LITE_ASSERT(network, "The network pass to LITE api is null");
  467. std::shared_ptr<lite::Network> network_shared{
  468. static_cast<lite::Network*>(network), [](void*) {}};
  469. lite::Runtime::set_runtime_thread_affinity(
  470. network_shared, std::move(thread_affinity_callback));
  471. LITE_CAPI_END();
  472. }
  473. int LITE_set_memory_allocator(LiteNetwork network,
  474. const LiteAllocate allocate_fun,
  475. const LiteFree free_fun) {
  476. LITE_CAPI_BEGIN();
  477. LITE_ASSERT(network && allocate_fun && free_fun,
  478. "The ptr pass to LITE api is null");
  479. std::shared_ptr<lite::Network> network_shared{
  480. static_cast<lite::Network*>(network), [](void*) {}};
  481. lite::Runtime::set_memory_allocator(
  482. network_shared,
  483. std::make_shared<UserAllocator>(allocate_fun, free_fun));
  484. LITE_CAPI_END();
  485. }
  486. int LITE_enable_io_txt_dump(LiteNetwork network, const char* io_txt_out_file) {
  487. LITE_CAPI_BEGIN();
  488. LITE_ASSERT(network, "The network pass to LITE api is null");
  489. std::shared_ptr<lite::Network> network_shared{
  490. static_cast<lite::Network*>(network), [](void*) {}};
  491. lite::Runtime::enable_io_txt_dump(network_shared, io_txt_out_file);
  492. LITE_CAPI_END();
  493. }
  494. int LITE_enable_io_bin_dump(LiteNetwork network, const char* io_bin_out_dir) {
  495. LITE_CAPI_BEGIN();
  496. LITE_ASSERT(network, "The network pass to LITE api is null");
  497. std::shared_ptr<lite::Network> network_shared{
  498. static_cast<lite::Network*>(network), [](void*) {}};
  499. lite::Runtime::enable_io_bin_dump(network_shared, io_bin_out_dir);
  500. LITE_CAPI_END();
  501. }
  502. int LITE_shared_weight_with_network(LiteNetwork dst_network,
  503. const LiteNetwork src_network) {
  504. LITE_CAPI_BEGIN();
  505. LITE_ASSERT(dst_network && src_network,
  506. "The network pass to LITE api is null");
  507. const std::shared_ptr<lite::Network> src_shared_net{
  508. static_cast<lite::Network*>(src_network), [](void*) {}};
  509. std::shared_ptr<lite::Network> dst_shared_net{
  510. static_cast<lite::Network*>(dst_network), [](void*) {}};
  511. lite::Runtime::shared_weight_with_network(dst_shared_net, src_shared_net);
  512. LITE_CAPI_END();
  513. }
  514. int LITE_share_runtime_memroy(LiteNetwork dst_network,
  515. LiteNetwork src_network) {
  516. LITE_CAPI_BEGIN();
  517. LITE_ASSERT(src_network && dst_network,
  518. "The network pass to LITE api is null");
  519. std::shared_ptr<lite::Network> src_shared{
  520. static_cast<lite::Network*>(src_network), [](void*) {}};
  521. std::shared_ptr<lite::Network> dst_shared{
  522. static_cast<lite::Network*>(dst_network), [](void*) {}};
  523. lite::Runtime::share_runtime_memory_with(dst_shared, src_shared);
  524. LITE_CAPI_END();
  525. }
  526. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台