You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer.cpp 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #include "megbrain/serialization/serializer.h"
  2. #include "megbrain/gopt/inference.h"
  3. #include "megbrain/opr/utility.h"
  4. namespace mgb {
  5. namespace serialization {
  6. /* ====================== helper impls ====================== */
  7. GraphLoader::LoadResult::~LoadResult() noexcept = default;
  8. std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
  9. const ComputingGraph::OutputSpec& outspec) {
  10. auto ret = graph->compile(outspec);
  11. if (graph->options().comp_node_seq_record_level == 2) {
  12. ComputingGraph::assert_destroy(graph);
  13. }
  14. return ret;
  15. }
  16. void GraphLoader::LoadResult::update_output_var_list(
  17. const SymbolVarArray& output_var_array) {
  18. mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
  19. mgb_assert(output_var_array.size() == output_var_list.size());
  20. // replace symvar in output_var_list
  21. for (size_t idx = 0; idx < output_var_array.size(); ++idx) {
  22. out_var_map[output_var_list[idx]] = output_var_array[idx];
  23. output_var_list[idx] = output_var_array[idx];
  24. }
  25. // replace symvar in output_var_map_id
  26. for (auto&& item : output_var_map_id) {
  27. item.second = out_var_map[item.second];
  28. }
  29. // replace symvar in output_var_map
  30. for (auto&& item : output_var_map) {
  31. item.second = out_var_map[item.second].rename(item.first);
  32. }
  33. }
  34. void GraphLoader::LoadResult::graph_compile_ahead() {
  35. //! when force_output_use_user_specified_memory is set, the output var may
  36. //! be changed by gopt, then the var in LoadResult can not exist, so here
  37. //! just do basic optimize_for_inference ahead, and replace the var in
  38. //! LoadResult
  39. if (graph->options().force_output_use_user_specified_memory) {
  40. auto options = gopt::OptimizeForInferenceOptions{};
  41. auto new_vars = gopt::optimize_for_inference(output_var_list, options);
  42. output_var_list = new_vars;
  43. output_var_map.clear();
  44. for (auto& var : new_vars) {
  45. output_var_map[var.node()->cname()] = var;
  46. }
  47. std::unordered_map<size_t, SymbolVar> var_map_id;
  48. for (auto& var : new_vars) {
  49. bool found = false;
  50. for (auto& old_var_it : output_var_map_id) {
  51. if (old_var_it.second.node()->name() == var.node()->name()) {
  52. found = true;
  53. var_map_id[old_var_it.first] = var;
  54. }
  55. }
  56. mgb_assert(
  57. found, "can't find var name %s when optimize_for_inference. ",
  58. var.node()->cname());
  59. }
  60. }
  61. }
  62. GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
  63. SharedTensorNameMap ret;
  64. for (auto&& i : shared_tensor_id_map()) {
  65. mgb_assert(!i.first.empty(), "name stripped during graph dump");
  66. auto ins = ret.emplace(i.first, &i.second);
  67. mgb_assert(ins.second);
  68. }
  69. return ret;
  70. }
  71. std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
  72. std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file);
  73. std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file);
  74. std::unique_ptr<GraphDumper> make_fbs_v2_dumper(
  75. std::unique_ptr<OutputFile> file, int version);
  76. bool is_fbs_file(InputFile& file);
  77. bool is_fbs_v2_file(InputFile& file);
  78. bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
  79. #if MGB_ENABLE_GRAD
  80. return opr->same_type<opr::SetGrad>();
  81. #else
  82. return false;
  83. #endif
  84. }
  85. std::unique_ptr<GraphDumper> GraphDumper::make(
  86. std::unique_ptr<OutputFile> file, GraphDumpFormat format, int version) {
  87. switch (format) {
  88. case GraphDumpFormat::FLATBUFFERS:
  89. #if MGB_ENABLE_FBS_SERIALIZATION
  90. return make_fbs_dumper(std::move(file));
  91. #endif
  92. MGB_FALLTHRU
  93. case GraphDumpFormat::FLATBUFFERS_V2:
  94. #if MGB_ENABLE_FBS_SERIALIZATION
  95. return make_fbs_v2_dumper(std::move(file), version);
  96. #endif
  97. MGB_FALLTHRU
  98. default:
  99. mgb_throw(SerializationError, "unsupported serialization format requested");
  100. }
  101. mgb_assert(false, "unreachable");
  102. }
  103. std::unique_ptr<GraphLoader> GraphLoader::make(
  104. std::unique_ptr<InputFile> file, GraphDumpFormat format) {
  105. switch (format) {
  106. case GraphDumpFormat::FLATBUFFERS:
  107. #if MGB_ENABLE_FBS_SERIALIZATION
  108. return make_fbs_loader(std::move(file));
  109. #endif
  110. MGB_FALLTHRU
  111. case GraphDumpFormat::FLATBUFFERS_V2:
  112. #if MGB_ENABLE_FBS_SERIALIZATION
  113. return make_fbs_v2_loader(std::move(file));
  114. #endif
  115. MGB_FALLTHRU
  116. default:
  117. mgb_throw(SerializationError, "unsupported serialization format requested");
  118. }
  119. mgb_assert(false, "unreachable");
  120. }
  121. Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file) {
  122. #if MGB_ENABLE_FBS_SERIALIZATION
  123. if (is_fbs_file(file)) {
  124. return GraphDumpFormat::FLATBUFFERS;
  125. }
  126. if (is_fbs_v2_file(file)) {
  127. return GraphDumpFormat::FLATBUFFERS_V2;
  128. }
  129. #endif
  130. return {};
  131. }
  132. } // namespace serialization
  133. } // namespace mgb