You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_oss_v2.cpp 40 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000
  1. #if MGB_ENABLE_FBS_SERIALIZATION
  2. #include <map>
  3. #include "megbrain/comp_node_env.h"
  4. #include "megbrain/opr/io.h"
  5. #include "megbrain/serialization/helper.h"
  6. #include "megbrain/serialization/internal/flatbuffers_helper.h"
  7. #include "megbrain/serialization/internal/schema_v2_generated.h"
  8. #include "megbrain/serialization/metadata.h"
  9. #include "megbrain/serialization/opr_load_dump.h"
  10. #include "megbrain/serialization/oss_opr_load_dump.h"
  11. #include "megbrain/utils/hash_ct.h"
  12. #include "megdnn/tensor_format.h"
  13. #include "serializer_oss_common.h"
  14. #include "megbrain/gopt/framework.h"
  15. namespace mgb {
  16. namespace serialization {
  17. /*!
  18. * \brief replace the the opr who has the replace_opr methord in OprLoadDumpImplV2
  19. */
  20. class PassConvertToCompatible : public gopt::Pass {
  21. ThinHashMap<
  22. Typeinfo*, thin_function<cg::OperatorNodeBase*(
  23. cg::OperatorNodeBase*, const VarNodeArray&)>>
  24. m_opr_replace_func;
  25. gopt::VarReplaceCheckFlag m_var_replace_check_flag =
  26. gopt::VarReplaceCheckFlag::CHECK_ALL;
  27. public:
  28. const char* name() const override { return "PassConvertToCompatible"; };
  29. PassConvertToCompatible& set_var_replace_check_flag(
  30. gopt::VarReplaceCheckFlag flag) {
  31. m_var_replace_check_flag = flag;
  32. return *this;
  33. }
  34. void apply(gopt::OptState& state) const override {
  35. state.set_var_replace_check_flag(m_var_replace_check_flag);
  36. auto rewriter = state.graph().make_rewriter();
  37. auto on_opr = [this, &rewriter](cg::OperatorNodeBase* opr) {
  38. auto it = m_opr_replace_func.find(opr->dyn_typeinfo());
  39. if (it != m_opr_replace_func.end()) {
  40. VarNodeArray new_inp;
  41. new_inp.clear();
  42. new_inp.reserve(opr->input().size());
  43. for (auto i : opr->input()) {
  44. new_inp.push_back(rewriter.get_var(i));
  45. }
  46. auto new_opr = (it->second)(opr, new_inp);
  47. auto &&origin_out = opr->output(), &&cur_out = new_opr->output();
  48. for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size());
  49. i++) {
  50. rewriter.replace_var(origin_out[i], cur_out[i], nullptr);
  51. }
  52. } else {
  53. rewriter.auto_replace_outputs(opr);
  54. }
  55. };
  56. state.graph().iter(on_opr);
  57. rewriter.apply_inplace();
  58. }
  59. static std::unique_ptr<PassConvertToCompatible> make(
  60. const SymbolVarArray& output_vars) {
  61. auto ret = std::make_unique<PassConvertToCompatible>();
  62. // iterate oprs to init
  63. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  64. if (!GraphDumper::should_remove_in_dump(opr)) {
  65. auto registry = OprRegistryV2::versioned_find_by_typeinfo(
  66. opr->dyn_typeinfo(), CURRENT_VERSION);
  67. mgb_throw_if(
  68. !registry,
  69. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  70. "serialization as FlatBuffers is not supported for "
  71. "operator %s, typeinfo %p",
  72. opr->dyn_typeinfo()->name, opr->dyn_typeinfo());
  73. if (registry->converter) {
  74. ret->m_opr_replace_func[opr->dyn_typeinfo()] = registry->converter;
  75. }
  76. }
  77. };
  78. cg::DepOprIter dep_opr_iter{on_opr};
  79. for (auto i : output_vars) {
  80. dep_opr_iter.add(i.node()->owner_opr());
  81. }
  82. return ret;
  83. };
  84. };
  85. namespace {
  86. fbs::v2::TensorFormat get_flatbuffer_tensor_format_type(
  87. const TensorLayout::Format& format) {
  88. using Type = megdnn::TensorFormat::Type;
  89. switch (format.type()) {
  90. case Type::DEFAULT:
  91. return fbs::v2::TensorFormat::TensorFormat_DefaultTensorFormat;
  92. case Type::IMAGE2D_PACK4:
  93. return fbs::v2::TensorFormat::TensorFormat_Image2DPackedTensorFormat;
  94. case Type::LOWBITS_ALIGNED_TO_BYTE:
  95. return fbs::v2::TensorFormat::TensorFormat_LowbitsAlignedTensorFormat;
  96. default:
  97. mgb_throw(
  98. SerializationError, "invalid tensor format type in serialization.");
  99. }
  100. }
  101. } // namespace
  102. flatbuffers::Offset<fbs::DType> GraphDumperOSSV2::build_dtype(DType dtype) {
  103. return fbs::intl::build_dtype(m_builder, dtype);
  104. }
  105. flatbuffers::Offset<void> GraphDumperOSSV2::build_tensor_format(
  106. const TensorLayout::Format& format) {
  107. using Type = megdnn::TensorFormat::Type;
  108. switch (format.type()) {
  109. case Type::DEFAULT:
  110. return fbs::v2::CreateDefaultTensorFormat(m_builder).Union();
  111. case Type::IMAGE2D_PACK4:
  112. return fbs::v2::CreateImage2DPackedTensorFormat(
  113. m_builder, format.as_impl<megdnn::Image2DPack4TensorFormat>()
  114. .align_axis())
  115. .Union();
  116. case Type::LOWBITS_ALIGNED_TO_BYTE: {
  117. auto size_bite = format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
  118. .size_nbits();
  119. auto align_size_in_bits =
  120. format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>()
  121. .align_size_in_bits();
  122. return fbs::v2::CreateLowbitsAlignedTensorFormat(
  123. m_builder, size_bite, align_size_in_bits)
  124. .Union();
  125. }
  126. default:
  127. mgb_throw(
  128. SerializationError, "invalid tensor format type in serialization.");
  129. }
  130. }
  131. flatbuffers::Offset<fbs::v2::MiddleTensor> GraphDumperOSSV2::build_middle_tensor(
  132. const SymbolVar var) {
  133. mgb_assert(var.node());
  134. auto fbname = m_builder.CreateSharedString(var.node()->name());
  135. flatbuffers::Offset<fbs::v2::MiddleTensor> serialized_middle_tensor;
  136. if (var.node()->dev_tensor_valid()) {
  137. auto layout = var.node()->layout();
  138. auto fshape =
  139. m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
  140. auto fcomp_node = fbs::v2::CreateCompNode(
  141. m_builder, m_builder.CreateSharedString(
  142. var.node()->comp_node().to_string_logical()));
  143. auto fdtype = build_dtype(layout.dtype);
  144. auto fformat_type = get_flatbuffer_tensor_format_type(layout.format);
  145. auto fformat = build_tensor_format(layout.format);
  146. serialized_middle_tensor = fbs::v2::CreateMiddleTensor(
  147. m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat);
  148. } else if (var.node()->shape().ndim > 0) {
  149. auto shape = var.node()->shape();
  150. auto fshape =
  151. m_builder.CreateVectorScalarCast<uint32_t>(shape.shape, shape.ndim);
  152. serialized_middle_tensor =
  153. fbs::v2::CreateMiddleTensor(m_builder, fbname, fshape);
  154. } else {
  155. serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
  156. }
  157. return serialized_middle_tensor;
  158. }
  159. flatbuffers::Offset<fbs::v2::OutputVar> GraphDumperOSSV2::build_output_var(
  160. const SymbolVar var) {
  161. auto out_node = var.node();
  162. if (m_var2midtensor_id.find(var.node()) == m_var2midtensor_id.end()) {
  163. mgb_assert(m_var_remove_in_dump.find(var.node()) != m_var_remove_in_dump.end());
  164. out_node = m_var_remove_in_dump[var.node()];
  165. }
  166. return fbs::v2::CreateOutputVar(
  167. m_builder, m_var2midtensor_id.at(out_node), var.node()->id());
  168. }
  169. void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
  170. m_oprs_to_dump.clear();
  171. // iterate oprs to init
  172. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  173. if (should_remove_in_dump(opr)) {
  174. mgb_assert(opr->input().size() == 1);
  175. // Copy input ID to output
  176. for (auto i : opr->output()) {
  177. if (m_var_remove_in_dump.find(opr->input(0)) !=
  178. m_var_remove_in_dump.end()) {
  179. m_var_remove_in_dump[i] = m_var_remove_in_dump[opr->input(0)];
  180. } else {
  181. m_var_remove_in_dump[i] = opr->input(0);
  182. }
  183. }
  184. } else {
  185. auto registry = OprRegistryV2::versioned_find_by_typeinfo(
  186. opr->dyn_typeinfo(), m_version);
  187. if (!registry || !registry->dumper) {
  188. mgb_throw(
  189. cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>,
  190. "serialization as FlatBuffers is not supported for "
  191. "operator %s",
  192. opr->dyn_typeinfo()->name);
  193. }
  194. mgb_assert(
  195. registry->version <= m_version,
  196. "The Operator version should less than model version");
  197. m_oprs_to_dump.emplace_back(opr, registry);
  198. }
  199. };
  200. cg::DepOprIter dep_opr_iter{on_opr};
  201. for (auto i : endpoints) {
  202. dep_opr_iter.add(i.node()->owner_opr());
  203. }
  204. }
  205. flatbuffers::Offset<fbs::v2::Metadata> GraphDumperOSSV2::build_metadata(
  206. const Metadata& metadata) {
  207. auto user_info = m_builder.CreateSharedString(metadata.user_info);
  208. fbs::v2::MetadataBuilder builder(m_builder);
  209. builder.add_is_valid(metadata.is_valid);
  210. builder.add_graph_modified(metadata.graph_modified);
  211. builder.add_optimize_options(metadata.optimize_options);
  212. builder.add_user_info(user_info);
  213. return builder.Finish();
  214. }
  215. flatbuffers::Offset<fbs::v2::Operator> GraphDumperOSSV2::build_single_opr(
  216. cg::OperatorNodeBase* opr, const OprRegistryV2* registry) {
  217. m_cur_opr = opr;
  218. ++m_cur_rst.nr_opr;
  219. using namespace flatbuffers;
  220. Offset<Vector<uint32_t>> inputs;
  221. if (m_cur_opr->input().size()) {
  222. std::vector<uint32_t> v;
  223. v.reserve(m_cur_opr->input().size());
  224. for (auto inp : m_cur_opr->input()) {
  225. if (m_var2midtensor_id.find(inp) != m_var2midtensor_id.end()) {
  226. v.emplace_back(m_var2midtensor_id.at(inp));
  227. } else {
  228. mgb_assert(
  229. m_var_remove_in_dump.find(inp) != m_var_remove_in_dump.end(),
  230. "when dump the model, the dependence of var is wrong.");
  231. v.emplace_back(m_var2midtensor_id.at(m_var_remove_in_dump[inp]));
  232. }
  233. }
  234. inputs = m_builder.CreateVector(v);
  235. }
  236. m_cur_opr_tensor.clear();
  237. m_blobs.clear();
  238. m_cur_opr_param.clear();
  239. m_cur_opr_param_type.clear();
  240. registry->dumper(*this, *m_cur_opr);
  241. Offset<Vector<Offset<fbs::v2::CompNode>>> comp_node;
  242. auto& config = m_cur_opr->config();
  243. if (config.has_comp_node_set()) {
  244. std::vector<flatbuffers::Offset<fbs::v2::CompNode>> cns;
  245. for (const auto& cn : config.comp_node()) {
  246. cns.emplace_back(fbs::v2::CreateCompNode(
  247. m_builder, m_builder.CreateSharedString(cn.to_string_logical())));
  248. }
  249. comp_node = m_builder.CreateVector(cns);
  250. }
  251. Offset<String> operator_name;
  252. if (m_config.keep_op_name) {
  253. operator_name = m_builder.CreateSharedString(m_cur_opr->name());
  254. }
  255. auto output_dtype = build_dtype(config.output_dtype());
  256. Offset<Vector<uint32_t>> outputs;
  257. if (m_cur_opr->output().size()) {
  258. std::vector<uint32_t> v;
  259. v.reserve(m_cur_opr->output().size());
  260. for (auto out : m_cur_opr->output()) {
  261. if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  262. if (m_config.keep_var_name >= 1) {
  263. auto fbs_out = build_middle_tensor(out);
  264. m_model_middle_tensors.push_back(fbs_out);
  265. } else {
  266. m_model_middle_tensors.push_back(0);
  267. }
  268. m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1;
  269. v.emplace_back(m_var2midtensor_id.at(out));
  270. }
  271. }
  272. outputs = m_builder.CreateVector(v);
  273. }
  274. Offset<Vector<Offset<fbs::v2::Tensor>>> tensors;
  275. if (m_cur_opr_tensor.size())
  276. tensors = m_builder.CreateVector(m_cur_opr_tensor);
  277. //! the blobs data is used by custom data
  278. //! m_blobs will be filled by the Operator dumper function
  279. Offset<Vector<Offset<fbs::v2::Blob>>> blobs;
  280. if (m_blobs.size())
  281. blobs = m_builder.CreateVector(m_blobs);
  282. Offset<Vector<uint8_t>> additional_params_type;
  283. Offset<Vector<Offset<void>>> additional_params;
  284. auto param_cnt = m_cur_opr_param_type.size();
  285. if (param_cnt > 1) {
  286. additional_params_type = m_builder.CreateVectorScalarCast<uint8_t>(
  287. m_cur_opr_param_type.data() + 1, param_cnt - 1);
  288. additional_params =
  289. m_builder.CreateVector(m_cur_opr_param.data() + 1, param_cnt - 1);
  290. }
  291. auto opr_type = m_builder.CreateSharedString(registry->name);
  292. fbs::v2::OperatorBuilder builder(m_builder);
  293. builder.add_type(opr_type);
  294. builder.add_type_id(registry->type_id);
  295. builder.add_inputs(inputs);
  296. builder.add_outputs(outputs);
  297. if (m_config.keep_opr_priority) {
  298. builder.add_priority(opr->node_prop().attribute().priority);
  299. }
  300. builder.add_comp_node(comp_node);
  301. builder.add_opr_version(registry->get_version());
  302. builder.add_name(operator_name);
  303. builder.add_output_dtype(output_dtype);
  304. if (param_cnt > 0) {
  305. builder.add_param_type(m_cur_opr_param_type[0]);
  306. builder.add_param(m_cur_opr_param[0]);
  307. }
  308. if (param_cnt > 1) {
  309. builder.add_additional_params_type(additional_params_type);
  310. builder.add_additional_params(additional_params);
  311. }
  312. builder.add_tensors(tensors);
  313. builder.add_custom_data(blobs);
  314. m_cur_opr = nullptr;
  315. return builder.Finish();
  316. }
  317. SymbolVarArray GraphDumperOSSV2::converter_all_opr_to_compatiable(
  318. const SymbolVarArray& output_vars) {
  319. gopt::GraphOptimizer optimizer;
  320. VarNodeArray rets_var;
  321. for (auto& symbolvar : output_vars) {
  322. rets_var.push_back(symbolvar.node());
  323. }
  324. optimizer.add_pass(PassConvertToCompatible::make(output_vars));
  325. optimizer.apply_inplace(rets_var);
  326. SymbolVarArray dst_vars;
  327. for (auto& var : rets_var) {
  328. dst_vars.push_back({var});
  329. }
  330. return dst_vars;
  331. }
  332. GraphDumper::DumpResult GraphDumperOSSV2::dump(
  333. const SymbolVarArray& output_vars, const DumpConfig& config,
  334. const Metadata& metadata) {
  335. mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph");
  336. auto new_output_vars = output_vars;
  337. if (!config.no_change_graph) {
  338. new_output_vars = converter_all_opr_to_compatiable(output_vars);
  339. mgb_assert(output_vars.size() == new_output_vars.size());
  340. for (size_t id = 0; id < output_vars.size(); id++) {
  341. auto& new_var = new_output_vars[id];
  342. new_var.rename(output_vars[id].node()->name());
  343. }
  344. }
  345. auto begin_pos = m_file->tell();
  346. m_config = config;
  347. m_builder.Reset();
  348. m_output_vars.clear();
  349. m_cur_rst = {};
  350. m_used_input_names.clear();
  351. m_used_param_names.clear();
  352. m_var_remove_in_dump.clear();
  353. m_model_middle_tensors.clear();
  354. m_var2midtensor_id.clear();
  355. m_nr_shared_tensor = 0;
  356. // process output vars
  357. bool keep_output_var_name = m_config.keep_var_name >= 1;
  358. std::unordered_set<std::string> output_var_names;
  359. for (auto i : new_output_vars) {
  360. mgb_assert(
  361. !i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  362. "can not dump var with VOLATILE_CONTENT flag: %s",
  363. cg::dump_var_info({i.node()}).c_str());
  364. if (m_output_vars.insert(i.node()).second && keep_output_var_name) {
  365. auto name_ins = output_var_names.insert(i.node()->name()).second;
  366. mgb_assert(name_ins, "duplicated output var name: %s", i.node()->cname());
  367. }
  368. }
  369. // Dump metadata
  370. auto fbmeta = build_metadata(metadata);
  371. // Dump operators
  372. init_oprs_to_dump(new_output_vars);
  373. std::vector<flatbuffers::Offset<fbs::v2::Operator>> oprs;
  374. for (auto&& i : m_oprs_to_dump) {
  375. record_opr_dumped(i.second->type_id, i.second->name, i.second->version);
  376. oprs.emplace_back(build_single_opr(i.first, i.second));
  377. }
  378. auto fb_oprs = m_builder.CreateVector(oprs);
  379. // Dump output vars
  380. std::vector<flatbuffers::Offset<fbs::v2::OutputVar>> output_vars_idx;
  381. output_vars_idx.reserve(new_output_vars.size());
  382. for (auto i : new_output_vars) {
  383. auto foutput_vars_idx = build_output_var(i);
  384. output_vars_idx.push_back(foutput_vars_idx);
  385. }
  386. auto fb_output_vars = m_builder.CreateVector(output_vars_idx);
  387. std::vector<flatbuffers::Offset<fbs::v2::OutputAlias>> output_vars_alias;
  388. if (m_config.alias_name_map.size() > 0) {
  389. for (auto&& pair : m_config.alias_name_map) {
  390. std::string name;
  391. SymbolVar var;
  392. std::tie(name, var) = pair;
  393. auto fbs_name = m_builder.CreateSharedString(name);
  394. output_vars_alias.push_back(
  395. fbs::v2::CreateOutputAlias(m_builder, var.node()->id(), fbs_name));
  396. }
  397. }
  398. auto fbs_output_alias = m_builder.CreateVector(output_vars_alias);
  399. flatbuffers::Offset<flatbuffers::Vector<
  400. flatbuffers::Offset<mgb::serialization::fbs::v2::MiddleTensor>>>
  401. fb_mid_tensor;
  402. if (m_config.keep_var_name >= 1)
  403. fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
  404. fbs::v2::ModelBuilder model(m_builder);
  405. model.add_mge_version(MGB_VERSION);
  406. model.add_model_version(m_version);
  407. model.add_oprs(fb_oprs);
  408. if (m_config.keep_var_name >= 1) {
  409. model.add_middle_tensors(fb_mid_tensor);
  410. }
  411. model.add_output_vars_idx(fb_output_vars);
  412. model.add_output_alias(fbs_output_alias);
  413. model.add_nr_shared_tensor(m_nr_shared_tensor);
  414. model.add_metadata(fbmeta);
  415. m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier());
  416. // Write serialized fbs::Graph
  417. m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize());
  418. // Finalize DumpResult
  419. auto&& ret = m_cur_rst;
  420. for (size_t i = 0; i < new_output_vars.size(); i++) {
  421. ret.outputs.emplace_back(
  422. keep_output_var_name ? new_output_vars[i].node()->cname()
  423. : ssprintf("unnamed%zu", i));
  424. }
  425. std::sort(ret.inputs.begin(), ret.inputs.end());
  426. mgb_assert(ret.nr_opr == m_oprs_to_dump.size());
  427. ret.tot_bytes = m_file->tell() - begin_pos;
  428. return ret;
  429. }
  430. void GraphDumperOSSV2::dump_tensor(
  431. const std::string& name, const HostTensorND& tensor, TensorWriteMethod method,
  432. TensorFormat format) {
  433. using namespace flatbuffers;
  434. using Meth = TensorWriteMethod;
  435. mgb_assert(
  436. (method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()),
  437. "name must be non-empty for non Meth::VALUE_ANONYMOUS tensors");
  438. bool has_value = method != Meth::META_INPUT;
  439. bool should_keep_name = true;
  440. switch (method) {
  441. case Meth::VALUE_ANONYMOUS:
  442. should_keep_name = false;
  443. break;
  444. case Meth::VALUE_SHARED:
  445. should_keep_name = m_config.keep_param_name;
  446. ++m_nr_shared_tensor;
  447. if (m_config.keep_param_name) {
  448. mgb_assert(
  449. m_used_param_names.insert(name).second,
  450. "duplicated VALUE_SHARED tensor name: %s", name.c_str());
  451. m_cur_rst.params.emplace_back(name);
  452. }
  453. break;
  454. case Meth::META_INPUT:
  455. case Meth::VALUE_INPUT:
  456. mgb_assert(!name.empty(), "empty input tensor name");
  457. mgb_assert(
  458. m_used_input_names.insert(name).second,
  459. "duplicated input tensor name: %s", name.c_str());
  460. m_cur_rst.inputs.emplace_back(name);
  461. break;
  462. }
  463. auto& layout = tensor.layout();
  464. flatbuffers::Offset<flatbuffers::Vector<uint8_t>> data;
  465. if (has_value) {
  466. check_tensor_value_valid(name, tensor);
  467. auto&& dumper = m_config.tensor_value_dumper;
  468. if (dumper) {
  469. mgb_log_warn(
  470. "serialization v2 format is pure flatbuffer format, not support "
  471. "user tensor value dumper callback.");
  472. }
  473. data = m_builder.CreateVector(
  474. reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte);
  475. m_cur_rst.tensor_value_bytes += layout.span().high_byte;
  476. }
  477. auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0;
  478. auto fshape = m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim);
  479. auto fcomp_node = fbs::v2::CreateCompNode(
  480. m_builder,
  481. m_builder.CreateSharedString(tensor.comp_node().to_string_logical()));
  482. auto fdtype = build_dtype(layout.dtype);
  483. auto fformat_type = get_flatbuffer_tensor_format_type(format);
  484. auto fformat = build_tensor_format(format);
  485. auto serialized_tensor = fbs::v2::CreateTensor(
  486. m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data);
  487. m_cur_opr_tensor.emplace_back(serialized_tensor);
  488. }
  489. void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) {
  490. auto blob = fbs::v2::CreateBlob(
  491. m_builder, m_builder.CreateVector(static_cast<const uint8_t*>(data), size));
  492. m_blobs.emplace_back(blob);
  493. }
  494. // ----------------------------- Loader --------------------------------------
  495. /**
  496. * SharedTensorAlignMent will record all shared device tensors, at beginning, the
  497. * tensor is not aligned, after all shared device tensor loaded, and the user
  498. * provide memory will be wrote, and reorder all the tensor to aligned address
  499. * ptr.
  500. */
  501. class GraphLoaderOSSV2::SharedTensorAlignMent {
  502. public:
  503. SharedTensorAlignMent(SharedBuffer buffer, InputFile* file, bool is_enabled)
  504. : m_enabled(is_enabled), m_file(file), m_model_buffer(buffer){};
  505. bool add_device_tensor(std::shared_ptr<DeviceTensorND> tensor) {
  506. if (!m_enabled)
  507. return false;
  508. if (tensor) {
  509. m_device_tensors[reinterpret_cast<intptr_t>(tensor->raw_ptr())] = tensor;
  510. return true;
  511. }
  512. return false;
  513. }
  514. /**
  515. * record the tensor shared from the m_model_buffer, copy every tensor to
  516. * the aligned address, then the model file will be modilfied, so it can't
  517. * use again.
  518. */
  519. bool reorder_and_align_tensor() {
  520. if (!m_enabled)
  521. return false;
  522. bool modilfied = false;
  523. intptr_t buffer_start = reinterpret_cast<intptr_t>(m_model_buffer.data());
  524. intptr_t write_end = buffer_start;
  525. for (auto& iter : m_device_tensors) {
  526. auto& tensor = iter.second;
  527. size_t tensor_size = tensor->layout().span().dist_byte();
  528. size_t alignment = tensor->comp_node().get_mem_addr_alignment();
  529. intptr_t tensor_start = reinterpret_cast<intptr_t>(tensor->raw_ptr());
  530. intptr_t align_start = static_cast<intptr_t>(
  531. reinterpret_cast<uintptr_t>(tensor->raw_ptr()) & ~(alignment - 1));
  532. if (align_start > write_end) {
  533. if (tensor_start != align_start) {
  534. memmove(reinterpret_cast<void*>(align_start),
  535. reinterpret_cast<void*>(tensor_start), tensor_size);
  536. modilfied = true;
  537. }
  538. write_end = align_start + tensor_size;
  539. DeviceTensorStorage storage;
  540. auto raw_storage = std::shared_ptr<mgb::dt_byte>(
  541. reinterpret_cast<mgb::dt_byte*>(align_start), [](void*) {});
  542. storage.reset(tensor->comp_node(), tensor_size, raw_storage);
  543. tensor->reset(storage, tensor->layout());
  544. } else {
  545. DeviceTensorND new_tensor(tensor->comp_node());
  546. new_tensor.copy_from(*tensor).sync();
  547. *tensor = new_tensor;
  548. }
  549. if (modilfied) {
  550. m_file->have_modified();
  551. }
  552. }
  553. return true;
  554. }
  555. private:
  556. bool m_enabled = false;
  557. InputFile* m_file;
  558. SharedBuffer m_model_buffer;
  559. std::map<intptr_t, std::shared_ptr<DeviceTensorND>> m_device_tensors;
  560. };
  561. CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node(
  562. const fbs::v2::CompNode* comp_node) {
  563. mgb_assert(comp_node);
  564. if (!comp_node->logical_locator())
  565. return {};
  566. auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str());
  567. m_loader->m_cur_load_config->comp_node_mapper(loc);
  568. return CompNode::load(loc);
  569. }
  570. TensorFormat get_tensor_format(
  571. const fbs::v2::TensorFormat fformat_type, const void* fformat,
  572. const CompNode& comp_node) {
  573. switch (fformat_type) {
  574. case fbs::v2::TensorFormat_DefaultTensorFormat:
  575. return megdnn::DefaultTensorFormat::make();
  576. case fbs::v2::TensorFormat_Image2DPackedTensorFormat: {
  577. auto image_format =
  578. static_cast<const fbs::v2::Image2DPackedTensorFormat*>(fformat);
  579. auto handle =
  580. MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle();
  581. return megdnn::Image2DPack4TensorFormat::make(
  582. image_format->align_axis(), handle);
  583. }
  584. case fbs::v2::TensorFormat_LowbitsAlignedTensorFormat: {
  585. auto lowbit_format =
  586. static_cast<const fbs::v2::LowbitsAlignedTensorFormat*>(fformat);
  587. return megdnn::LowbitsAlignedToBytesTensorFormat::make(
  588. lowbit_format->size_nbits());
  589. }
  590. default:
  591. mgb_throw(
  592. SerializationError, "invalid tensor format type in serialization.");
  593. }
  594. }
  595. TensorLayout load_tensor_layout_without_format(const fbs::v2::Tensor* tensor) {
  596. TensorLayout layout;
  597. if (tensor->shape()) {
  598. layout.ndim = tensor->shape()->size();
  599. std::copy(tensor->shape()->begin(), tensor->shape()->end(), layout.shape);
  600. }
  601. if (tensor->dtype()) {
  602. // modify data type inplace for TensorLayout
  603. layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype()));
  604. }
  605. layout.init_contiguous_stride();
  606. return layout;
  607. }
  608. TensorFormat GraphLoaderOSSV2::OprLoadContextImpl::load_tensor_format(size_t id) {
  609. mgb_assert(m_current_opr->tensors() && id < m_current_opr->tensors()->size());
  610. auto tensor = m_current_opr->tensors()->Get(id);
  611. auto comp_node = load_comp_node(tensor->comp_node());
  612. TensorFormat format;
  613. if (tensor->format() && tensor->format_type()) {
  614. format = get_tensor_format(tensor->format_type(), tensor->format(), comp_node);
  615. }
  616. return format;
  617. }
  618. //! the opr loader should make sure the exist of tensors and the number of
  619. //! tensor, here just assert it.
  620. std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() {
  621. mgb_assert(
  622. m_current_opr->tensors() &&
  623. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  624. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  625. auto comp_node = load_comp_node(tensor->comp_node());
  626. auto layout = load_tensor_layout_without_format(tensor);
  627. auto ret = std::make_shared<HostTensorND>(comp_node, layout);
  628. auto&& loader = m_loader->m_cur_load_config->tensor_value_loader;
  629. if (tensor->data() && tensor->data()->size() > 0) {
  630. if (loader) {
  631. mgb_log_warn(
  632. "serialization v2 format is pure flatbuffer format, not support "
  633. "user tensor value loader callback.");
  634. }
  635. fill_tensor_memory(
  636. *ret, tensor->data()->data(), tensor->data()->size(),
  637. m_loader->m_file->is_shared_memory());
  638. }
  639. if (tensor->name()) {
  640. m_tensor_map[tensor->name()->str()] = ret;
  641. }
  642. if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) {
  643. bool has_value = false;
  644. if (tensor && tensor->data()) {
  645. has_value = tensor->data()->size() != 0;
  646. }
  647. mod(tensor->name() ? tensor->name()->str() : "", has_value, *ret);
  648. }
  649. return ret;
  650. }
  651. std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl::
  652. load_tensor_shared(bool copy_immediatly) {
  653. mgb_assert(
  654. m_current_opr->tensors() &&
  655. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  656. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  657. auto comp_node = load_comp_node(tensor->comp_node());
  658. auto layout = load_tensor_layout_without_format(tensor);
  659. mgb_assert(tensor->data());
  660. if (m_loader->m_shared_tensor_map.size() <= m_cur_shared_tensor_idx) {
  661. m_loader->m_shared_tensor_map.resize(m_cur_shared_tensor_idx + 5);
  662. }
  663. auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
  664. auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()];
  665. if (shared_tensor_ref) {
  666. if (shared_tensor_ref->comp_node() == comp_node)
  667. return shared_tensor_ref;
  668. // same mem node but different comp node, change comp node and share
  669. // value
  670. auto ret = std::make_shared<DeviceTensorND>(*shared_tensor_ref);
  671. ret->comp_node(comp_node);
  672. return ret;
  673. }
  674. if (tensor->name()) {
  675. shared_pair.first = tensor->name()->str();
  676. }
  677. if (comp_node.mem_node() == CompNode::default_cpu().mem_node() || copy_immediatly) {
  678. // directly forward CPU memory
  679. shared_tensor_ref = std::make_shared<DeviceTensorND>();
  680. HostTensorND hv{comp_node};
  681. if (tensor->data() && tensor->data()->size() > 0) {
  682. hv.dtype(layout.dtype).resize(layout);
  683. fill_tensor_memory(
  684. hv, tensor->data()->data(), tensor->data()->size(),
  685. m_loader->m_file->is_shared_memory());
  686. }
  687. if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
  688. *shared_tensor_ref = DeviceTensorND::make_proxy(hv);
  689. m_tensor_alignment->add_device_tensor(shared_tensor_ref);
  690. } else {
  691. mgb_assert(copy_immediatly);
  692. shared_tensor_ref->comp_node(comp_node).copy_from(hv).sync();
  693. }
  694. } else {
  695. // use lazy load for non-CPU devices
  696. HostTensorND hv{CompNode::default_cpu()};
  697. if (tensor->data() && tensor->data()->size() > 0) {
  698. hv.dtype(layout.dtype).resize(layout);
  699. fill_tensor_memory(
  700. hv, tensor->data()->data(), tensor->data()->size(),
  701. m_loader->m_file->is_shared_memory());
  702. }
  703. shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv));
  704. }
  705. return shared_tensor_ref;
  706. }
  707. Metadata GraphLoaderOSSV2::OprLoadContextImpl::load_metadata() {
  708. const auto* fbmeta = m_loader->m_model->metadata();
  709. Metadata ret;
  710. if (fbmeta) {
  711. ret.is_valid = fbmeta->is_valid();
  712. ret.graph_modified = fbmeta->graph_modified();
  713. if (fbmeta->user_info()) {
  714. ret.user_info = fbmeta->user_info()->str();
  715. ret.has_user_info = true;
  716. }
  717. if (fbmeta->optimize_options()) {
  718. ret.optimize_options = fbmeta->optimize_options();
  719. ret.optimized_for_inference = true;
  720. }
  721. }
  722. return ret;
  723. }
  724. void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
  725. const fbs::v2::Operator* fbopr) {
  726. m_cur_opr_tensor_cnt = 0;
  727. m_cur_opr_blob_cnt = 0;
  728. m_cur_opr_param_cnt = 0;
  729. OperatorNodeConfig config;
  730. if (fbopr->output_dtype()) {
  731. config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
  732. }
  733. if (fbopr->name()) {
  734. config.name(fbopr->name()->str());
  735. }
  736. if (fbopr->comp_node()) {
  737. auto cnt = fbopr->comp_node()->size();
  738. cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
  739. for (size_t i = 0; i < cnt; i++) {
  740. CompNode cn{};
  741. auto node = fbopr->comp_node()->Get(i);
  742. if (node) {
  743. cn = load_comp_node(node);
  744. }
  745. comp_node_arr[i] = cn;
  746. }
  747. config.comp_node_arr(comp_node_arr);
  748. }
  749. //! opr version must be exist
  750. uint8_t opr_version = fbopr->opr_version();
  751. auto type_id = fbopr->type_id();
  752. const OprRegistryV2* registry =
  753. OprRegistryV2::versioned_find_by_id(type_id, opr_version);
  754. mgb_throw_if(
  755. !registry, SerializationError,
  756. "failed to find opr with type %s and version %d.",
  757. fbopr->type()->str().c_str(), opr_version);
  758. // load inputs
  759. VarNodeArray inputs;
  760. if (fbopr->inputs()) {
  761. inputs.resize(fbopr->inputs()->size());
  762. for (size_t i = 0; i < inputs.size(); ++i) {
  763. inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i));
  764. }
  765. }
  766. // call loader
  767. auto accessor = registry->loader(*this, inputs, config);
  768. auto opr = accessor.opr();
  769. // check opr type; note that:
  770. // 1. registry->type may be empty for dynamic opr loaders or legacy oprs
  771. // 2. due to some optimization, an opr may be replaced by ImmutableTensor
  772. mgb_assert(
  773. opr && (opr->dyn_typeinfo() == registry->type || !registry->type ||
  774. opr->same_type<opr::ImmutableTensor>()),
  775. "got_type=%s expected_type=%s", opr ? opr->dyn_typeinfo()->name : nullptr,
  776. registry->type->name);
  777. // record output vars; read output names
  778. size_t i = 0;
  779. for (auto ovar : accessor.output()) {
  780. if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  781. m_id2varnode.push_back(ovar);
  782. if (fbopr->outputs()) {
  783. auto id = fbopr->outputs()->Get(i);
  784. mgb_assert(
  785. m_id2varnode.size() - 1 == fbopr->outputs()->Get(i),
  786. "id2var is %zu, fbs get id is %d\n", m_id2varnode.size() - 1,
  787. fbopr->outputs()->Get(i));
  788. if (m_middle_tensors.size() > i) {
  789. auto name = m_middle_tensors[id]->name()->str();
  790. ovar->name(name);
  791. }
  792. }
  793. i++;
  794. }
  795. }
  796. opr->node_prop().attribute().priority = fbopr->priority();
  797. }
  798. GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() {
  799. // load oprs
  800. const auto* oprs = m_loader->m_model->oprs();
  801. {
  802. // inplace arith graph optimization is disabled during opr load
  803. // it tries to restore the same graph as it was dumped
  804. // see test TestSerializer2.LOGEXP for example
  805. GraphLoader::ScopedGraphOptDisabler _(m_graph);
  806. for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) {
  807. m_current_opr = oprs->Get(i);
  808. load_single_opr(m_current_opr);
  809. }
  810. }
  811. // batched loading device values
  812. m_device_value_loader.apply();
  813. LoadResult ret;
  814. ret.graph = m_graph;
  815. ret.tensor_map = m_tensor_map;
  816. const auto* outputs = m_loader->m_model->output_vars_idx();
  817. ret.output_var_list.resize(outputs->size());
  818. for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) {
  819. auto out = outputs->Get(i);
  820. auto var = m_id2varnode.at(out->compact_id());
  821. ret.output_var_map[var->name()] = var;
  822. ret.output_var_map_id[out->original_id()] = var;
  823. ret.output_var_list[i] = var;
  824. }
  825. mgb_assert(m_cur_shared_tensor_idx <= m_loader->m_shared_tensor_map.size());
  826. return ret;
  827. }
  828. void GraphLoaderOSSV2::OprLoadContextImpl::load_middle_tensor() {
  829. auto model = m_loader->m_model;
  830. if (model->middle_tensors()) {
  831. for (unsigned int i = 0; i < m_loader->m_model->middle_tensors()->size(); i++) {
  832. m_middle_tensors.push_back(model->middle_tensors()->Get(i));
  833. }
  834. }
  835. }
  836. GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool rewind) {
  837. mgb_assert(m_file);
  838. m_cur_load_config = &config;
  839. if (rewind) {
  840. m_file->rewind();
  841. }
  842. // Read fbs::Graph
  843. uint32_t size;
  844. m_file->read(&size, sizeof(size));
  845. m_file->skip(-sizeof(size));
  846. m_model_buf = m_file->read_shared(size + sizeof(size));
  847. {
  848. flatbuffers::Verifier verifier(
  849. static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size());
  850. mgb_throw_if(
  851. !fbs::v2::VerifySizePrefixedModelBuffer(verifier), SerializationError,
  852. "model verification failed (invalid or corrupted model?)");
  853. }
  854. m_model = fbs::v2::GetSizePrefixedModel(m_model_buf.data());
  855. m_mgb_version = m_model->mge_version();
  856. m_model_version = m_model->model_version();
  857. if (m_model->mge_version() > MGB_VERSION) {
  858. mgb_log_warn(
  859. "loading model from future runtime: version=%u "
  860. "model_version=%u",
  861. MGB_VERSION, m_model->mge_version());
  862. }
  863. if (m_model_version > CURRENT_VERSION) {
  864. mgb_log_warn(
  865. "The model dump in the future version %d, try to load it, maybe case "
  866. "load error in %d version.",
  867. m_model_version, CURRENT_VERSION);
  868. }
  869. if (m_shared_tensor_map.empty()) {
  870. m_shared_tensor_map.resize(m_model->nr_shared_tensor());
  871. } else {
  872. mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor());
  873. }
  874. SharedTensorAlignMent tensor_alignment(
  875. m_model_buf, m_file.get(),
  876. m_file->writable() && m_file->is_shared_memory());
  877. OprLoadContextImpl ctx{this, &tensor_alignment, m_model->mge_version()};
  878. ctx.load_middle_tensor();
  879. auto metadata = ctx.load_metadata();
  880. auto result = ctx.load_oprs();
  881. result.metadata = metadata;
  882. if (m_model->output_alias() && m_model->output_alias()->size() > 0) {
  883. auto nr_alias = m_model->output_alias()->size();
  884. result.output_var_list.resize(nr_alias);
  885. for (size_t i = 0; i < nr_alias; i++) {
  886. auto output_alias = m_model->output_alias()->Get(i);
  887. std::string name = output_alias->name()->str();
  888. size_t id = output_alias->id();
  889. result.output_var_map[name] = result.output_var_map_id[id];
  890. result.output_var_list[i] = result.output_var_map_id[id];
  891. }
  892. }
  893. m_model_loaded = true;
  894. tensor_alignment.reorder_and_align_tensor();
  895. result.graph_compile_ahead();
  896. return result;
  897. }
  898. std::unique_ptr<GraphDumper> make_fbs_v2_dumper(
  899. std::unique_ptr<OutputFile> file, int version) {
  900. return std::make_unique<GraphDumperOSSV2>(std::move(file), version);
  901. }
  902. std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) {
  903. return std::make_unique<GraphLoaderOSSV2>(std::move(file));
  904. }
  905. bool is_fbs_v2_file(InputFile& file) {
  906. constexpr size_t identifier_length = 25;
  907. char identifier[identifier_length];
  908. file.read(identifier, identifier_length);
  909. file.skip(-identifier_length);
  910. //! skip the size in prefix of the file
  911. return fbs::v2::ModelBufferHasIdentifier(identifier + sizeof(uint32_t));
  912. }
  913. } // namespace serialization
  914. } // namespace mgb
  915. #endif
  916. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}