You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_lite.cpp 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #include "model_lite.h"
  2. #include <gflags/gflags.h>
  3. #include <cstring>
  4. #include <map>
  5. #include "misc.h"
  6. DECLARE_bool(share_param_mem);
  7. using namespace lar;
  8. ModelLite::ModelLite(const std::string& path) : model_path(path) {
  9. LITE_LOG("creat lite model use CPU as default comp node");
  10. };
  11. void ModelLite::create_network() {
  12. m_network = std::make_shared<lite::Network>(config, IO);
  13. }
  14. void ModelLite::load_model() {
  15. if (share_model_mem) {
  16. //! WARNNING:maybe not right to share param memmory for this
  17. LITE_LOG("enable share model memory");
  18. FILE* fin = fopen(model_path.c_str(), "rb");
  19. LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
  20. fseek(fin, 0, SEEK_END);
  21. size_t size = ftell(fin);
  22. fseek(fin, 0, SEEK_SET);
  23. void* ptr = malloc(size);
  24. std::shared_ptr<void> buf{ptr, free};
  25. auto nr = fread(buf.get(), 1, size, fin);
  26. LITE_ASSERT(nr == size, "read model file failed");
  27. fclose(fin);
  28. m_network->load_model(buf.get(), size);
  29. } else {
  30. m_network->load_model(model_path);
  31. }
  32. }
  33. void ModelLite::run_model() {
  34. m_network->forward();
  35. }
  36. void ModelLite::wait() {
  37. m_network->wait();
  38. }
  39. #if MGB_ENABLE_JSON
  40. std::shared_ptr<mgb::json::Object> ModelLite::get_io_info() {
  41. std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make();
  42. std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make();
  43. auto get_dtype = [&](lite::Layout& layout) {
  44. std::map<LiteDataType, std::string> type_map = {
  45. {LiteDataType::LITE_FLOAT, "float32"},
  46. {LiteDataType::LITE_HALF, "float16"},
  47. {LiteDataType::LITE_INT64, "int64"},
  48. {LiteDataType::LITE_INT, "int32"},
  49. {LiteDataType::LITE_UINT, "uint32"},
  50. {LiteDataType::LITE_INT16, "int16"},
  51. {LiteDataType::LITE_UINT16, "uint16"},
  52. {LiteDataType::LITE_INT8, "int8"},
  53. {LiteDataType::LITE_UINT8, "uint8"}};
  54. return type_map[layout.data_type];
  55. };
  56. auto make_shape = [](lite::Layout& layout) {
  57. std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
  58. shape;
  59. for (size_t i = 0; i < layout.ndim; ++i) {
  60. std::string lable = "dim";
  61. lable += std::to_string(layout.ndim - i - 1);
  62. shape.push_back(
  63. {mgb::json::String(lable),
  64. mgb::json::NumberInt::make(layout.shapes[layout.ndim - i - 1])});
  65. }
  66. return shape;
  67. };
  68. auto input_name = m_network->get_all_input_name();
  69. for (auto& i : input_name) {
  70. std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
  71. json_inp;
  72. auto layout = m_network->get_io_tensor(i)->get_layout();
  73. json_inp.push_back(
  74. {mgb::json::String("shape"),
  75. mgb::json::Object::make(make_shape(layout))});
  76. json_inp.push_back(
  77. {mgb::json::String("dtype"),
  78. mgb::json::String::make(get_dtype(layout))});
  79. json_inp.push_back({mgb::json::String("name"), mgb::json::String::make(i)});
  80. inputs->add(mgb::json::Object::make(json_inp));
  81. }
  82. auto output_name = m_network->get_all_output_name();
  83. for (auto& i : output_name) {
  84. std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
  85. json_out;
  86. auto layout = m_network->get_io_tensor(i)->get_layout();
  87. json_out.push_back(
  88. {mgb::json::String("shape"),
  89. mgb::json::Object::make(make_shape(layout))});
  90. json_out.push_back(
  91. {mgb::json::String("dtype"),
  92. mgb::json::String::make(get_dtype(layout))});
  93. json_out.push_back({mgb::json::String("name"), mgb::json::String::make(i)});
  94. inputs->add(mgb::json::Object::make(json_out));
  95. }
  96. return mgb::json::Object::make(
  97. {{"IO",
  98. mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}});
  99. }
  100. #endif
  101. std::vector<uint8_t> ModelLite::get_model_data() {
  102. std::vector<uint8_t> out_data;
  103. LITE_THROW("unsupported interface: ModelLite::get_model_data() \n");
  104. return out_data;
  105. }