You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layout_trans_options.h 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #pragma once
  2. #include <gflags/gflags.h>
  3. #include "megbrain/gopt/inference.h"
  4. #include "models/model.h"
  5. #include "option_base.h"
  6. DECLARE_string(layout_transform);
  7. DECLARE_int32(layout_transform_batch_size);
  8. DECLARE_string(layout_transform_dump);
  9. namespace lar {
  10. class GoptLayoutOption final : public OptionBase {
  11. public:
  12. //! get condition for construct FastRunOption
  13. static bool is_valid();
  14. //! creat option using condition from cmdline args
  15. static std::shared_ptr<OptionBase> create_option();
  16. //! configure model for different runtime_param
  17. void config_model(
  18. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  19. //! get options name for quickly search
  20. std::string option_name() const override { return m_option_name; }
  21. static void set_valid(bool val) { m_valid = val; }
  22. OptionValMap* get_option() override { return &m_option; }
  23. void update() override;
  24. private:
  25. GoptLayoutOption() = default;
  26. //! config template for different model
  27. template <typename ModelImpl>
  28. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {}
  29. bool m_layout_transform;
  30. std::string m_option_name;
  31. std::string m_layout_transform_dump_file;
  32. mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
  33. static bool m_valid;
  34. OptionValMap m_option;
  35. int32_t m_force_batch_size;
  36. };
  37. } // namespace lar