You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.h 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. /**
  2. * \file inlude/lite/network.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "macro.h"
  13. #include "tensor.h"
  14. #include <functional>
  15. #include <memory>
  16. #include <mutex>
  17. #include <string>
  18. #include <unordered_map>
  19. namespace lite {
  20. LITE_API inline LiteAlgoSelectStrategy operator|(
  21. LiteAlgoSelectStrategy x, LiteAlgoSelectStrategy y) {
  22. return static_cast<LiteAlgoSelectStrategy>(
  23. static_cast<uint32_t>(x) | static_cast<uint32_t>(y));
  24. }
  25. /*!
  26. * \brief the inference options which will be translated to megenine
  27. *
  28. * \param weight_preprocess is the option wich optimize the inferece performance
  29. * with preprocess the const weights
  30. *
  31. * \param fuse_preprocess fuse preprocess patten, like astype + pad_channel +
  32. * dimshuffle
  33. *
  34. * \param fake_next_exec whether only to perform non-computing tasks (like
  35. * memory allocation and queue initialization) for next exec. This would be
  36. * reset to false when the graph is executed.
  37. *
  38. * \param var_sanity_check_first_run Disable var sanity check on the first run.
  39. * Var sanity check is enabled on the first-time execution by default, and can
  40. * be used to find some potential memory access errors in the operator
  41. * implementation.
  42. *
  43. * \param const_shape This can be used to reduce memory usage since some
  44. * static inference data structures can be omitted.
  45. *
  46. * \param force_dynamic_alloc force dynamic memory alloc for all vars
  47. *
  48. * \param force_output_dynamic_alloc force dynamic memory alloc for output vars
  49. * which are used as CallbackCaller input when call compile() function
  50. *
  51. * \param no_profiling_on_shape_change do not re-profile to select best impl
  52. * algo when input shape changes (use previous algo)
  53. *
  54. * \param jit_level Execute supported operators with JIT (support MLIR,
  55. * NVRTC). Can only be used on Nvidia GPUs, this value indicates JIT level:
  56. * 1 for basic elemwise opr;
  57. * 2 for including reduce operator
  58. *
  59. * \param record_level flag optimize the inference performace with record the
  60. * kernel tasks in first run, hereafter the inference all need to execute the
  61. * recorded tasks.
  62. * level = 0 means the normal inference,
  63. * level = 1 means use record inference,
  64. * level = 2 means record inference with free the extra memory
  65. *
  66. * \param graph_opt_level optimization level:
  67. * 0: disable
  68. * 1: level-1: inplace arith transformations during graph
  69. * construction
  70. * 2: level-2: level-1, plus global optimization before graph
  71. * compiling
  72. * 3: also enable JIT
  73. * <0: corresponding level, with result check for debug
  74. *
  75. * \param async_exec_level exec: dispatch on separate threads for different
  76. * comp_node.
  77. * 0: do not perform async dispatch
  78. * 1: dispatch async if there are more than one comp node with limited queue
  79. * mask 0b10: async if there are multiple comp nodes with
  80. * mask 0b100: always async
  81. */
  82. struct LITE_API Options {
  83. bool weight_preprocess = false;
  84. bool fuse_preprocess = false;
  85. bool fake_next_exec = false;
  86. bool var_sanity_check_first_run = true;
  87. bool const_shape = false;
  88. bool force_dynamic_alloc = false;
  89. bool force_output_dynamic_alloc = false;
  90. bool force_output_use_user_specified_memory = false;
  91. bool no_profiling_on_shape_change = false;
  92. uint8_t jit_level = 0;
  93. uint8_t comp_node_seq_record_level = 0;
  94. uint8_t graph_opt_level = 2;
  95. uint16_t async_exec_level = 1;
  96. //! layout transform options
  97. bool enable_nchw44 = false;
  98. bool enable_nchw44_dot = false;
  99. bool enable_nchw88 = false;
  100. bool enable_nhwcd4 = false;
  101. bool enable_nchw4 = false;
  102. bool enable_nchw32 = false;
  103. bool enable_nchw64 = false;
  104. };
  105. /*!
  106. * \brief Configuration when load and compile the graph
  107. *
  108. * \param bare_model_cryption_name is the bare model cryption method name, bare
  109. *model is not pack json info inside
  110. *
  111. *\param has_compression flag whether the model is compressed, the compress
  112. *method will read form the model
  113. */
  114. struct LITE_API Config {
  115. bool has_compression = false;
  116. int device_id = 0;
  117. LiteDeviceType device_type = LiteDeviceType::LITE_CPU;
  118. LiteBackend backend = LiteBackend::LITE_DEFAULT;
  119. std::string bare_model_cryption_name = {};
  120. Options options = {};
  121. };
  122. /*!
  123. * \brief config the network input and output item
  124. *
  125. */
  126. struct LITE_API IO {
  127. //! the tensor name in the graph corresponding to the IO
  128. std::string name;
  129. //! Used to mark where the input tensor comes from and the output where copy
  130. //! to, if is_host is true, the input is from host and output copy to host,
  131. //! otherwise device. Sometimes The input is from device and output no need
  132. //! copy to host, default is true.
  133. bool is_host = true;
  134. //! The IO type, it can be SHAPE or VALUE, when SHAPE is set, the input or
  135. //! output tensor value is invaid, only shape will be set, default is VALUE
  136. LiteIOType io_type = LiteIOType::LITE_IO_VALUE;
  137. //! The layout of the config from user, if other layout is set before
  138. //! forward or get after forward by input tensor reset, this layout will by
  139. //! pass. if no other layout is set before forward, this layout will work.
  140. //! if this layout is no set, the model will forward with its origin layout.
  141. //! if in output, it will used to check.
  142. Layout config_layout = {};
  143. };
  144. /*!
  145. * \brief the input and output information when load the network
  146. * the NetworkIO will remain in the network until the network is destroyed
  147. */
  148. struct LITE_API NetworkIO {
  149. std::vector<IO> inputs = {};
  150. std::vector<IO> outputs = {};
  151. };
  152. /*!
  153. * \brief A user-implemented allocator interface
  154. */
  155. class LITE_API Allocator {
  156. public:
  157. virtual ~Allocator() = default;
  158. //! allocate memory of size in the given device with the given align
  159. virtual void* allocate(
  160. LiteDeviceType device_type, int device_id, size_t size, size_t align) = 0;
  161. //! free the memory pointed by ptr in the given device
  162. virtual void free(LiteDeviceType device_type, int device_id, void* ptr) = 0;
  163. };
  164. /*!
  165. * \brief the thread affinith callback type
  166. * \param thread_id thread_id is the a number begin from 0 to (nr_threads - 1),
  167. * thread_id of (nr_threads - 1) is the main worker thread.
  168. */
  169. using ThreadAffinityCallback = std::function<void(int thread_id)>;
  170. using AsyncCallback = std::function<void(void)>;
  171. /*!
  172. * \brief the start/finish callback function
  173. * \param unordered_map map from the io tensor name to the pair of which is the
  174. * corresponding IO of user config and the realy input or output tensor.
  175. */
  176. using StartCallback =
  177. std::function<void(const std::unordered_map<
  178. std::string, std::pair<IO, std::shared_ptr<Tensor>>>&)>;
  179. using FinishCallback =
  180. std::function<void(const std::unordered_map<
  181. std::string, std::pair<IO, std::shared_ptr<Tensor>>>&)>;
  182. /*!
  183. * \brief The network is construct form a model, implement model load, init,
  184. * forward, and display some model information
  185. */
  186. class LITE_API Network {
  187. public:
  188. class NetworkImplBase;
  189. ~Network();
  190. Network(const Config& config = {}, const NetworkIO& networkio = {});
  191. Network(const NetworkIO& networkio, const Config& config = {});
  192. //! load the model form memory
  193. void load_model(void* model_mem, size_t size);
  194. //! load the model from a model path
  195. void load_model(std::string model_path);
  196. //! only compute the output tensor in user configured
  197. void compute_only_configured_output();
  198. //! get the network input and output tensor, the layout of which is
  199. //! sync from mge tensor, when the name of input and output tensor are the
  200. //! same, use LiteTensorPhase to separate
  201. std::shared_ptr<Tensor> get_io_tensor(
  202. std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_IO);
  203. //! get the network input by index
  204. std::shared_ptr<Tensor> get_input_tensor(size_t index);
  205. //! get the network output tensor by index
  206. std::shared_ptr<Tensor> get_output_tensor(size_t index);
  207. //! set the network forward in async mode and set the async callback
  208. //! function
  209. Network& set_async_callback(const AsyncCallback& async_callback);
  210. //! set the start forward callback function, which will be execute before
  211. //! forward. this can be used to check network input or dump model inputs
  212. //! for debug
  213. Network& set_start_callback(const StartCallback& start_callback);
  214. //! set the finish forward callback function, which will be execute after
  215. //! forward. this can be used to dump model outputs for debug
  216. Network& set_finish_callback(const FinishCallback& finish_callback);
  217. //! forward the network with filled input data and fill the output data
  218. //! to the output tensor
  219. void forward();
  220. //! waite until forward finish in sync model
  221. void wait();
  222. //! get the input tensor name in the order in load return
  223. std::string get_input_name(size_t index) const;
  224. //! get the output tensor name in the order in load return
  225. std::string get_output_name(size_t index) const;
  226. //! get all the input tensor name in the order in load return
  227. std::vector<std::string> get_all_input_name() const;
  228. //! get all the output tensor name in the order in load return
  229. std::vector<std::string> get_all_output_name() const;
  230. //! set/get device id, default device id = 0
  231. Network& set_device_id(int device_id);
  232. int get_device_id() const;
  233. //! set/get stream id, default stream id = 0
  234. Network& set_stream_id(int stream_id);
  235. int get_stream_id() const;
  236. //! enable profile the network, a file will be generated
  237. void enable_profile_performance(std::string profile_file_path);
  238. //! get model extra info
  239. const std::string& get_model_extra_info();
  240. //! get device type
  241. LiteDeviceType get_device_type() const;
  242. //! get static peak memory info showed by Graph visualization
  243. void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const;
  244. public:
  245. friend class NetworkHelper;
  246. private:
  247. //! update member from implement
  248. void update_from_implement();
  249. //! decrypt and parse the model file
  250. void prase_model(std::shared_ptr<void> model_data, size_t size);
  251. private:
  252. bool m_loaded = false;
  253. Config m_config;
  254. NetworkIO m_network_io;
  255. std::unique_ptr<NetworkImplBase> m_impl;
  256. std::string m_extra_info;
  257. };
  258. /*********************** MGE special network function ***************/
  259. class LITE_API Runtime {
  260. public:
  261. //! When device is CPU, this interface will set the to be loaded model
  262. //! run in multi thread mode with the given thread number.
  263. static void set_cpu_threads_number(
  264. std::shared_ptr<Network> dst_network, size_t nr_threads);
  265. static size_t get_cpu_threads_number(std::shared_ptr<Network> dst_network);
  266. //! set threads affinity callback;
  267. static void set_runtime_thread_affinity(
  268. std::shared_ptr<Network> network,
  269. const ThreadAffinityCallback& thread_affinity_callback);
  270. //! Set cpu default mode when device is CPU, in some low computation
  271. //! device or single core device, this mode will get good performace
  272. static void set_cpu_inplace_mode(std::shared_ptr<Network> dst_network);
  273. static bool is_cpu_inplace_mode(std::shared_ptr<Network> dst_network);
  274. //! Set use tensorrt forward
  275. static void use_tensorrt(std::shared_ptr<Network> dst_network);
  276. //! set opr algorithm selection strategy in the network
  277. //! shared_batch_size: the batch size used by fastrun,
  278. //! Non-zero value means that fastrun use this batch size
  279. //! regardless of the batch size of the model. Zero means
  280. //! fastrun use batch size of the model
  281. //! binary_equal_between_batch: if the content of each input batch is binary
  282. //! equal,whether the content of each output
  283. //! batch is promised to be equal
  284. static void set_network_algo_policy(
  285. std::shared_ptr<Network> dst_network, LiteAlgoSelectStrategy strategy,
  286. uint32_t shared_batch_size = 0, bool binary_equal_between_batch = false);
  287. //! set workspace_limit for oprs with multiple algorithms, set
  288. //! workspace limitation can save memory but may influence the performance
  289. static void set_network_algo_workspace_limit(
  290. std::shared_ptr<Network> dst_network, size_t workspace_limit);
  291. //! set the network memroy allocator, the allocator is defined by user
  292. static void set_memory_allocator(
  293. std::shared_ptr<Network> dst_network,
  294. std::shared_ptr<Allocator> user_allocator);
  295. //! share the runtime memory with other network, the weights is not shared
  296. static void share_runtime_memory_with(
  297. std::shared_ptr<Network> dst_network, std::shared_ptr<Network> src_network);
  298. //! Dump input/output values of all internal variables to output
  299. //! file, in txt format
  300. static void enable_io_txt_dump(
  301. std::shared_ptr<Network> dst_network, std::string io_txt_out_file);
  302. //! Dump input/output values of all internal variables to output
  303. //! directory, in binary format
  304. static void enable_io_bin_dump(
  305. std::shared_ptr<Network> dst_network, std::string io_bin_out_dir);
  306. //! load a new network which will share weights with src network
  307. static void shared_weight_with_network(
  308. std::shared_ptr<Network> dst_network,
  309. const std::shared_ptr<Network> src_network);
  310. //! set global layout transform optimization for network
  311. static void enable_global_layout_transform(std::shared_ptr<Network> network);
  312. //! dump network after global layout transform optimization
  313. static void dump_layout_transform_model(
  314. std::shared_ptr<Network> network, std::string optimized_model_path);
  315. };
  316. } // namespace lite
  317. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}