You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_oss_v2.cpp 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884
  1. #if MGB_ENABLE_FBS_SERIALIZATION
  2. #include "megbrain/comp_node_env.h"
  3. #include "megbrain/opr/io.h"
  4. #include "megbrain/serialization/helper.h"
  5. #include "megbrain/serialization/internal/flatbuffers_helper.h"
  6. #include "megbrain/serialization/internal/schema_v2_generated.h"
  7. #include "megbrain/serialization/metadata.h"
  8. #include "megbrain/serialization/opr_load_dump.h"
  9. #include "megbrain/serialization/oss_opr_load_dump.h"
  10. #include "megbrain/utils/hash_ct.h"
  11. #include "megdnn/tensor_format.h"
  12. #include "serializer_oss_common.h"
  13. #include "megbrain/gopt/framework.h"
  14. namespace mgb {
  15. namespace serialization {
  16. /*!
  17. * \brief replace the the opr who has the replace_opr methord in OprLoadDumpImplV2
  18. */
  19. class PassConvertToCompatible : public gopt::Pass {
  20. ThinHashMap<
  21. Typeinfo*, thin_function<cg::OperatorNodeBase*(
  22. cg::OperatorNodeBase*, const VarNodeArray&)>>
  23. m_opr_replace_func;
  24. gopt::VarReplaceCheckFlag m_var_replace_check_flag =
  25. gopt::VarReplaceCheckFlag::CHECK_ALL;
  26. public:
  27. const char* name() const override { return "PassConvertToCompatible"; };
  28. PassConvertToCompatible& set_var_replace_check_flag(
  29. gopt::VarReplaceCheckFlag flag) {
  30. m_var_replace_check_flag = flag;
  31. return *this;
  32. }
  33. void apply(gopt::OptState& state) const override {
  34. state.set_var_replace_check_flag(m_var_replace_check_flag);
  35. auto rewriter = state.graph().make_rewriter();
  36. auto on_opr = [this, &rewriter](cg::OperatorNodeBase* opr) {
  37. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  38. if (it != m_opr_replace_func.end()) {
  39. VarNodeArray new_inp;
  40. new_inp.clear();
  41. new_inp.reserve(opr->input().size());
  42. for (auto i : opr->input()) {
  43. new_inp.push_back(rewriter.get_var(i));
  44. }
  45. auto new_opr = (it->second)(opr, new_inp);
  46. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  47. for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size());
  48. i++) {
  49. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  50. }
  51. } else {
  52. rewriter.auto_replace_outputs(opr);
  53. }
  54. };
  55. state.graph().iter(on_opr);
  56. rewriter.apply_inplace();
  57. }
  58. static std::unique_ptr<PassConvertToCompatible> make(
  59. const SymbolVarArray& output_vars) {
  60. auto ret = std::make_unique<PassConvertToCompatible>();
  61. // iterate oprs to init
  62. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  63. if (!GraphDumper::should_remove_in_dump(opr)) {
  64. auto registry = OprRegistryV2::versioned_find_by_typeinfo(
  65. opr->dyn_typeinfo(), CURRENT_VERSION);
  66. mgb_throw_if(
  67. !registry,
  68. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  69. "serialization as FlatBuffers is not supported for "
  70. "operator %s, typeinfo %p",
  71. opr->dyn_typeinfo()->name, opr->dyn_typeinfo());
  72. if (registry->converter) {
  73. ret->m_opr_replace_func[opr->dyn_typeinfo()] = registry->converter;
  74. }
  75. }
  76. };
  77. cg::DepOprIter dep_opr_iter{on_opr};
  78. for (auto i : output_vars) {
  79. dep_opr_iter.add(i.node()->owner_opr());
  80. }
  81. return ret;
  82. };
  83. };
  84. namespace {
  85. fbs::v2::TensorFormat get_flatbuffer_tensor_format_type(
  86. const TensorLayout::Format& format) {
  87. using Type = megdnn::TensorFormat::Type;
  88. switch (format.type()) {
  89. case Type::DEFAULT:
  90. return fbs::v2::TensorFormat::TensorFormat_DefaultTensorFormat;
  91. case Type::IMAGE2D_PACK4:
  92. return fbs::v2::TensorFormat::TensorFormat_Image2DPackedTensorFormat;
  93. case Type::LOWBITS_ALIGNED_TO_BYTE:
  94. return fbs::v2::TensorFormat::TensorFormat_LowbitsAlignedTensorFormat;
  95. default:
  96. mgb_throw(
  97. SerializationError, "invalid tensor format type in serialization.");
  98. }
  99. }
  100. } // namespace
  101. flatbuffers::Offset<fbs::DType> GraphDumperOSSV2::build_dtype(DType dtype) {
  102. return fbs::intl::build_dtype(m_builder, dtype);
  103. }
  104. flatbuffers::Offset<void> GraphDumperOSSV2::build_tensor_format(
  105. const TensorLayout::Format& format) {
  106. using Type = megdnn::TensorFormat::Type;
  107. switch (format.type()) {
  108. case Type::DEFAULT:
  109. return fbs::v2::CreateDefaultTensorFormat(m_builder).Union();
  110. case Type::IMAGE2D_PACK4:
  111. return fbs::v2::CreateImage2DPackedTensorFormat(
  112. m_builder, format.as_impl<megdnn::Image2DPack4TensorFormat>()
  113. .align_axis())
  114. .Union();
  115. case Type::LOWBITS_ALIGNED_TO_BYTE: {
  116. auto size_bite = format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
  117. .size_nbits();
  118. auto align_size_in_bits =
  119. format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
  120. .align_size_in_bits();
  121. return fbs::v2::CreateLowbitsAlignedTensorFormat(
  122. m_builder, size_bite, align_size_in_bits)
  123. .Union();
  124. }
  125. default:
  126. mgb_throw(
  127. SerializationError, "invalid tensor format type in serialization.");
  128. }
  129. }
  130. flatbuffers::Offset<fbs::v2::MiddleTensor> GraphDumperOSSV2::build_middle_tensor(
  131. const SymbolVar var) {
  132. mgb_assert(var.node());
  133. auto fbname = m_builder.CreateSharedString(var.node()->name());
  134. flatbuffers::Offset<fbs::v2::MiddleTensor> serialized_middle_tensor;
  135. if (var.node()->dev_tensor_valid()) {
  136. auto layout = var.node()->layout();
  137. auto fshape =
  138. m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
  139. auto fcomp_node = fbs::v2::CreateCompNode(
  140. m_builder, m_builder.CreateSharedString(
  141. var.node()->comp_node().to_string_logical()));
  142. auto fdtype = build_dtype(layout.dtype);
  143. auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
  144. auto fformat = build_tensor_format(layout.format);
  145. serialized_middle_tensor = fbs::v2::CreateMiddleTensor(
  146. m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat);
  147. }
  148. serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
  149. return serialized_middle_tensor;
  150. }
  151. flatbuffers::Offset<fbs::v2::OutputVar> GraphDumperOSSV2::build_output_var(
  152. const SymbolVar var) {
  153. auto out_node = var.node();
  154. if (m_var2midtensor_id.find(var.node()) == m_var2midtensor_id.end()) {
  155. mgb_assert(m_var_remove_in_dump.find(var.node()) != m_var_remove_in_dump.end());
  156. out_node = m_var_remove_in_dump[var.node()];
  157. }
  158. return fbs::v2::CreateOutputVar(
  159. m_builder, m_var2midtensor_id.at(out_node), var.node()->id());
  160. }
  161. void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
  162. m_oprs_to_dump.clear();
  163. // iterate oprs to init
  164. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  165. if (should_remove_in_dump(opr)) {
  166. mgb_assert(opr->input().size() == 1);
  167. // Copy input ID to output
  168. for (auto i : opr->output()) {
  169. if (m_var_remove_in_dump.find(opr->input(0)) !=
  170. m_var_remove_in_dump.end()) {
  171. m_var_remove_in_dump[i] = m_var_remove_in_dump[opr->input(0)];
  172. } else {
  173. m_var_remove_in_dump[i] = opr->input(0);
  174. }
  175. }
  176. } else {
  177. auto registry = OprRegistryV2::versioned_find_by_typeinfo(
  178. opr->dyn_typeinfo(), m_version);
  179. if (!registry || !registry->dumper) {
  180. mgb_throw(
  181. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  182. "serialization as FlatBuffers is not supported for "
  183. "operator %s",
  184. opr->dyn_typeinfo()->name);
  185. }
  186. mgb_assert(
  187. registry->version <= m_version,
  188. "The Operator version should less than model version");
  189. m_oprs_to_dump.emplace_back(opr, registry);
  190. }
  191. };
  192. cg::DepOprIter dep_opr_iter{on_opr};
  193. for (auto i : endpoints) {
  194. dep_opr_iter.add(i.node()->owner_opr());
  195. }
  196. }
  197. flatbuffers::Offset<fbs::v2::Metadata> GraphDumperOSSV2::build_metadata(
  198. const Metadata& metadata) {
  199. auto user_info = m_builder.CreateSharedString(metadata.user_info);
  200. fbs::v2::MetadataBuilder builder(m_builder);
  201. builder.add_is_valid(metadata.is_valid);
  202. builder.add_graph_modified(metadata.graph_modified);
  203. builder.add_optimize_options(metadata.optimize_options);
  204. builder.add_user_info(user_info);
  205. return builder.Finish();
  206. }
  207. flatbuffers::Offset<fbs::v2::Operator> GraphDumperOSSV2::build_single_opr(
  208. cg::OperatorNodeBase* opr, const OprRegistryV2* registry) {
  209. m_cur_opr = opr;
  210. ++m_cur_rst.nr_opr;
  211. using namespace flatbuffers;
  212. Offset<Vector<uint32_t>> inputs;
  213. if (m_cur_opr->input().size()) {
  214. std::vector<uint32_t> v;
  215. v.reserve(m_cur_opr->input().size());
  216. for (auto inp : m_cur_opr->input()) {
  217. if (m_var2midtensor_id.find(inp) != m_var2midtensor_id.end()) {
  218. v.emplace_back(m_var2midtensor_id.at(inp));
  219. } else {
  220. mgb_assert(
  221. m_var_remove_in_dump.find(inp) != m_var_remove_in_dump.end(),
  222. "when dump the model, the dependence of var is wrong.");
  223. v.emplace_back(m_var2midtensor_id.at(m_var_remove_in_dump[inp]));
  224. }
  225. }
  226. inputs = m_builder.CreateVector(v);
  227. }
  228. m_cur_opr_tensor.clear();
  229. m_blobs.clear();
  230. m_cur_opr_param.clear();
  231. m_cur_opr_param_type.clear();
  232. registry->dumper(*this, *m_cur_opr);
  233. Offset<Vector<Offset<fbs::v2::CompNode>>> comp_node;
  234. auto& config = m_cur_opr->config();
  235. if (config.has_comp_node_set()) {
  236. std::vector<flatbuffers::Offset<fbs::v2::CompNode>> cns;
  237. for (const auto& cn : config.comp_node()) {
  238. cns.emplace_back(fbs::v2::CreateCompNode(
  239. m_builder, m_builder.CreateSharedString(cn.to_string_logical())));
  240. }
  241. comp_node = m_builder.CreateVector(cns);
  242. }
  243. Offset<String> operator_name;
  244. if (m_config.keep_op_name) {
  245. operator_name = m_builder.CreateSharedString(m_cur_opr->name());
  246. }
  247. auto output_dtype = build_dtype(config.output_dtype());
  248. Offset<Vector<uint32_t>> outputs;
  249. if (m_cur_opr->output().size()) {
  250. std::vector<uint32_t> v;
  251. v.reserve(m_cur_opr->output().size());
  252. for (auto out : m_cur_opr->output()) {
  253. if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  254. auto fbs_out = build_middle_tensor(out);
  255. m_model_middle_tensors.push_back(fbs_out);
  256. m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1;
  257. v.emplace_back(m_var2midtensor_id.at(out));
  258. }
  259. }
  260. outputs = m_builder.CreateVector(v);
  261. }
  262. Offset<Vector<Offset<fbs::v2::Tensor>>> tensors;
  263. if (m_cur_opr_tensor.size())
  264. tensors = m_builder.CreateVector(m_cur_opr_tensor);
  265. //! the blobs data is used by custom data
  266. //! m_blobs will be filled by the Operator dumper function
  267. Offset<Vector<Offset<fbs::v2::Blob>>> blobs;
  268. if (m_blobs.size())
  269. blobs = m_builder.CreateVector(m_blobs);
  270. Offset<Vector<uint8_t>> additional_params_type;
  271. Offset<Vector<Offset<void>>> additional_params;
  272. auto param_cnt = m_cur_opr_param_type.size();
  273. if (param_cnt > 1) {
  274. additional_params_type = m_builder.CreateVectorScalarCast<uint8_t>(
  275. m_cur_opr_param_type.data() + 1, param_cnt - 1);
  276. additional_params =
  277. m_builder.CreateVector(m_cur_opr_param.data() + 1, param_cnt - 1);
  278. }
  279. auto opr_type = m_builder.CreateSharedString(registry->name);
  280. fbs::v2::OperatorBuilder builder(m_builder);
  281. builder.add_type(opr_type);
  282. builder.add_type_id(registry->type_id);
  283. builder.add_inputs(inputs);
  284. builder.add_outputs(outputs);
  285. if (m_config.keep_opr_priority) {
  286. builder.add_priority(opr->node_prop().attribute().priority);
  287. }
  288. builder.add_comp_node(comp_node);
  289. builder.add_opr_version(registry->get_version());
  290. builder.add_name(operator_name);
  291. builder.add_output_dtype(output_dtype);
  292. if (param_cnt > 0) {
  293. builder.add_param_type(m_cur_opr_param_type[0]);
  294. builder.add_param(m_cur_opr_param[0]);
  295. }
  296. if (param_cnt > 1) {
  297. builder.add_additional_params_type(additional_params_type);
  298. builder.add_additional_params(additional_params);
  299. }
  300. builder.add_tensors(tensors);
  301. builder.add_custom_data(blobs);
  302. m_cur_opr = nullptr;
  303. return builder.Finish();
  304. }
  305. SymbolVarArray GraphDumperOSSV2::converter_all_opr_to_compatiable(
  306. const SymbolVarArray& output_vars) {
  307. gopt::GraphOptimizer optimizer;
  308. VarNodeArray rets_var;
  309. for (auto& symbolvar : output_vars) {
  310. rets_var.push_back(symbolvar.node());
  311. }
  312. optimizer.add_pass(PassConvertToCompatible::make(output_vars));
  313. optimizer.apply_inplace(rets_var);
  314. SymbolVarArray dst_vars;
  315. for (auto& var : rets_var) {
  316. dst_vars.push_back({var});
  317. }
  318. return dst_vars;
  319. }
  320. GraphDumper::DumpResult GraphDumperOSSV2::dump(
  321. const SymbolVarArray& output_vars, const DumpConfig& config,
  322. const Metadata& metadata) {
  323. mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph");
  324. auto new_output_vars = output_vars;
  325. if (!config.no_change_graph) {
  326. new_output_vars = converter_all_opr_to_compatiable(output_vars);
  327. }
  328. auto begin_pos = m_file->tell();
  329. m_config = config;
  330. m_builder.Reset();
  331. m_output_vars.clear();
  332. m_cur_rst = {};
  333. m_used_input_names.clear();
  334. m_used_param_names.clear();
  335. m_var_remove_in_dump.clear();
  336. m_model_middle_tensors.clear();
  337. m_var2midtensor_id.clear();
  338. m_nr_shared_tensor = 0;
  339. // process output vars
  340. bool keep_output_var_name = m_config.keep_var_name >= 1;
  341. std::unordered_set<std::string> output_var_names;
  342. for (auto i : new_output_vars) {
  343. mgb_assert(
  344. !i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  345. "can not dump var with VOLATILE_CONTENT flag: %s",
  346. cg::dump_var_info({i.node()}).c_str());
  347. if (m_output_vars.insert(i.node()).second && keep_output_var_name) {
  348. auto name_ins = output_var_names.insert(i.node()->name()).second;
  349. mgb_assert(name_ins, "duplicated output var name: %s", i.node()->cname());
  350. }
  351. }
  352. // Dump metadata
  353. auto fbmeta = build_metadata(metadata);
  354. // Dump operators
  355. init_oprs_to_dump(new_output_vars);
  356. std::vector<flatbuffers::Offset<fbs::v2::Operator>> oprs;
  357. for (auto&& i : m_oprs_to_dump) {
  358. record_opr_dumped(i.second->type_id, i.second->name, i.second->version);
  359. oprs.emplace_back(build_single_opr(i.first, i.second));
  360. }
  361. auto fb_oprs = m_builder.CreateVector(oprs);
  362. // Dump output vars
  363. std::vector<flatbuffers::Offset<fbs::v2::OutputVar>> output_vars_idx;
  364. output_vars_idx.reserve(new_output_vars.size());
  365. for (auto i : new_output_vars) {
  366. auto foutput_vars_idx = build_output_var(i);
  367. output_vars_idx.push_back(foutput_vars_idx);
  368. }
  369. auto fb_output_vars = m_builder.CreateVector(output_vars_idx);
  370. std::vector<flatbuffers::Offset<fbs::v2::OutputAlias>> output_vars_alias;
  371. if (m_config.alias_name_map.size() > 0) {
  372. for (auto&& pair : m_config.alias_name_map) {
  373. std::string name;
  374. SymbolVar var;
  375. std::tie(name, var) = pair;
  376. auto fbs_name = m_builder.CreateSharedString(name);
  377. output_vars_alias.push_back(
  378. fbs::v2::CreateOutputAlias(m_builder, var.node()->id(), fbs_name));
  379. }
  380. }
  381. auto fbs_output_alias = m_builder.CreateVector(output_vars_alias);
  382. auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
  383. fbs::v2::ModelBuilder model(m_builder);
  384. model.add_mge_version(MGB_VERSION);
  385. model.add_model_version(m_version);
  386. model.add_oprs(fb_oprs);
  387. model.add_middle_tensors(fb_mid_tensor);
  388. model.add_output_vars_idx(fb_output_vars);
  389. model.add_output_alias(fbs_output_alias);
  390. model.add_nr_shared_tensor(m_nr_shared_tensor);
  391. model.add_metadata(fbmeta);
  392. m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier());
  393. // Write serialized fbs::Graph
  394. m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize());
  395. // Finalize DumpResult
  396. auto&& ret = m_cur_rst;
  397. for (size_t i = 0; i < new_output_vars.size(); i++) {
  398. ret.outputs.emplace_back(
  399. keep_output_var_name ? new_output_vars[i].node()->cname()
  400. : ssprintf("unnamed%zu", i));
  401. }
  402. std::sort(ret.inputs.begin(), ret.inputs.end());
  403. mgb_assert(ret.nr_opr == m_oprs_to_dump.size());
  404. ret.tot_bytes = m_file->tell() - begin_pos;
  405. return ret;
  406. }
  407. void GraphDumperOSSV2::dump_tensor(
  408. const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) {
  409. using namespace flatbuffers;
  410. using Meth = TensorWriteMethod;
  411. mgb_assert(
  412. (method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()),
  413. "name must be non-empty for non Meth::VALUE_ANONYMOUS tensors");
  414. bool has_value = method != Meth::META_INPUT;
  415. bool should_keep_name = true;
  416. switch (method) {
  417. case Meth::VALUE_ANONYMOUS:
  418. should_keep_name = false;
  419. break;
  420. case Meth::VALUE_SHARED:
  421. should_keep_name = m_config.keep_param_name;
  422. ++m_nr_shared_tensor;
  423. if (m_config.keep_param_name) {
  424. mgb_assert(
  425. m_used_param_names.insert(name).second,
  426. "duplicated VALUE_SHARED tensor name: %s", name.c_str());
  427. m_cur_rst.params.emplace_back(name);
  428. }
  429. break;
  430. case Meth::META_INPUT:
  431. case Meth::VALUE_INPUT:
  432. mgb_assert(!name.empty(), "empty input tensor name");
  433. mgb_assert(
  434. m_used_input_names.insert(name).second,
  435. "duplicated input tensor name: %s", name.c_str());
  436. m_cur_rst.inputs.emplace_back(name);
  437. break;
  438. }
  439. auto& layout = tensor.layout();
  440. flatbuffers::Offset<flatbuffers::Vector<uint8_t>> data;
  441. if (has_value) {
  442. check_tensor_value_valid(name, tensor);
  443. auto&& dumper = m_config.tensor_value_dumper;
  444. if (dumper) {
  445. mgb_log_warn(
  446. "serialization v2 format is pure flatbuffer format, not support "
  447. "user tensor value dumper callback.");
  448. }
  449. data = m_builder.CreateVector(
  450. reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte);
  451. m_cur_rst.tensor_value_bytes += layout.span().high_byte;
  452. }
  453. auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0;
  454. auto fshape = m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
  455. auto fcomp_node = fbs::v2::CreateCompNode(
  456. m_builder,
  457. m_builder.CreateSharedString(tensor.comp_node().to_string_logical()));
  458. auto fdtype = build_dtype(layout.dtype);
  459. auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
  460. auto fformat = build_tensor_format(layout.format);
  461. auto serialized_tensor = fbs::v2::CreateTensor(
  462. m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data);
  463. m_cur_opr_tensor.emplace_back(serialized_tensor);
  464. }
  465. void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) {
  466. auto blob = fbs::v2::CreateBlob(
  467. m_builder, m_builder.CreateVector(static_cast<const uint8_t*>(data), size));
  468. m_blobs.emplace_back(blob);
  469. }
  470. // ----------------------------- Loader --------------------------------------
  471. CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node(
  472. const fbs::v2::CompNode* comp_node) {
  473. mgb_assert(comp_node);
  474. if (!comp_node->logical_locator())
  475. return {};
  476. auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str());
  477. m_loader->m_cur_load_config->comp_node_mapper(loc);
  478. return CompNode::load(loc);
  479. }
  480. TensorFormat load_tensor_format(
  481. const fbs::v2::TensorFormat fformat_type, const void* fformat,
  482. const CompNode& comp_node) {
  483. switch (fformat_type) {
  484. case fbs::v2::TensorFormat_DefaultTensorFormat:
  485. return megdnn::DefaultTensorFormat::make();
  486. case fbs::v2::TensorFormat_Image2DPackedTensorFormat: {
  487. auto image_format =
  488. static_cast<const fbs::v2::Image2DPackedTensorFormat*>(fformat);
  489. auto handle =
  490. MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle();
  491. return megdnn::Image2DPack4TensorFormat::make(
  492. image_format->align_axis(), handle);
  493. }
  494. case fbs::v2::TensorFormat_LowbitsAlignedTensorFormat: {
  495. auto lowbit_format =
  496. static_cast<const fbs::v2::LowbitsAlignedTensorFormat*>(fformat);
  497. return megdnn::LowbitsAlignedToBytesTensorFormat::make(
  498. lowbit_format->size_nbits());
  499. }
  500. default:
  501. mgb_throw(
  502. SerializationError, "invalid tensor format type in serialization.");
  503. }
  504. }
  505. TensorLayout load_tensor_layout(
  506. const fbs::v2::Tensor* tensor, const CompNode& comp_node) {
  507. TensorLayout layout;
  508. if (tensor->shape()) {
  509. layout.ndim = tensor->shape()->size();
  510. std::copy(tensor->shape()->begin(), tensor->shape()->end(), layout.shape);
  511. }
  512. if (tensor->dtype()) {
  513. // modify data type inplace for TensorLayout
  514. layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype()));
  515. }
  516. if (tensor->format() && tensor->format_type()) {
  517. layout.format =
  518. load_tensor_format(tensor->format_type(), tensor->format(), comp_node);
  519. }
  520. layout.init_contiguous_stride();
  521. return layout;
  522. }
  523. //! the opr loader should make sure the exist of tensors and the number of
  524. //! tensor, here just assert it.
  525. std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() {
  526. mgb_assert(
  527. m_current_opr->tensors() &&
  528. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  529. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  530. auto comp_node = load_comp_node(tensor->comp_node());
  531. auto layout = load_tensor_layout(tensor, comp_node);
  532. auto ret = std::make_shared<HostTensorND>(comp_node, layout);
  533. auto&& loader = m_loader->m_cur_load_config->tensor_value_loader;
  534. if (tensor->data() && tensor->data()->size() > 0) {
  535. if (loader) {
  536. mgb_log_warn(
  537. "serialization v2 format is pure flatbuffer format, not support "
  538. "user tensor value loader callback.");
  539. }
  540. memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size());
  541. }
  542. if (tensor->name()) {
  543. m_tensor_map[tensor->name()->str()] = ret;
  544. }
  545. if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) {
  546. bool has_value = false;
  547. if (tensor && tensor->data()) {
  548. has_value = tensor->data()->size() != 0;
  549. }
  550. mod(tensor->name() ? tensor->name()->str() : "", has_value, *ret);
  551. }
  552. return ret;
  553. }
  554. std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
  555. load_tensor_shared() {
  556. mgb_assert(
  557. m_current_opr->tensors() &&
  558. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  559. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  560. auto comp_node = load_comp_node(tensor->comp_node());
  561. auto layout = load_tensor_layout(tensor, comp_node);
  562. mgb_assert(tensor->data());
  563. auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
  564. auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()];
  565. if (shared_tensor_ref) {
  566. if (shared_tensor_ref->comp_node() == comp_node)
  567. return shared_tensor_ref;
  568. // same mem node but different comp node, change comp node and share
  569. // value
  570. auto ret = std::make_shared<DeviceTensorND>(*shared_tensor_ref);
  571. ret->comp_node(comp_node);
  572. return ret;
  573. }
  574. if (tensor->name()) {
  575. shared_pair.first = tensor->name()->str();
  576. }
  577. if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
  578. // directly forward CPU memory
  579. HostTensorND hv{comp_node};
  580. if (tensor->data() && tensor->data()->size() > 0) {
  581. hv.dtype(layout.dtype).resize(layout);
  582. memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size());
  583. }
  584. shared_tensor_ref = std::make_shared<DeviceTensorND>();
  585. *shared_tensor_ref = DeviceTensorND::make_proxy(hv);
  586. } else {
  587. // use lazy load for non-CPU devices
  588. HostTensorND hv{CompNode::default_cpu()};
  589. if (tensor->data() && tensor->data()->size() > 0) {
  590. hv.dtype(layout.dtype).resize(layout);
  591. memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size());
  592. }
  593. shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv));
  594. }
  595. return shared_tensor_ref;
  596. }
  597. Metadata GraphLoaderOSSV2::OprLoadContextImpl::load_metadata() {
  598. const auto* fbmeta = m_loader->m_model->metadata();
  599. Metadata ret;
  600. if (fbmeta) {
  601. ret.is_valid = fbmeta->is_valid();
  602. ret.graph_modified = fbmeta->graph_modified();
  603. if (fbmeta->user_info()) {
  604. ret.user_info = fbmeta->user_info()->str();
  605. ret.has_user_info = true;
  606. }
  607. if (fbmeta->optimize_options()) {
  608. ret.optimize_options = fbmeta->optimize_options();
  609. ret.optimized_for_inference = true;
  610. }
  611. }
  612. return ret;
  613. }
  614. void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
  615. const fbs::v2::Operator* fbopr) {
  616. m_cur_opr_tensor_cnt = 0;
  617. m_cur_opr_blob_cnt = 0;
  618. m_cur_opr_param_cnt = 0;
  619. OperatorNodeConfig config;
  620. if (fbopr->output_dtype()) {
  621. config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
  622. }
  623. if (fbopr->name()) {
  624. config.name(fbopr->name()->str());
  625. }
  626. if (fbopr->comp_node()) {
  627. auto cnt = fbopr->comp_node()->size();
  628. cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
  629. for (size_t i = 0; i < cnt; i++) {
  630. CompNode cn{};
  631. auto node = fbopr->comp_node()->Get(i);
  632. if (node) {
  633. cn = load_comp_node(node);
  634. }
  635. comp_node_arr[i] = cn;
  636. }
  637. config.comp_node_arr(comp_node_arr);
  638. }
  639. //! opr version must be exist
  640. uint8_t opr_version = fbopr->opr_version();
  641. auto type_id = fbopr->type_id();
  642. const OprRegistryV2* registry =
  643. OprRegistryV2::versioned_find_by_id(type_id, opr_version);
  644. mgb_throw_if(
  645. !registry, SerializationError,
  646. "failed to find opr with type %s and version %d.",
  647. fbopr->type()->str().c_str(), opr_version);
  648. // load inputs
  649. VarNodeArray inputs;
  650. if (fbopr->inputs()) {
  651. inputs.resize(fbopr->inputs()->size());
  652. for (size_t i = 0; i < inputs.size(); ++i) {
  653. inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i));
  654. }
  655. }
  656. // call loader
  657. auto accessor = registry->loader(*this, inputs, config);
  658. auto opr = accessor.opr();
  659. // check opr type; note that:
  660. // 1. registry->type may be empty for dynamic opr loaders or legacy oprs
  661. // 2. due to some optimization, an opr may be replaced by ImmutableTensor
  662. mgb_assert(
  663. opr && (opr->dyn_typeinfo() == registry->type || !registry->type ||
  664. opr->same_type<opr::ImmutableTensor>()),
  665. "got_type=%s expected_type=%s", opr ? opr->dyn_typeinfo()->name : nullptr,
  666. registry->type->name);
  667. // record output vars; read output names
  668. size_t i = 0;
  669. for (auto ovar : accessor.output()) {
  670. if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  671. m_id2varnode.push_back(ovar);
  672. if (fbopr->outputs()) {
  673. auto id = fbopr->outputs()->Get(i);
  674. mgb_assert(
  675. m_id2varnode.size() - 1 == fbopr->outputs()->Get(i),
  676. "id2var is %zu, fbs get id is %d\n", m_id2varnode.size() - 1,
  677. fbopr->outputs()->Get(i));
  678. if (m_middle_tensors.size() > i) {
  679. auto name = m_middle_tensors[id]->name()->str();
  680. ovar->name(name);
  681. }
  682. }
  683. i++;
  684. }
  685. }
  686. opr->node_prop().attribute().priority = fbopr->priority();
  687. }
  688. GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() {
  689. // load oprs
  690. const auto* oprs = m_loader->m_model->oprs();
  691. {
  692. // inplace arith graph optimization is disabled during opr load
  693. // it tries to restore the same graph as it was dumped
  694. // see test TestSerializer2.LOGEXP for example
  695. GraphLoader::ScopedGraphOptDisabler _(m_graph);
  696. for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) {
  697. m_current_opr = oprs->Get(i);
  698. load_single_opr(m_current_opr);
  699. }
  700. }
  701. // batched loading device values
  702. m_device_value_loader.apply();
  703. LoadResult ret;
  704. ret.graph = m_graph;
  705. ret.tensor_map = m_tensor_map;
  706. const auto* outputs = m_loader->m_model->output_vars_idx();
  707. ret.output_var_list.resize(outputs->size());
  708. for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) {
  709. auto out = outputs->Get(i);
  710. auto var = m_id2varnode.at(out->compact_id());
  711. ret.output_var_map[var->name()] = var;
  712. ret.output_var_map_id[out->original_id()] = var;
  713. ret.output_var_list[i] = var;
  714. }
  715. mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size());
  716. return ret;
  717. }
  718. void GraphLoaderOSSV2::OprLoadContextImpl::load_middle_tensor() {
  719. auto model = m_loader->m_model;
  720. if (model->middle_tensors()) {
  721. for (unsigned int i = 0; i < m_loader->m_model->middle_tensors()->size(); i++) {
  722. m_middle_tensors.push_back(model->middle_tensors()->Get(i));
  723. }
  724. }
  725. }
  726. GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool rewind) {
  727. mgb_assert(m_file);
  728. m_cur_load_config = &config;
  729. if (rewind) {
  730. m_file->rewind();
  731. }
  732. // Read fbs::Graph
  733. uint32_t size;
  734. m_file->read(&size, sizeof(size));
  735. m_model_buf = m_file->read_shared(size);
  736. mgb_throw_if(
  737. !fbs::v2::ModelBufferHasIdentifier(m_model_buf.data()), SerializationError,
  738. "invalid fbs model");
  739. {
  740. flatbuffers::Verifier verifier(
  741. static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size());
  742. mgb_throw_if(
  743. !fbs::v2::VerifyModelBuffer(verifier), SerializationError,
  744. "model verification failed (invalid or corrupted model?)");
  745. }
  746. m_model = fbs::v2::GetModel(m_model_buf.data());
  747. m_mgb_version = m_model->mge_version();
  748. m_model_version = m_model->model_version();
  749. if (m_model->mge_version() > MGB_VERSION) {
  750. mgb_log_warn(
  751. "loading model from future runtime: version=%u "
  752. "model_version=%u",
  753. MGB_VERSION, m_model->mge_version());
  754. }
  755. if (m_model_version > CURRENT_VERSION) {
  756. mgb_log_warn(
  757. "The model dump in the future version %d, try to load it, maybe case "
  758. "load error in %d version.",
  759. m_model_version, CURRENT_VERSION);
  760. }
  761. if (m_shared_tensor_map.empty()) {
  762. m_shared_tensor_map.resize(m_model->nr_shared_tensor());
  763. } else {
  764. mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor());
  765. }
  766. OprLoadContextImpl ctx{this, m_model->mge_version()};
  767. ctx.load_middle_tensor();
  768. auto metadata = ctx.load_metadata();
  769. auto result = ctx.load_oprs();
  770. result.metadata = metadata;
  771. if (m_model->output_alias() && m_model->output_alias()->size() > 0) {
  772. auto nr_alias = m_model->output_alias()->size();
  773. result.output_var_list.resize(nr_alias);
  774. for (size_t i = 0; i < nr_alias; i++) {
  775. auto output_alias = m_model->output_alias()->Get(i);
  776. std::string name = output_alias->name()->str();
  777. size_t id = output_alias->id();
  778. result.output_var_map[name] = result.output_var_map_id[id];
  779. result.output_var_list[i] = result.output_var_map_id[id];
  780. }
  781. }
  782. m_model_loaded = true;
  783. result.graph_compile_ahead();
  784. return result;
  785. }
  786. std::unique_ptr<GraphDumper> make_fbs_v2_dumper(
  787. std::unique_ptr<OutputFile> file, int version) {
  788. return std::make_unique<GraphDumperOSSV2>(std::move(file), version);
  789. }
  790. std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) {
  791. return std::make_unique<GraphLoaderOSSV2>(std::move(file));
  792. }
  793. bool is_fbs_v2_file(InputFile& file) {
  794. constexpr size_t identifier_length = 25;
  795. char identifier[identifier_length];
  796. file.read(identifier, identifier_length);
  797. file.skip(-identifier_length);
  798. //! skip the size in prefix of the file
  799. return fbs::v2::ModelBufferHasIdentifier(identifier + sizeof(uint32_t));
  800. }
  801. } // namespace serialization
  802. } // namespace mgb
  803. #endif
  804. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}