You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

io_options.h 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #pragma once
  2. #include <gflags/gflags.h>
  3. #include "helpers/outdumper.h"
  4. #include "megbrain/plugin/opr_io_dump.h"
  5. #include "models/model.h"
  6. #include "option_base.h"
  7. DECLARE_string(input);
  8. DECLARE_string(io_dump);
  9. DECLARE_bool(io_dump_stdout);
  10. DECLARE_bool(io_dump_stderr);
  11. DECLARE_string(bin_io_dump);
  12. DECLARE_string(bin_out_dump);
  13. DECLARE_bool(copy_to_host);
  14. DECLARE_int32(batch_size);
  15. namespace lar {
  16. /*!
  17. * \brief: input option for --input set
  18. */
  19. class InputOption final : public OptionBase {
  20. public:
  21. //! static function for registe options
  22. static bool is_valid() { return !FLAGS_input.empty() || FLAGS_batch_size > 0; };
  23. static std::shared_ptr<OptionBase> create_option();
  24. void config_model(
  25. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  26. //! interface implement from OptionBase
  27. std::string option_name() const override { return m_option_name; };
  28. void update() override;
  29. private:
  30. InputOption() = default;
  31. template <typename ModelImpl>
  32. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  33. std::string m_option_name;
  34. std::vector<std::string> data_path; // data string or data file path
  35. int32_t m_force_batch_size;
  36. };
  37. class IOdumpOption : public OptionBase {
  38. public:
  39. static bool is_valid();
  40. static std::shared_ptr<OptionBase> create_option();
  41. //! config the model, if different has different configure code, then
  42. //! dispatch
  43. void config_model(
  44. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  45. std::string option_name() const override { return m_option_name; };
  46. void update() override;
  47. private:
  48. IOdumpOption() = default;
  49. template <typename ModelImpl>
  50. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  51. bool enable_io_dump;
  52. bool enable_io_dump_stdout;
  53. bool enable_io_dump_stderr;
  54. bool enable_bin_io_dump;
  55. bool enable_bin_out_dump;
  56. bool enable_copy_to_host;
  57. std::string m_option_name;
  58. std::string dump_path;
  59. std::unique_ptr<mgb::OprIODumpBase> io_dumper;
  60. std::unique_ptr<OutputDumper> out_dumper;
  61. };
  62. } // namespace lar