You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

optimize_options.h 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. #pragma once
  2. #include <gflags/gflags.h>
  3. #include "helpers/common.h"
  4. #include "models/model.h"
  5. #include "option_base.h"
  6. DECLARE_bool(enable_fuse_preprocess);
  7. DECLARE_bool(weight_preprocess);
  8. DECLARE_bool(enable_fuse_conv_bias_nonlinearity);
  9. DECLARE_bool(enable_fuse_conv_bias_with_z);
  10. DECLARE_bool(const_shape);
  11. DECLARE_bool(fake_first);
  12. DECLARE_bool(no_sanity_check);
  13. DECLARE_bool(record_comp_seq);
  14. DECLARE_bool(record_comp_seq2);
  15. DECLARE_bool(disable_mem_opt);
  16. DECLARE_uint64(workspace_limit);
  17. DECLARE_bool(enable_jit);
  18. #if MGB_ENABLE_TENSOR_RT
  19. DECLARE_bool(tensorrt);
  20. DECLARE_string(tensorrt_cache);
  21. #endif
  22. namespace lar {
  23. ///////////////////////// fuse_preprocess optimize options //////////////
  24. class FusePreprocessOption final : public OptionBase {
  25. public:
  26. static bool is_valid();
  27. static std::shared_ptr<OptionBase> create_option();
  28. void config_model(
  29. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  30. std::string option_name() const override { return m_option_name; };
  31. static void set_valid(bool val) { m_valid = val; }
  32. OptionValMap* get_option() override { return &m_option; }
  33. void update() override;
  34. private:
  35. FusePreprocessOption() = default;
  36. template <typename ModelImpl>
  37. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  38. std::string m_option_name;
  39. bool enable_fuse_preprocess;
  40. static bool m_valid;
  41. OptionValMap m_option;
  42. };
  43. ///////////////////////// weight preprocess optimize options //////////////
  44. class WeightPreprocessOption final : public OptionBase {
  45. public:
  46. static bool is_valid();
  47. static std::shared_ptr<OptionBase> create_option();
  48. void config_model(
  49. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  50. std::string option_name() const override { return m_option_name; };
  51. static void set_valid(bool val) { m_valid = val; };
  52. OptionValMap* get_option() override { return &m_option; }
  53. void update() override;
  54. private:
  55. WeightPreprocessOption() = default;
  56. template <typename ModelImpl>
  57. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  58. std::string m_option_name;
  59. bool weight_preprocess;
  60. static bool m_valid;
  61. OptionValMap m_option;
  62. };
  63. /////////////// fuse_conv_bias_nonlinearity optimize options ///////////////
  64. class FuseConvBiasNonlinearOption final : public OptionBase {
  65. public:
  66. static bool is_valid();
  67. static std::shared_ptr<OptionBase> create_option();
  68. void config_model(
  69. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  70. std::string option_name() const override { return m_option_name; };
  71. static void set_valid(bool val) { m_valid = val; }
  72. OptionValMap* get_option() override { return &m_option; }
  73. void update() override;
  74. private:
  75. FuseConvBiasNonlinearOption() = default;
  76. template <typename ModelImpl>
  77. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  78. std::string m_option_name;
  79. bool enable_fuse_conv_bias_nonlinearity;
  80. static bool m_valid;
  81. OptionValMap m_option;
  82. };
  83. ///////////////////////// fuse_conv_bias_with_z optimize options //////////////
  84. class FuseConvBiasElemwiseAddOption final : public OptionBase {
  85. public:
  86. static bool is_valid();
  87. static std::shared_ptr<OptionBase> create_option();
  88. void config_model(
  89. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  90. std::string option_name() const override { return m_option_name; };
  91. static void set_valid(bool val) { m_valid = val; }
  92. OptionValMap* get_option() override { return &m_option; }
  93. void update() override;
  94. private:
  95. FuseConvBiasElemwiseAddOption() = default;
  96. template <typename ModelImpl>
  97. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  98. std::string m_option_name;
  99. bool enable_fuse_conv_bias_with_z;
  100. static bool m_valid;
  101. OptionValMap m_option;
  102. };
  103. ///////////////////////// graph record options ///////////////////////////
  104. class GraphRecordOption final : public OptionBase {
  105. public:
  106. static bool is_valid();
  107. static std::shared_ptr<OptionBase> create_option();
  108. void config_model(
  109. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  110. std::string option_name() const override { return m_option_name; };
  111. static void set_valid(bool val) { m_valid = val; }
  112. OptionValMap* get_option() override { return &m_option; }
  113. void update() override;
  114. private:
  115. GraphRecordOption() = default;
  116. template <typename ModelImpl>
  117. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  118. std::string m_option_name;
  119. size_t m_record_comp_seq;
  120. bool const_shape;
  121. bool fake_first;
  122. bool no_sanity_check;
  123. static bool m_valid;
  124. OptionValMap m_option;
  125. };
  126. ///////////////////////// memory optimize options /////////////////////////
  127. class MemoryOptimizeOption final : public OptionBase {
  128. public:
  129. static bool is_valid();
  130. static std::shared_ptr<OptionBase> create_option();
  131. void config_model(
  132. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  133. std::string option_name() const override { return m_option_name; };
  134. void update() override;
  135. private:
  136. MemoryOptimizeOption() = default;
  137. template <typename ModelImpl>
  138. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  139. std::string m_option_name;
  140. bool disable_mem_opt;
  141. uint64_t workspace_limit;
  142. };
  143. ///////////////////////// other options for optimization /////////////////
  144. class JITOption final : public OptionBase {
  145. public:
  146. static bool is_valid();
  147. static std::shared_ptr<OptionBase> create_option();
  148. void config_model(
  149. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  150. std::string option_name() const override { return m_option_name; };
  151. void update() override;
  152. private:
  153. JITOption() = default;
  154. template <typename ModelImpl>
  155. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  156. std::string m_option_name;
  157. bool enable_jit;
  158. };
  159. ///////////////////////// TensorRT options for optimization /////////////////
  160. #if MGB_ENABLE_TENSOR_RT
  161. class TensorRTOption final : public OptionBase {
  162. public:
  163. static bool is_valid();
  164. static std::shared_ptr<OptionBase> create_option();
  165. void config_model(
  166. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  167. std::string option_name() const override { return m_option_name; };
  168. void update() override;
  169. private:
  170. TensorRTOption() = default;
  171. template <typename ModelImpl>
  172. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  173. std::string m_option_name;
  174. bool enable_tensorrt;
  175. std::string tensorrt_cache;
  176. };
  177. #endif
  178. } // namespace lar