You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_mdl.cpp 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #include "model_mdl.h"
  2. #include <gflags/gflags.h>
  3. #include <iostream>
  4. DECLARE_bool(share_param_mem);
  5. using namespace lar;
  6. ModelMdl::ModelMdl(const std::string& path) : model_path(path) {
  7. mgb_log("creat mdl model use XPU as default comp node");
  8. m_load_config.comp_graph = mgb::ComputingGraph::make();
  9. m_load_config.comp_graph->options().graph_opt_level = 0;
  10. testcase_num = 0;
  11. }
  12. void ModelMdl::load_model() {
  13. //! read dump file
  14. if (share_model_mem) {
  15. mgb_log("enable share model memory");
  16. FILE* fin = fopen(model_path.c_str(), "rb");
  17. mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
  18. fseek(fin, 0, SEEK_END);
  19. size_t size = ftell(fin);
  20. fseek(fin, 0, SEEK_SET);
  21. void* ptr = malloc(size);
  22. std::shared_ptr<void> buf{ptr, free};
  23. auto nr = fread(buf.get(), 1, size, fin);
  24. mgb_assert(nr == size, "read model file failed");
  25. fclose(fin);
  26. m_model_file = mgb::serialization::InputFile::make_mem_proxy(buf, size);
  27. } else {
  28. m_model_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
  29. }
  30. //! get dump_with_testcase model testcase number
  31. char magic[8];
  32. m_model_file->read(magic, sizeof(magic));
  33. if (strncmp(magic, "mgbtest0", 8)) {
  34. m_model_file->rewind();
  35. } else {
  36. m_model_file->read(&testcase_num, sizeof(testcase_num));
  37. }
  38. m_format =
  39. mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file);
  40. mgb_assert(
  41. m_format.valid(),
  42. "invalid format, please make sure model is dumped by GraphDumper");
  43. //! load computing graph of model
  44. m_loader = mgb::serialization::GraphLoader::make(
  45. std::move(m_model_file), m_format.val());
  46. m_load_result = m_loader->load(m_load_config, false);
  47. m_load_config.comp_graph.reset();
  48. // get testcase input generated by dump_with_testcase.py
  49. if (testcase_num) {
  50. for (auto&& i : m_load_result.tensor_map) {
  51. test_input_tensors.emplace_back(i.first, i.second.get());
  52. }
  53. std::sort(test_input_tensors.begin(), test_input_tensors.end());
  54. }
  55. // initialize output callback
  56. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  57. mgb::ComputingGraph::Callback cb;
  58. m_callbacks.push_back(cb);
  59. }
  60. }
  61. void ModelMdl::make_output_spec() {
  62. for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
  63. auto item = m_load_result.output_var_list[i];
  64. m_output_spec.emplace_back(item, std::move(m_callbacks[i]));
  65. }
  66. m_asyc_exec = m_load_result.graph_compile(m_output_spec);
  67. }
  68. std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader(
  69. std::unique_ptr<mgb::serialization::InputFile> input_file) {
  70. if (input_file) {
  71. m_loader = mgb::serialization::GraphLoader::make(
  72. std::move(input_file), m_loader->format());
  73. } else {
  74. m_loader = mgb::serialization::GraphLoader::make(
  75. m_loader->reset_file(), m_loader->format());
  76. }
  77. return m_loader;
  78. }
  79. void ModelMdl::run_model() {
  80. mgb_assert(
  81. m_asyc_exec != nullptr,
  82. "empty asychronous function to execute after graph compiled");
  83. m_asyc_exec->execute();
  84. }
  85. void ModelMdl::wait() {
  86. m_asyc_exec->wait();
  87. }
  88. #if MGB_ENABLE_JSON
  89. std::shared_ptr<mgb::json::Object> ModelMdl::get_io_info() {
  90. std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make();
  91. std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make();
  92. auto get_dtype = [&](megdnn::DType data_type) {
  93. std::map<megdnn::DTypeEnum, std::string> type_map = {
  94. {mgb::dtype::Float32().enumv(), "float32"},
  95. {mgb::dtype::Int32().enumv(), "int32"},
  96. {mgb::dtype::Int16().enumv(), "int16"},
  97. {mgb::dtype::Uint16().enumv(), "uint16"},
  98. {mgb::dtype::Int8().enumv(), "int8"},
  99. {mgb::dtype::Uint8().enumv(), "uint8"}};
  100. return type_map[data_type.enumv()];
  101. };
  102. auto make_shape = [](mgb::TensorShape& shape_) {
  103. std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
  104. shape;
  105. for (size_t i = 0; i < shape_.ndim; ++i) {
  106. std::string lable = "dim";
  107. lable += std::to_string(shape_.ndim - i - 1);
  108. shape.push_back(
  109. {mgb::json::String(lable),
  110. mgb::json::NumberInt::make(shape_[shape_.ndim - i - 1])});
  111. }
  112. return shape;
  113. };
  114. for (auto&& i : m_load_result.tensor_map) {
  115. std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
  116. json_inp;
  117. auto shape_ = i.second->shape();
  118. json_inp.push_back(
  119. {mgb::json::String("shape"),
  120. mgb::json::Object::make(make_shape(shape_))});
  121. json_inp.push_back(
  122. {mgb::json::String("dtype"),
  123. mgb::json::String::make(get_dtype(i.second->dtype()))});
  124. json_inp.push_back(
  125. {mgb::json::String("name"), mgb::json::String::make(i.first)});
  126. inputs->add(mgb::json::Object::make(json_inp));
  127. }
  128. for (auto&& i : m_load_result.output_var_list) {
  129. std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
  130. json_out;
  131. auto shape_ = i.shape();
  132. json_out.push_back(
  133. {mgb::json::String("shape"),
  134. mgb::json::Object::make(make_shape(shape_))});
  135. json_out.push_back(
  136. {mgb::json::String("dtype"),
  137. mgb::json::String::make(get_dtype(i.dtype()))});
  138. json_out.push_back(
  139. {mgb::json::String("name"), mgb::json::String::make(i.node()->name())});
  140. outputs->add(mgb::json::Object::make(json_out));
  141. }
  142. return mgb::json::Object::make(
  143. {{"IO",
  144. mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}});
  145. }
  146. #endif
  147. std::vector<uint8_t> ModelMdl::get_model_data() {
  148. std::vector<uint8_t> out_data;
  149. auto out_file = mgb::serialization::OutputFile::make_vector_proxy(&out_data);
  150. using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
  151. DumpConfig config{1, false, false};
  152. auto dumper =
  153. mgb::serialization::GraphDumper::make(std::move(out_file), m_format.val());
  154. dumper->dump(m_load_result.output_var_list, config);
  155. return out_data;
  156. }