You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_impl.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. #pragma once
  2. #include "lite_build_config.h"
  3. #include "megbrain/graph.h"
  4. #if LITE_BUILD_WITH_MGE
  5. #include "lite/network.h"
  6. #include "network_impl_base.h"
  7. #include "tensor_impl.h"
  8. #include <memory>
  9. #include <unordered_map>
  10. #include "megbrain/gopt/inference.h"
  11. #include "megbrain/graph/bases.h"
  12. #include "megbrain/plugin/opr_io_dump.h"
  13. #include "megbrain/plugin/profiler.h"
  14. #include "megbrain/serialization/extern_c_opr.h"
  15. #include "megbrain/serialization/file.h"
  16. #include "megbrain/serialization/load_dump_config.h"
  17. #include "megbrain/serialization/serializer.h"
  18. #include "megbrain/utils/thin/hash_table.h"
  19. namespace lite {
  20. /*!
  21. * \brief implement the Network, contain the mgb related member
  22. */
  23. class NetworkImplDft final : public Network::NetworkImplBase {
  24. LITE_DYN_TYPE_OBJ_FINAL_DECL;
  25. public:
  26. NetworkImplDft() {
  27. m_load_config.comp_graph = mgb::ComputingGraph::make();
  28. m_user_config = std::make_unique<Config>();
  29. m_network_io = std::make_unique<NetworkIOInner>();
  30. }
  31. using S = megdnn::param::ExecutionPolicy::Strategy;
  32. using Var = mgb::cg::SymbolVar;
  33. //! set the config of the network, include:
  34. //! the inference device
  35. //! the other inference options, such as record_level, weight_preprocess...
  36. void set_config(const Config& config) override;
  37. //! set the special io infomation, if not set, default io tensor will used,
  38. //! this is special for input/output is not host tensor, default the
  39. //! input/output tensors are host tensor
  40. void set_io(const NetworkIO& network_io) override;
  41. //! only compute the output tensor in user configured
  42. void compute_only_configured_output() override {
  43. m_compute_configured_output_only = true;
  44. }
  45. //! get the network input and ouput tensor, the layout of which is
  46. //! sync from mge tensor
  47. std::shared_ptr<Tensor> get_io_tensor(
  48. std::string io_name,
  49. LiteTensorPhase phase = LiteTensorPhase::LITE_IO) override;
  50. //! get the network input tensors which input consists of discrete multiple tensors,
  51. //! layout (1, c, h, w)
  52. std::vector<std::shared_ptr<Tensor>> get_discrete_tensors(
  53. std::string io_name,
  54. LiteTensorPhase phase = LiteTensorPhase::LITE_INPUT) override;
  55. //! get the input tensor by index in the load_result tensormap
  56. std::shared_ptr<Tensor> get_input_tensor(size_t index) override;
  57. //! get the network input tensors which input consists of discrete multiple tensors
  58. //! by index
  59. std::vector<std::shared_ptr<Tensor>> get_input_tensors(size_t index) override;
  60. //! get the output tensor by index in the load_result output_var_list
  61. std::shared_ptr<Tensor> get_output_tensor(size_t index) override;
  62. //! get all the input tensor name in the order in load return
  63. std::vector<const char*> get_all_input_name() const override;
  64. //! get all the output tensor name in the order in load return
  65. std::vector<const char*> get_all_output_name() const override;
  66. //! get the input tensor name in the order in load return
  67. const char* get_input_name(size_t index) const override;
  68. //! get the output tensor name in the order in load return
  69. const char* get_output_name(size_t index) const override;
  70. //! set the callback in async model
  71. void set_async_callback(const AsyncCallback& callback) override;
  72. //! set the start callback which will execute before network forward
  73. void set_start_callback(const StartCallback& callback) override {
  74. m_start_callback = std::move(callback);
  75. }
  76. //! set the finish callback which will execute after network forward
  77. void set_finish_callback(const FinishCallback& callback) override {
  78. m_finish_callback = std::move(callback);
  79. }
  80. //! load the model and get the m_load_result
  81. void load_model(
  82. std::shared_ptr<void> model_mem, size_t size,
  83. std::unordered_map<std::string, LiteAny> separate_config_map = {}) override;
  84. //! forward the network with filled input data and fill the output data
  85. //! to the output tensor
  86. void forward() override;
  87. //! in sync model, wait utile the inference finish
  88. void wait() override;
  89. virtual LiteDeviceType get_device_type() const override {
  90. return m_user_config->device_type;
  91. }
  92. //! Set cpu default mode when device is CPU, in some low computation
  93. //! device or single core device, this mode will get good performace
  94. void set_cpu_inplace_mode();
  95. bool is_cpu_inplace_mode() const { return m_is_cpu_inplace_mode; }
  96. //! When device is CPU, this interface will set the to be loaded model
  97. //! run in multi thread mode with the given thread number.
  98. void set_cpu_threads_number(size_t nr_threads);
  99. size_t get_cpu_threads_number() const { return m_nr_threads; }
  100. //! set device id, default device id = 0
  101. void set_device_id(int device_id) override;
  102. int get_device_id() const override { return m_compnode_locator.device; };
  103. LiteBackend get_backend_type() const override { return LiteBackend::LITE_DEFAULT; }
  104. //! set stream id, default stream id = 0
  105. void set_stream_id(int stream_id) override;
  106. int get_stream_id() const override { return m_compnode_locator.stream; };
  107. //! enable tensorrt
  108. void use_tensorrt();
  109. //! enable profile the network, a JSON format file will be generated
  110. void enable_profile_performance(std::string profile_json_file_path) override;
  111. /********************** mge special function ************************/
  112. //! load a new network which will share weights with src network
  113. void shared_weight_with(const NetworkImplBase* src_network);
  114. //! share the runtime memory with other network, the weights is not shared
  115. void share_runtime_memory_with(NetworkImplBase* network);
  116. //! set threads affinity callback;
  117. void set_runtime_thread_affinity(
  118. const ThreadAffinityCallback& thread_affinity_callback);
  119. //! set the network memroy allocator, the allocator is defined by user
  120. void set_memory_allocator(std::shared_ptr<Allocator> user_allocator);
  121. //! set opr algorithm selection strategy in the network
  122. void set_network_algo_policy(
  123. LiteAlgoSelectStrategy strategy, uint32_t shared_batch_size,
  124. bool binary_equal_between_batch);
  125. //! set workspace_limit for oprs with multiple algorithms, set
  126. //! workspace limitation can save memory but may influence the performance
  127. void set_network_algo_workspace_limit(size_t workspace_limit);
  128. //! Dump input/output values of all internal variables to output file,
  129. //! in text format
  130. void enable_io_txt_dump(std::string io_txt_out_file);
  131. //! Dump input/output values of all internal variables to output
  132. //! directory, in binary format
  133. void enable_io_bin_dump(std::string io_bin_out_dir);
  134. //! get static peak memory info showed by Graph visualization
  135. void get_static_memory_alloc_info(
  136. const std::string& log_dir = "logs/test") const override;
  137. //! set global layout transform optimization for network
  138. void enable_global_layout_transform();
  139. //! dump network after global layout transform optimization
  140. void dump_layout_transform_model(std::string optimized_model_path);
  141. mgb::serialization::GraphLoader::LoadResult get_load_result() {
  142. return m_load_result;
  143. }
  144. private:
  145. //! construct the outputspec according to the m_network_io, and set the
  146. //! call_back to the outputspec
  147. void make_output_spec();
  148. //! do layout transform for the given platform target, maybe the global
  149. //! layout optimization or heuristically choose the best layout according to
  150. //! the device information
  151. void layout_transform_optimization();
  152. //! modify the execution policy
  153. void modify_exection_policy();
  154. //! if the input is dev tensor, the pass will replace the H2D Opr to
  155. //! VolatileSharedDeviceTensor Opr
  156. void replace_dev_input_pass();
  157. //! if the input to the network is a list of tensors, this pass will replace
  158. //! the opr that supports the input of a list of tensors with the corresponding
  159. //! version, current support WarpPerspective
  160. void replace_src_discrete_input_opr_pass();
  161. //! check whether the model is cross compnode
  162. void cross_compnode_model_detect();
  163. //! when the model have loaded, update the IO, if not set networkio, update
  164. //! the networkio with the IO of loaded model
  165. void update_io();
  166. void update_input();
  167. void update_output();
  168. //! initialization lite_tensors when input is composed of discrete multiple tensors
  169. void update_input_lite_tensors();
  170. //! when the model info have loaded, update the config according the model
  171. //! info, finaly use it in compute graph
  172. void application_config();
  173. //! after finish forwarding the netwark, output the result of plugin to file
  174. void output_plugin_result() const;
  175. //! when finish forwarding the network, the function will be called
  176. void finish() const;
  177. //! before forwarding the network, the function will be called
  178. void start() const;
  179. //! compile the graph to get the execute function
  180. void compile_graph();
  181. //! try to infer output tensor layout
  182. void try_infer_tensor_layout(std::shared_ptr<Tensor> tensor, Var var);
  183. //! optimized output tensor copy
  184. void output_tensor_copy_optimize(Var var, std::shared_ptr<Tensor> tensor);
  185. //! configure and optimize network after loaded
  186. void configure_after_loaded();
  187. private:
  188. bool m_async = false;
  189. bool m_is_cpu_inplace_mode = false;
  190. int m_nr_device_type = 0;
  191. size_t m_nr_threads = 1;
  192. bool m_compute_configured_output_only = false;
  193. bool m_set_layout_transform = false;
  194. mgb::CompNode::Locator m_compnode_locator;
  195. AsyncCallback m_async_callback = nullptr;
  196. std::unique_ptr<NetworkIOInner> m_network_io;
  197. std::unique_ptr<Config> m_user_config;
  198. std::unique_ptr<mgb::cg::AsyncExecutable> m_execute_func;
  199. //! The model load related data
  200. S m_execution_policy = static_cast<S>(0);
  201. std::unique_ptr<mgb::serialization::InputFile> m_input_file;
  202. mgb::Maybe<mgb::serialization::GraphDumpFormat> m_format;
  203. mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
  204. mgb::serialization::GraphLoadConfig m_load_config;
  205. mgb::serialization::GraphLoader::LoadResult m_load_result;
  206. mgb::ComputingGraph::OutputSpec m_output_spec;
  207. std::shared_ptr<mgb::serialization::GraphLoader> m_loader;
  208. //! start and finish callback
  209. StartCallback m_start_callback = nullptr;
  210. FinishCallback m_finish_callback = nullptr;
  211. //! profile and io dump related data
  212. #if MGB_ENABLE_JSON
  213. std::unique_ptr<mgb::GraphProfiler> m_profiler;
  214. std::string m_profiler_output_file;
  215. #endif
  216. std::unique_ptr<mgb::OprIODumpBase> m_iodump;
  217. };
  218. //! get the model information before model loaded by Network
  219. NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config);
  220. //! get the model information before model loaded by Network by model memory and
  221. //! size
  222. NetworkIO get_model_io_info_dft(
  223. const void* model_mem, size_t size, const Config& config);
  224. } // namespace lite
  225. #endif
  226. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}