You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. /**
  2. * \file src/serialization/impl/serializer.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/serialization/serializer.h"
  12. #include "megbrain/gopt/inference.h"
  13. #include "megbrain/opr/utility.h"
  14. namespace mgb {
  15. namespace serialization {
  16. /* ====================== helper impls ====================== */
  17. GraphLoader::LoadResult::~LoadResult() noexcept = default;
  18. std::unique_ptr<cg::AsyncExecutable> GraphLoader::LoadResult::graph_compile(
  19. const ComputingGraph::OutputSpec& outspec) {
  20. auto ret = graph->compile(outspec);
  21. if (graph->options().comp_node_seq_record_level == 2) {
  22. ComputingGraph::assert_destroy(graph);
  23. }
  24. return ret;
  25. }
  26. void GraphLoader::LoadResult::graph_compile_ahead() {
  27. //! when force_output_use_user_specified_memory is set, the output var may
  28. //! be changed by gopt, then the var in LoadResult can not exist, so here
  29. //! just do basic optimize_for_inference ahead, and replace the var in
  30. //! LoadResult
  31. if (graph->options().force_output_use_user_specified_memory) {
  32. auto options = gopt::OptimizeForInferenceOptions{};
  33. auto new_vars = gopt::optimize_for_inference(output_var_list, options);
  34. output_var_list = new_vars;
  35. output_var_map.clear();
  36. for (auto& var : new_vars) {
  37. output_var_map[var.node()->cname()] = var;
  38. }
  39. std::unordered_map<size_t, SymbolVar> var_map_id;
  40. for (auto& var : new_vars) {
  41. bool found = false;
  42. for (auto& old_var_it : output_var_map_id) {
  43. if (old_var_it.second.node()->name() == var.node()->name()) {
  44. found = true;
  45. var_map_id[old_var_it.first] = var;
  46. }
  47. }
  48. mgb_assert(
  49. found, "can't find var name %s when optimize_for_inference. ",
  50. var.node()->cname());
  51. }
  52. }
  53. }
  54. GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() {
  55. SharedTensorNameMap ret;
  56. for (auto&& i : shared_tensor_id_map()) {
  57. mgb_assert(!i.first.empty(), "name stripped during graph dump");
  58. auto ins = ret.emplace(i.first, &i.second);
  59. mgb_assert(ins.second);
  60. }
  61. return ret;
  62. }
  63. std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
  64. std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file);
  65. bool is_fbs_file(InputFile& file);
  66. bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
  67. #if MGB_ENABLE_GRAD
  68. return opr->same_type<opr::SetGrad>();
  69. #else
  70. return false;
  71. #endif
  72. }
  73. std::unique_ptr<GraphDumper> GraphDumper::make(
  74. std::unique_ptr<OutputFile> file, GraphDumpFormat format) {
  75. switch (format) {
  76. case GraphDumpFormat::FLATBUFFERS:
  77. #if MGB_ENABLE_FBS_SERIALIZATION
  78. return make_fbs_dumper(std::move(file));
  79. #endif
  80. MGB_FALLTHRU
  81. default:
  82. mgb_throw(SerializationError, "unsupported serialization format requested");
  83. }
  84. mgb_assert(false, "unreachable");
  85. }
  86. std::unique_ptr<GraphLoader> GraphLoader::make(
  87. std::unique_ptr<InputFile> file, GraphDumpFormat format) {
  88. switch (format) {
  89. case GraphDumpFormat::FLATBUFFERS:
  90. #if MGB_ENABLE_FBS_SERIALIZATION
  91. return make_fbs_loader(std::move(file));
  92. #endif
  93. MGB_FALLTHRU
  94. default:
  95. mgb_throw(SerializationError, "unsupported serialization format requested");
  96. }
  97. mgb_assert(false, "unreachable");
  98. }
  99. Maybe<GraphDumpFormat> GraphLoader::identify_graph_dump_format(InputFile& file) {
  100. #if MGB_ENABLE_FBS_SERIALIZATION
  101. if (is_fbs_file(file)) {
  102. return GraphDumpFormat::FLATBUFFERS;
  103. }
  104. #endif
  105. return {};
  106. }
  107. } // namespace serialization
  108. } // namespace mgb