You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

io_options.h 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /**
  2. * \file lite/load_and_run/src/options/io_options.h
  3. *
  4. * This file is part of MegEngine, a deep learning framework developed by
  5. * Megvii.
  6. *
  7. * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  8. */
  9. #pragma once
  10. #include <gflags/gflags.h>
  11. #include "helpers/outdumper.h"
  12. #include "megbrain/plugin/opr_io_dump.h"
  13. #include "models/model.h"
  14. #include "option_base.h"
  15. DECLARE_string(input);
  16. DECLARE_string(io_dump);
  17. DECLARE_bool(io_dump_stdout);
  18. DECLARE_bool(io_dump_stderr);
  19. DECLARE_string(bin_io_dump);
  20. DECLARE_string(bin_out_dump);
  21. DECLARE_bool(copy_to_host);
  22. namespace lar {
  23. /*!
  24. * \brief: input option for --input set
  25. */
  26. class InputOption final : public OptionBase {
  27. public:
  28. //! static function for registe options
  29. static bool is_valid() { return !FLAGS_input.empty(); };
  30. static std::shared_ptr<OptionBase> create_option();
  31. void config_model(
  32. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  33. //! interface implement from OptionBase
  34. std::string option_name() const override { return m_option_name; };
  35. private:
  36. InputOption();
  37. template <typename ModelImpl>
  38. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  39. std::string m_option_name;
  40. std::vector<std::string> data_path; // data string or data file path
  41. };
  42. class IOdumpOption : public OptionBase {
  43. public:
  44. static bool is_valid();
  45. static std::shared_ptr<OptionBase> create_option();
  46. //! config the model, if different has different configure code, then
  47. //! dispatch
  48. void config_model(
  49. RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
  50. std::string option_name() const override { return m_option_name; };
  51. private:
  52. IOdumpOption();
  53. template <typename ModelImpl>
  54. void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
  55. bool enable_io_dump;
  56. bool enable_io_dump_stdout;
  57. bool enable_io_dump_stderr;
  58. bool enable_bin_io_dump;
  59. bool enable_bin_out_dump;
  60. bool enable_copy_to_host;
  61. std::string m_option_name;
  62. std::string dump_path;
  63. std::unique_ptr<mgb::OprIODumpBase> io_dumper;
  64. std::unique_ptr<OutputDumper> out_dumper;
  65. };
  66. } // namespace lar

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台