You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_impl_base.h 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /**
  2. * \file src/network_impl_base.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "lite/network.h"
  13. #include "misc.h"
  14. #include "tensor_impl_base.h"
  15. #include "type_info.h"
  16. #include <unordered_map>
  17. namespace lite {
  18. /*!
  19. * \brief the Inner IO data struct, add some inner data from IO
  20. */
  21. class IOInner : public IO {
  22. public:
  23. //! use to flag the corresponding lite_tensor is filled, when the
  24. //! value of lite_tensor is filled, the have_sync is true, other wise false,
  25. //! this is used in async mode
  26. bool have_sync = false;
  27. //! Real input and output data location
  28. std::shared_ptr<Tensor> lite_tensor = nullptr;
  29. IOInner() = default;
  30. IOInner(const IO& io) {
  31. name = io.name;
  32. is_host = io.is_host;
  33. io_type = io.io_type;
  34. config_layout = io.config_layout;
  35. }
  36. };
  37. /*!
  38. * \brief the realy network IO info when network run
  39. */
  40. struct NetworkIOInner {
  41. std::vector<IOInner> inputs;
  42. std::vector<IOInner> outputs;
  43. };
  44. /*!
  45. * \brief implement the Network, contain the mgb related member
  46. */
  47. class Network::NetworkImplBase : public DynTypeObj {
  48. public:
  49. virtual ~NetworkImplBase() = default;
  50. //! set the config of the network, include:
  51. //! the inference device
  52. //! the other inference options, such as record_level, weight_preprocess...
  53. virtual void set_config(const Config& config) = 0;
  54. //! set the special io infomation, if not set, default io tensor will used,
  55. //! this is special for input/output is not host tensor, default the
  56. //! input/output tensors are host tensor
  57. virtual void set_io(const NetworkIO& network_io) = 0;
  58. //! only compute the output tensor in user configured
  59. virtual void compute_only_configured_output() = 0;
  60. //! get the network input and ouput tensor, the layout of which is
  61. //! sync from mge tensor
  62. virtual std::shared_ptr<Tensor> get_io_tensor(
  63. std::string io_name, LiteTensorPhase phase = LiteTensorPhase::LITE_IO) = 0;
  64. //! get the input tensor by index in the load_result tensormap
  65. virtual std::shared_ptr<Tensor> get_input_tensor(size_t index) = 0;
  66. //! get the output tensor by index in the load_result output_var_list
  67. virtual std::shared_ptr<Tensor> get_output_tensor(size_t index) = 0;
  68. //! get all the input tensor name in the order in load return
  69. virtual std::vector<const char*> get_all_input_name() const = 0;
  70. //! get all the output tensor name in the order in load return
  71. virtual std::vector<const char*> get_all_output_name() const = 0;
  72. //! get the input tensor name in the order in load return
  73. virtual const char* get_input_name(size_t index) const = 0;
  74. //! get the output tensor name in the order in load return
  75. virtual const char* get_output_name(size_t index) const = 0;
  76. //! set the callback in async model
  77. virtual void set_async_callback(const AsyncCallback& callback) = 0;
  78. //! set the start callback which will execute before network forward
  79. virtual void set_start_callback(const StartCallback& callback) = 0;
  80. //! set the finish callback which will execute after network forward
  81. virtual void set_finish_callback(const FinishCallback& callback) = 0;
  82. //! load the model and get the m_load_result
  83. virtual void load_model(
  84. std::shared_ptr<void> model_mem, size_t size,
  85. std::unordered_map<std::string, LiteAny> separate_config_map = {}) = 0;
  86. //! forward the network with filled input data and fill the output data
  87. //! to the output tensor
  88. virtual void forward() = 0;
  89. //! in sync model, wait utile the inference finish
  90. virtual void wait() = 0;
  91. //! set device id, default device id = 0
  92. virtual void set_device_id(int device_id) = 0;
  93. virtual int get_device_id() const = 0;
  94. virtual LiteBackend get_backend_type() const = 0;
  95. //! set stream id, default stream id = 0
  96. virtual void set_stream_id(int stream_id) = 0;
  97. virtual int get_stream_id() const = 0;
  98. virtual LiteDeviceType get_device_type() const = 0;
  99. //! enable profile the network, a file will be generated
  100. virtual void enable_profile_performance(std::string profile_file_path) = 0;
  101. //! get static peak memory info showed by Graph visualization
  102. virtual void get_static_memory_alloc_info(const std::string& log_dir) const {
  103. LITE_MARK_USED_VAR(log_dir);
  104. LITE_THROW(
  105. "This nerworkimpl doesn't support get_static_memory_alloc_info() "
  106. "function.");
  107. }
  108. };
  109. /******************************** friend class *****************************/
  110. /*!
  111. * \brief friend class of Network, for convenient accessing the Network members
  112. */
  113. class NetworkHelper {
  114. public:
  115. static bool loaded(const std::shared_ptr<Network> network) {
  116. LITE_ASSERT(network);
  117. return network->m_loaded;
  118. }
  119. static void loaded(const std::shared_ptr<Network> network, bool loaded) {
  120. LITE_ASSERT(network);
  121. network->m_loaded = loaded;
  122. }
  123. static Network::NetworkImplBase* implement(const Network* network) {
  124. LITE_ASSERT(network);
  125. return network->m_impl.get();
  126. }
  127. static Network::NetworkImplBase* implement(const std::shared_ptr<Network> network) {
  128. LITE_ASSERT(network);
  129. return network->m_impl.get();
  130. }
  131. static void implement(
  132. const std::shared_ptr<Network> network,
  133. std::unique_ptr<Network::NetworkImplBase> impl) {
  134. LITE_ASSERT(network);
  135. network->m_impl = std::move(impl);
  136. }
  137. };
  138. } // namespace lite
  139. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台