#pragma once #include #include "helpers/common.h" #include "models/model.h" #include "option_base.h" DECLARE_bool(enable_fuse_preprocess); DECLARE_bool(weight_preprocess); DECLARE_bool(enable_fuse_conv_bias_nonlinearity); DECLARE_bool(enable_fuse_conv_bias_with_z); DECLARE_bool(const_shape); DECLARE_bool(fake_first); DECLARE_bool(no_sanity_check); DECLARE_bool(record_comp_seq); DECLARE_bool(record_comp_seq2); DECLARE_bool(disable_mem_opt); DECLARE_uint64(workspace_limit); DECLARE_bool(enable_jit); #if MGB_ENABLE_TENSOR_RT DECLARE_bool(tensorrt); DECLARE_string(tensorrt_cache); #endif namespace lar { ///////////////////////// fuse_preprocess optimize options ////////////// class FusePreprocessOption final : public OptionBase { public: static bool is_valid(); static std::shared_ptr create_option(); void config_model( RuntimeParam& runtime_param, std::shared_ptr model) override; std::string option_name() const override { return m_option_name; }; static void set_valid(bool val) { m_valid = val; } OptionValMap* get_option() override { return &m_option; } void update() override; private: FusePreprocessOption() = default; template void config_model_internel(RuntimeParam&, std::shared_ptr){}; std::string m_option_name; bool enable_fuse_preprocess; static bool m_valid; OptionValMap m_option; }; ///////////////////////// weight preprocess optimize options ////////////// class WeightPreprocessOption final : public OptionBase { public: static bool is_valid(); static std::shared_ptr create_option(); void config_model( RuntimeParam& runtime_param, std::shared_ptr model) override; std::string option_name() const override { return m_option_name; }; static void set_valid(bool val) { m_valid = val; }; OptionValMap* get_option() override { return &m_option; } void update() override; private: WeightPreprocessOption() = default; template void config_model_internel(RuntimeParam&, std::shared_ptr){}; std::string m_option_name; bool weight_preprocess; static bool m_valid; OptionValMap m_option; }; /////////////// fuse_conv_bias_nonlinearity optimize options /////////////// class FuseConvBiasNonlinearOption final : public OptionBase { public: static bool is_valid(); static std::shared_ptr create_option(); void config_model( RuntimeParam& runtime_param, std::shared_ptr model) override; std::string option_name() const override { return m_option_name; }; static void set_valid(bool val) { m_valid = val; } OptionValMap* get_option() override { return &m_option; } void update() override; private: FuseConvBiasNonlinearOption() = default; template void config_model_internel(RuntimeParam&, std::shared_ptr){}; std::string m_option_name; bool enable_fuse_conv_bias_nonlinearity; static bool m_valid; OptionValMap m_option; }; ///////////////////////// fuse_conv_bias_with_z optimize options ////////////// class FuseConvBiasElemwiseAddOption final : public OptionBase { public: static bool is_valid(); static std::shared_ptr create_option(); void config_model( RuntimeParam& runtime_param, std::shared_ptr model) override; std::string option_name() const override { return m_option_name; }; static void set_valid(bool val) { m_valid = val; } OptionValMap* get_option() override { return &m_option; } void update() override; private: FuseConvBiasElemwiseAddOption() = default; template void config_model_internel(RuntimeParam&, std::shared_ptr){}; std::string m_option_name; bool enable_fuse_conv_bias_with_z; static bool m_valid; OptionValMap m_option; }; ///////////////////////// graph record options /////////////////////////// class GraphRecordOption final : public OptionBase { public: static bool is_valid(); static std::shared_ptr create_option(); void config_model( RuntimeParam& runtime_param, std::shared_ptr model) override; std::string option_name() const override { return m_option_name; }; static void set_valid(bool val) { m_valid = val; } OptionValMap* get_option() override { return &m_option; } void update() override; private: GraphRecordOption() = default; template void config_model_internel(RuntimeParam&, std::shared_ptr){}; std::string m_option_name; size_t m_record_comp_seq; bool const_shape; bool fake_first; bool no_sanity_check; static bool m_valid; OptionValMap m_option; }; ///////////////////////// memory optimize options ///////////////////////// class MemoryOptimizeOption final : public OptionBase { public: static bool is_valid(); static std::shared_ptr create_option(); void config_model( RuntimeParam& runtime_param, std::shared_ptr model) override; std::string option_name() const override { return m_option_name; }; void update() override; private: MemoryOptimizeOption() = default; template void config_model_internel(RuntimeParam&, std::shared_ptr){}; std::string m_option_name; bool disable_mem_opt; uint64_t workspace_limit; }; ///////////////////////// other options for optimization ///////////////// class JITOption final : public OptionBase { public: static bool is_valid(); static std::shared_ptr create_option(); void config_model( RuntimeParam& runtime_param, std::shared_ptr model) override; std::string option_name() const override { return m_option_name; }; void update() override; private: JITOption() = default; template void config_model_internel(RuntimeParam&, std::shared_ptr){}; std::string m_option_name; bool enable_jit; }; ///////////////////////// TensorRT options for optimization ///////////////// #if MGB_ENABLE_TENSOR_RT class TensorRTOption final : public OptionBase { public: static bool is_valid(); static std::shared_ptr create_option(); void config_model( RuntimeParam& runtime_param, std::shared_ptr model) override; std::string option_name() const override { return m_option_name; }; void update() override; private: TensorRTOption() = default; template void config_model_internel(RuntimeParam&, std::shared_ptr){}; std::string m_option_name; bool enable_tensorrt; std::string tensorrt_cache; }; #endif } // namespace lar