You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

serializer_oss.cpp 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941
  1. /**
  2. * \file src/serialization/impl/serializer_oss.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. /*
  12. * Dump file layout:
  13. * [uint32_t fourcc]
  14. * [00 00 00 00]
  15. * [uint64_t offset to graph from tensor start]
  16. * [Tensor 1]
  17. * [Tensor 2]
  18. * [...]
  19. * [Tensor N]
  20. * [SizePrefixed FlatBuffers Graph]
  21. */
  22. #if MGB_ENABLE_FBS_SERIALIZATION
  23. #include "batched_device_value_loader.h"
  24. #include "megbrain/graph/exc_extra_info.h"
  25. #include "megbrain/opr/io.h"
  26. #include "megbrain/serialization/helper.h"
  27. #include "megbrain/serialization/internal/flatbuffers_helper.h"
  28. #include "megbrain/serialization/internal/schema_generated.h"
  29. #include "megbrain/serialization/opr_load_dump.h"
  30. #include "megbrain/serialization/metadata.h"
  31. #include "megbrain/serialization/serializer.h"
  32. #include "megbrain/version.h"
  33. #include <flatbuffers/flatbuffers.h>
  34. #include <cerrno>
  35. #include <cinttypes>
  36. #include <cstdio>
  37. using namespace mgb;
  38. using namespace mgb::serialization;
  39. namespace {
  40. constexpr uint32_t MGB_VERSION =
  41. (MGB_MAJOR * 1000 + MGB_MINOR) * 100 + MGB_PATCH;
  42. constexpr uint32_t MGB_MAGIC = 0x5342474D;
  43. template <typename T>
  44. bool contains_any_in_set(const SmallVector<T>& list,
  45. const ThinHashSet<T>& set) {
  46. for (const auto& x : list) {
  47. if (set.count(x)) {
  48. return true;
  49. }
  50. }
  51. return false;
  52. }
  53. void check_tensor_value_valid(const std::string& name,
  54. const HostTensorND& tensor) {
  55. bool cond_normal = tensor.layout().format.is_default() &&
  56. tensor.layout().is_physical_contiguous();
  57. bool cond_lowbit = tensor.layout().dtype.is_quantized_lowbit() &&
  58. tensor.layout().format.is_lowbit_aligned() &&
  59. tensor.layout().is_contiguous();
  60. mgb_assert(cond_normal || cond_lowbit,
  61. "non-contiguous tensor: name=%s layout=%s", name.c_str(),
  62. tensor.layout().to_string().c_str());
  63. if (tensor.dtype() == dtype::Float32()) {
  64. auto ptr = tensor.ptr<float>();
  65. for (size_t i = 0, it = tensor.shape().total_nr_elems(); i < it; ++i) {
  66. if (!std::isfinite(ptr[i])) {
  67. mgb_log_warn("invalid tensor value in %s: %g", name.c_str(),
  68. ptr[i]);
  69. break;
  70. }
  71. }
  72. }
  73. }
  74. } // namespace
  75. namespace mgb {
  76. namespace serialization {
  77. class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
  78. const std::unique_ptr<OutputFile> m_file;
  79. flatbuffers::FlatBufferBuilder m_builder;
  80. DumpConfig m_config;
  81. DumpResult m_cur_rst;
  82. size_t m_nr_shared_tensor;
  83. std::vector<std::pair<cg::OperatorNodeBase*, const OprRegistry*>>
  84. m_oprs_to_dump;
  85. ThinHashMap<VarNode*, size_t> m_var2id;
  86. //! set of output vars specified by user
  87. ThinHashSet<VarNode*> m_output_vars;
  88. std::unordered_set<std::string> m_used_input_names, m_used_param_names;
  89. //! current opr to be dumped
  90. cg::OperatorNodeBase* m_cur_opr = nullptr;
  91. // Will be filled in dump_tensor
  92. std::vector<flatbuffers::Offset<fbs::Tensor>> m_cur_opr_tensor;
  93. std::vector<flatbuffers::Offset<fbs::Blob>> m_blobs;
  94. std::vector<fbs::OperatorParam> m_cur_opr_param_type;
  95. std::vector<flatbuffers::Offset<void>> m_cur_opr_param;
  96. void init_oprs_to_dump(const SymbolVarArray& endpoints);
  97. flatbuffers::Offset<fbs::Metadata> build_metadata(const Metadata& metadata);
  98. flatbuffers::Offset<fbs::Operator> build_single_opr(
  99. cg::OperatorNodeBase* opr, const OprRegistry* registry);
  100. flatbuffers::Offset<fbs::DType> build_dtype(DType dtype);
  101. public:
  102. GraphDumperOSS(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {}
  103. DumpResult dump(const SymbolVarArray& output_vars,
  104. const DumpConfig& config = {},
  105. const Metadata& metadata = {}) override;
  106. const GraphDumpConfig& config() const override { return m_config; }
  107. void dump_tensor(const std::string& name, const HostTensorND& tensor,
  108. TensorWriteMethod method) override;
  109. flatbuffers::FlatBufferBuilder& builder() override { return m_builder; }
  110. void append_param(uint32_t type, uint32_t value) override {
  111. static_assert(std::is_same<uint32_t, flatbuffers::uoffset_t>::value,
  112. "append_param depends on uoffset_t being uint32_t");
  113. static_assert(std::is_standard_layout<flatbuffers::Offset<void>>::value,
  114. "append_param depends on flatbuffers::Offset having "
  115. "standard memory layout");
  116. mgb_assert(type != fbs::OperatorParam_NONE);
  117. m_cur_opr_param_type.emplace_back(
  118. static_cast<fbs::OperatorParam>(type));
  119. m_cur_opr_param.emplace_back(value);
  120. }
  121. void dump_buf_with_len(const void* data, uint32_t size) override;
  122. GraphDumpFormat format() const override {
  123. return GraphDumpFormat::FLATBUFFERS;
  124. }
  125. };
  126. flatbuffers::Offset<fbs::DType> GraphDumperOSS::build_dtype(DType dtype) {
  127. return fbs::intl::build_dtype(m_builder, dtype);
  128. }
  129. void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) {
  130. m_oprs_to_dump.clear();
  131. m_var2id.clear();
  132. // iterate oprs to init m_var2id
  133. size_t next_id = 0;
  134. auto on_opr = [&](cg::OperatorNodeBase* opr) {
  135. if (should_remove_in_dump(opr)) {
  136. mgb_assert(opr->input().size() == 1);
  137. // Copy input ID to output
  138. auto id = m_var2id.at(opr->input(0));
  139. for (auto i : opr->output())
  140. m_var2id[i] = id;
  141. } else {
  142. auto registry = OprRegistry::find_by_type(opr->dyn_typeinfo());
  143. if (!registry || !registry->dumper) {
  144. mgb_throw(cg::OperatorNodeExcExtraInfo::ExcMaker{opr}
  145. .make<MegBrainError>,
  146. "serialization as FlatBuffers is not supported for "
  147. "operator %s",
  148. opr->dyn_typeinfo()->name);
  149. }
  150. m_oprs_to_dump.emplace_back(opr, registry);
  151. for (auto i : opr->output()) {
  152. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  153. m_var2id[i] = next_id++;
  154. }
  155. }
  156. }
  157. };
  158. cg::DepOprIter dep_opr_iter{on_opr};
  159. for (auto i : endpoints) {
  160. dep_opr_iter.add(i.node()->owner_opr());
  161. }
  162. }
  163. flatbuffers::Offset<fbs::Metadata> GraphDumperOSS::build_metadata(
  164. const Metadata& metadata) {
  165. auto user_info = m_builder.CreateSharedString(metadata.user_info);
  166. fbs::MetadataBuilder builder(m_builder);
  167. builder.add_is_valid(metadata.is_valid);
  168. builder.add_graph_modified(metadata.graph_modified);
  169. builder.add_user_info(user_info);
  170. builder.add_optimize_options(metadata.optimize_options);
  171. return builder.Finish();
  172. }
  173. flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
  174. cg::OperatorNodeBase* opr, const OprRegistry* registry) {
  175. m_cur_opr = opr;
  176. ++m_cur_rst.nr_opr;
  177. using namespace flatbuffers;
  178. Offset<Vector<Offset<fbs::CompNode>>> comp_node;
  179. auto& config = opr->config();
  180. if (config.has_comp_node_set()) {
  181. std::vector<flatbuffers::Offset<fbs::CompNode>> cns;
  182. for (const auto& cn : config.comp_node()) {
  183. cns.emplace_back(fbs::CreateCompNode(
  184. m_builder,
  185. m_builder.CreateSharedString(cn.to_string_logical())));
  186. }
  187. comp_node = m_builder.CreateVector(cns);
  188. }
  189. Offset<Vector<uint32_t>> inputs;
  190. if (opr->input().size()) {
  191. std::vector<uint32_t> v;
  192. v.reserve(opr->input().size());
  193. for (auto inp : opr->input()) {
  194. v.emplace_back(m_var2id.at(inp));
  195. }
  196. inputs = m_builder.CreateVector(v);
  197. }
  198. Offset<String> operator_name;
  199. if (m_config.keep_op_name) {
  200. operator_name = m_builder.CreateSharedString(opr->name());
  201. }
  202. Offset<Vector<Offset<String>>> output_names;
  203. if (m_config.keep_var_name >= 2 ||
  204. (m_config.keep_var_name == 1 &&
  205. contains_any_in_set(opr->output(), m_output_vars))) {
  206. std::vector<std::string> onames;
  207. for (auto i : opr->output()) {
  208. if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  209. onames.emplace_back(i->name());
  210. }
  211. }
  212. output_names = m_builder.CreateVectorOfStrings(onames);
  213. }
  214. auto output_dtype = build_dtype(config.output_dtype());
  215. m_cur_opr_tensor.clear();
  216. m_blobs.clear();
  217. m_cur_opr_param.clear();
  218. m_cur_opr_param_type.clear();
  219. registry->dumper(*this, *opr);
  220. Offset<Vector<Offset<fbs::Tensor>>> tensors;
  221. if (m_cur_opr_tensor.size())
  222. tensors = m_builder.CreateVector(m_cur_opr_tensor);
  223. Offset<Vector<Offset<fbs::Blob>>> blobs;
  224. if (m_blobs.size())
  225. blobs = m_builder.CreateVector(m_blobs);
  226. Offset<Vector<uint8_t>> additional_params_type;
  227. Offset<Vector<Offset<void>>> additional_params;
  228. auto param_cnt = m_cur_opr_param_type.size();
  229. if (param_cnt > 1) {
  230. additional_params_type = m_builder.CreateVectorScalarCast<uint8_t>(
  231. m_cur_opr_param_type.data() + 1, param_cnt - 1);
  232. additional_params = m_builder.CreateVector(m_cur_opr_param.data() + 1,
  233. param_cnt - 1);
  234. }
  235. fbs::OperatorBuilder builder(m_builder);
  236. builder.add_type_id(registry->unversioned_type_id);
  237. builder.add_inputs(inputs);
  238. if (m_config.keep_opr_priority) {
  239. builder.add_priority(opr->node_prop().attribute().priority);
  240. }
  241. builder.add_comp_node(comp_node);
  242. builder.add_output_name(output_names);
  243. builder.add_name(operator_name);
  244. builder.add_output_dtype(output_dtype);
  245. if (param_cnt > 0) {
  246. builder.add_param_type(m_cur_opr_param_type[0]);
  247. builder.add_param(m_cur_opr_param[0]);
  248. }
  249. if (param_cnt > 1) {
  250. builder.add_additional_params_type(additional_params_type);
  251. builder.add_additional_params(additional_params);
  252. }
  253. builder.add_tensors(tensors);
  254. builder.add_blobs(blobs);
  255. m_cur_opr = nullptr;
  256. return builder.Finish();
  257. }
  258. GraphDumper::DumpResult GraphDumperOSS::dump(
  259. const SymbolVarArray& output_vars,
  260. const DumpConfig& config, const Metadata& metadata) {
  261. mgb_throw_if(output_vars.empty(), SerializationError,
  262. "Can't dump empty graph");
  263. auto begin_pos = m_file->tell();
  264. m_config = config;
  265. m_builder.Reset();
  266. m_output_vars.clear();
  267. m_cur_rst = {};
  268. m_used_input_names.clear();
  269. m_used_param_names.clear();
  270. m_nr_shared_tensor = 0;
  271. // process output vars
  272. bool keep_output_var_name = m_config.keep_var_name >= 1;
  273. std::unordered_set<std::string> output_var_names;
  274. for (auto i : output_vars) {
  275. mgb_assert(!i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT),
  276. "can not dump var with VOLATILE_CONTENT flag: %s",
  277. cg::dump_var_info({i.node()}).c_str());
  278. if (m_output_vars.insert(i.node()).second && keep_output_var_name) {
  279. auto name_ins = output_var_names.insert(i.node()->name()).second;
  280. mgb_assert(name_ins, "duplicated output var name: %s",
  281. i.node()->cname());
  282. }
  283. }
  284. // Write magic
  285. uint32_t magic = MGB_MAGIC;
  286. m_file->write(&magic, sizeof(magic));
  287. // Padding
  288. uint32_t reserved = 0;
  289. m_file->write(&reserved, sizeof(reserved));
  290. // Write placeholder for offset_to_fbs
  291. auto offset_pos = m_file->tell();
  292. uint64_t offset_to_fbs = 0;
  293. m_file->write(&offset_to_fbs, sizeof(offset_to_fbs));
  294. // Dump metadata
  295. auto fbmeta = build_metadata(metadata);
  296. // Dump operators
  297. init_oprs_to_dump(output_vars);
  298. std::vector<flatbuffers::Offset<fbs::Operator>> oprs;
  299. for (auto&& i : m_oprs_to_dump) {
  300. oprs.emplace_back(build_single_opr(i.first, i.second));
  301. }
  302. auto fb_oprs = m_builder.CreateVector(oprs);
  303. // Dump output vars
  304. std::vector<fbs::OutputVar> output_vars_idx;
  305. output_vars_idx.reserve(output_vars.size());
  306. for (auto i : output_vars) {
  307. output_vars_idx.emplace_back(m_var2id.at(i.node()), i.node()->id());
  308. }
  309. auto fb_output_vars = m_builder.CreateVectorOfStructs(output_vars_idx);
  310. XXHash content_hash;
  311. content_hash.update(m_builder.GetCurrentBufferPointer(),
  312. m_builder.GetSize());
  313. auto graph_hash = content_hash.digest();
  314. fbs::GraphBuilder graph(m_builder);
  315. graph.add_mgb_version(MGB_VERSION);
  316. graph.add_hash(graph_hash);
  317. graph.add_oprs(fb_oprs);
  318. graph.add_output_vars_idx(fb_output_vars);
  319. graph.add_nr_shared_tensor(m_nr_shared_tensor);
  320. graph.add_metadata(fbmeta);
  321. m_builder.FinishSizePrefixed(graph.Finish(), fbs::GraphIdentifier());
  322. // Write actual offset_to_fbs
  323. auto cur = m_file->tell();
  324. mgb_assert(cur >= offset_pos && cur - offset_pos >= sizeof(offset_to_fbs));
  325. offset_to_fbs = cur - offset_pos - sizeof(offset_to_fbs);
  326. m_file->seek(offset_pos);
  327. m_file->write(&offset_to_fbs, sizeof(offset_to_fbs));
  328. m_file->seek(cur);
  329. // Write serialized fbs::Graph
  330. m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize());
  331. // Finalize DumpResult
  332. auto&& ret = m_cur_rst;
  333. for (size_t i = 0; i < output_vars.size(); i++) {
  334. ret.outputs.emplace_back(keep_output_var_name
  335. ? output_vars[i].node()->cname()
  336. : ssprintf("unnamed%zu", i));
  337. }
  338. ret.content_hash = graph_hash;
  339. std::sort(ret.inputs.begin(), ret.inputs.end());
  340. mgb_assert(ret.nr_opr == m_oprs_to_dump.size());
  341. ret.tot_bytes = m_file->tell() - begin_pos;
  342. return ret;
  343. }
  344. void GraphDumperOSS::dump_tensor(const std::string& name,
  345. const HostTensorND& tensor,
  346. TensorWriteMethod method) {
  347. using namespace flatbuffers;
  348. using Meth = TensorWriteMethod;
  349. mgb_assert((method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()),
  350. "name must be non-empty for non Meth::VALUE_ANONYMOUS tensors");
  351. bool has_value = method != Meth::META_INPUT;
  352. bool should_keep_name = true;
  353. switch (method) {
  354. case Meth::VALUE_ANONYMOUS:
  355. should_keep_name = false;
  356. break;
  357. case Meth::VALUE_SHARED:
  358. should_keep_name = m_config.keep_param_name;
  359. ++m_nr_shared_tensor;
  360. if (m_config.keep_param_name) {
  361. mgb_assert(m_used_param_names.insert(name).second,
  362. "duplicated VALUE_SHARED tensor name: %s",
  363. name.c_str());
  364. m_cur_rst.params.emplace_back(name);
  365. }
  366. break;
  367. case Meth::META_INPUT:
  368. case Meth::VALUE_INPUT:
  369. mgb_assert(!name.empty(), "empty input tensor name");
  370. mgb_assert(m_used_input_names.insert(name).second,
  371. "duplicated input tensor name: %s", name.c_str());
  372. m_cur_rst.inputs.emplace_back(name);
  373. break;
  374. }
  375. size_t value_size = 0;
  376. if (has_value) {
  377. check_tensor_value_valid(name, tensor);
  378. auto begin = m_file->tell();
  379. auto&& dumper = m_config.tensor_value_dumper;
  380. if (dumper) {
  381. dumper(*m_file, *m_cur_opr, tensor);
  382. } else {
  383. m_file->write(tensor.raw_ptr(), tensor.layout().span().high_byte);
  384. }
  385. value_size = m_file->tell() - begin;
  386. m_cur_rst.tensor_value_bytes += value_size;
  387. }
  388. auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0;
  389. auto shape = m_builder.CreateVectorScalarCast<uint32_t>(
  390. tensor.shape().shape, tensor.shape().ndim);
  391. auto comp_node = fbs::CreateCompNode(
  392. m_builder, m_builder.CreateSharedString(
  393. tensor.comp_node().to_string_logical()));
  394. auto dtype = build_dtype(tensor.dtype());
  395. auto serialized_tensor = fbs::CreateTensor(m_builder, fbname, shape,
  396. comp_node, dtype, value_size);
  397. m_cur_opr_tensor.emplace_back(serialized_tensor);
  398. }
  399. void GraphDumperOSS::dump_buf_with_len(const void* data, uint32_t size) {
  400. auto blob = fbs::CreateBlob(
  401. m_builder,
  402. m_builder.CreateVector(static_cast<const uint8_t*>(data), size));
  403. m_blobs.emplace_back(blob);
  404. }
  405. // ----------------------------- Loader --------------------------------------
  406. class GraphLoaderOSS final : public GraphLoader {
  407. const LoadConfig* m_cur_load_config = nullptr;
  408. std::unique_ptr<InputFile> m_file;
  409. SharedBuffer m_graph_buf{{}, 0};
  410. const fbs::Graph* m_graph;
  411. SharedTensorIDMap m_shared_tensor_map;
  412. uint32_t m_mgb_version = 0;
  413. uint64_t m_graph_hash = 0;
  414. class OprLoadContextImpl;
  415. friend class OprLoadContextImpl;
  416. void verify();
  417. public:
  418. GraphLoaderOSS(std::unique_ptr<InputFile> input_file)
  419. : m_file{std::move(input_file)} {}
  420. std::unique_ptr<InputFile> reset_file(
  421. std::unique_ptr<InputFile> file) override {
  422. file.swap(m_file);
  423. return file;
  424. }
  425. LoadResult load(const LoadConfig& config, bool rewind) override;
  426. const SharedTensorIDMap& shared_tensor_id_map() const override {
  427. mgb_assert(m_graph_hash, "graph not loaded yet");
  428. return m_shared_tensor_map;
  429. }
  430. GraphDumpFormat format() const override {
  431. return GraphDumpFormat::FLATBUFFERS;
  432. }
  433. };
  434. class GraphLoaderOSS::OprLoadContextImpl final
  435. : public OprLoadContextFlatBuffers {
  436. GraphLoaderOSS* const m_loader;
  437. size_t m_cur_shared_tensor_idx = 0;
  438. std::shared_ptr<ComputingGraph> m_graph;
  439. LoadResult::TensorMap m_tensor_map;
  440. VarNodeArray m_id2varnode;
  441. BatchedDeviceValueLoader m_device_value_loader;
  442. const fbs::Operator* m_current_opr;
  443. size_t m_cur_opr_tensor_cnt;
  444. size_t m_cur_opr_blob_cnt;
  445. size_t m_cur_opr_param_cnt;
  446. ComputingGraph& graph() override { return *m_graph; }
  447. const GraphLoadConfig& config() const override {
  448. return *m_loader->m_cur_load_config;
  449. }
  450. void load_tensor_value(HostTensorND* dest, const TensorLayout& layout,
  451. const fbs::Tensor* tensor);
  452. std::shared_ptr<HostTensorND> load_tensor() override;
  453. std::shared_ptr<DeviceTensorND> load_tensor_shared() override;
  454. void load_single_opr(const fbs::Operator* opr);
  455. public:
  456. OprLoadContextImpl(GraphLoaderOSS* loader, uint32_t version)
  457. : OprLoadContextFlatBuffers(version), m_loader{loader} {
  458. m_graph = loader->m_cur_load_config->comp_graph;
  459. if (!m_graph) {
  460. m_graph = ComputingGraph::make();
  461. }
  462. auto maker = [this]() {
  463. return std::shared_ptr<OprLoadContext>{
  464. std::shared_ptr<OprLoadContext>{}, this};
  465. };
  466. auto got = m_graph->options()
  467. .user_data.get_user_data_or_create<OprLoadContext>(
  468. maker);
  469. mgb_assert(got == this);
  470. }
  471. ~OprLoadContextImpl() noexcept {
  472. auto nr = m_graph->options().user_data.pop_user_data<OprLoadContext>();
  473. mgb_assert(nr == 1);
  474. }
  475. Metadata load_metadata();
  476. LoadResult load_oprs();
  477. CompNode load_comp_node(const fbs::CompNode* comp_node);
  478. const void* get_next_param(uint32_t enumv) override {
  479. auto type = static_cast<fbs::OperatorParam>(enumv);
  480. if (m_cur_opr_param_cnt == 0) {
  481. m_cur_opr_param_cnt++;
  482. if (m_current_opr->param_type() == type) {
  483. return m_current_opr->param();
  484. }
  485. } else {
  486. mgb_assert(m_current_opr->additional_params() &&
  487. m_cur_opr_param_cnt - 1 <
  488. m_current_opr->additional_params()->size());
  489. auto i = m_cur_opr_param_cnt++ - 1;
  490. if (m_current_opr->additional_params_type()->Get(i) == type) {
  491. return m_current_opr->additional_params()->Get(i);
  492. }
  493. }
  494. return nullptr;
  495. }
  496. std::string load_buf_with_len() override {
  497. mgb_assert(m_current_opr->blobs() &&
  498. m_cur_opr_blob_cnt < m_current_opr->blobs()->size());
  499. auto blob = m_current_opr->blobs()->Get(m_cur_opr_blob_cnt++);
  500. mgb_assert(blob && blob->data());
  501. auto data = blob->data()->data();
  502. return {reinterpret_cast<const char*>(data), blob->data()->size()};
  503. }
  504. SharedBuffer load_shared_buf_with_len() override {
  505. mgb_assert(m_current_opr->blobs() &&
  506. m_cur_opr_blob_cnt < m_current_opr->blobs()->size());
  507. auto blob = m_current_opr->blobs()->Get(m_cur_opr_blob_cnt++);
  508. mgb_assert(blob && blob->data());
  509. auto size = blob->data()->size();
  510. std::shared_ptr<uint8_t> shptr{new uint8_t[size],
  511. [](uint8_t* p) { delete[] p; }};
  512. memcpy(shptr.get(), blob->data()->data(), size);
  513. return {std::move(shptr), size};
  514. }
  515. };
  516. CompNode GraphLoaderOSS::OprLoadContextImpl::load_comp_node(
  517. const fbs::CompNode* comp_node) {
  518. mgb_assert(comp_node);
  519. if (!comp_node->logical_locator())
  520. return {};
  521. auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str());
  522. m_loader->m_cur_load_config->comp_node_mapper(loc);
  523. return CompNode::load(loc);
  524. }
  525. TensorLayout load_tensor_layout(const fbs::Tensor* tensor) {
  526. TensorLayout layout;
  527. if (tensor->shape()) {
  528. layout.ndim = tensor->shape()->size();
  529. std::copy(tensor->shape()->begin(), tensor->shape()->end(),
  530. layout.shape);
  531. }
  532. if (tensor->dtype()) {
  533. // modify data type inplace for TensorLayout
  534. layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype()));
  535. }
  536. layout.init_contiguous_stride();
  537. return layout;
  538. }
  539. void GraphLoaderOSS::OprLoadContextImpl::load_tensor_value(
  540. HostTensorND* dest, const TensorLayout& layout,
  541. const fbs::Tensor* tensor) {
  542. auto&& loader = m_loader->m_cur_load_config->tensor_value_loader;
  543. auto&& file = m_loader->m_file;
  544. auto begin_pos = file->tell();
  545. file->skip(tensor->offset());
  546. if (loader) {
  547. // call custom loader
  548. void* dest_ptr = nullptr;
  549. if (dest) {
  550. dest->dtype(layout.dtype).resize(layout);
  551. dest_ptr = dest->raw_ptr();
  552. }
  553. loader(dest_ptr, layout, *file);
  554. } else {
  555. if (dest) {
  556. file->read_into_tensor(*dest, layout);
  557. } else {
  558. file->skip(layout.span().high_byte);
  559. }
  560. }
  561. mgb_throw_if(file->tell() < begin_pos, SerializationError,
  562. "Custom tensor value loader accessed out of range data before "
  563. "start of data blob");
  564. auto data_size = tensor->data_size();
  565. auto consumed_size = file->tell() - begin_pos;
  566. mgb_throw_if(consumed_size > data_size, SerializationError,
  567. "Custom tensor value loader consumed more data than "
  568. "available: consumed %lu, has %u",
  569. consumed_size, data_size);
  570. if (consumed_size < data_size) {
  571. mgb_log_warn(
  572. "Tensor value loader consumed less data than available: "
  573. "consumed %lu bytes, has %u bytes",
  574. consumed_size, data_size);
  575. file->skip(data_size - consumed_size);
  576. }
  577. }
  578. std::shared_ptr<HostTensorND>
  579. GraphLoaderOSS::OprLoadContextImpl::load_tensor() {
  580. mgb_assert(m_current_opr->tensors() &&
  581. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  582. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  583. auto comp_node = load_comp_node(tensor->comp_node());
  584. auto layout = load_tensor_layout(tensor);
  585. auto ret = std::make_shared<HostTensorND>(comp_node, layout);
  586. if (tensor->data_size()) {
  587. load_tensor_value(ret.get(), layout, tensor);
  588. }
  589. if (tensor->name()) {
  590. m_tensor_map[tensor->name()->str()] = ret;
  591. }
  592. if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) {
  593. mod(tensor->name() ? tensor->name()->str() : "",
  594. tensor->data_size() != 0, *ret);
  595. }
  596. return ret;
  597. }
  598. std::shared_ptr<DeviceTensorND>
  599. GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() {
  600. mgb_assert(m_current_opr->tensors() &&
  601. m_cur_opr_tensor_cnt < m_current_opr->tensors()->size());
  602. auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++);
  603. auto comp_node = load_comp_node(tensor->comp_node());
  604. auto layout = load_tensor_layout(tensor);
  605. mgb_assert(tensor->data_size());
  606. auto&& sh_reg = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++);
  607. auto&& sh_ptr_ref = sh_reg.second[comp_node.mem_node()];
  608. if (sh_ptr_ref) {
  609. // cached tensor value is valid so we can reuse it
  610. load_tensor_value(nullptr, layout, tensor);
  611. if (sh_ptr_ref->comp_node() == comp_node)
  612. return sh_ptr_ref;
  613. // same mem node but different comp node, change comp node and share
  614. // value
  615. auto ret = std::make_shared<DeviceTensorND>(*sh_ptr_ref);
  616. ret->comp_node(comp_node);
  617. return ret;
  618. }
  619. if (tensor->name()) {
  620. sh_reg.first = tensor->name()->str();
  621. }
  622. if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) {
  623. // directly forward CPU memory
  624. HostTensorND hv{comp_node};
  625. load_tensor_value(&hv, layout, tensor);
  626. sh_ptr_ref = std::make_shared<DeviceTensorND>();
  627. *sh_ptr_ref = DeviceTensorND::make_proxy(hv);
  628. } else {
  629. // use lazy load for non-CPU devices
  630. HostTensorND hv{CompNode::default_cpu()};
  631. load_tensor_value(&hv, layout, tensor);
  632. sh_ptr_ref = m_device_value_loader.make(comp_node, std::move(hv));
  633. }
  634. return sh_ptr_ref;
  635. }
  636. Metadata GraphLoaderOSS::OprLoadContextImpl::load_metadata() {
  637. const auto* fbmeta = m_loader->m_graph->metadata();
  638. Metadata ret;
  639. ret.is_valid = fbmeta->is_valid();
  640. ret.graph_modified = fbmeta->graph_modified();
  641. if (fbmeta->user_info()) {
  642. ret.user_info = fbmeta->user_info()->str();
  643. ret.has_user_info = true;
  644. }
  645. if (fbmeta->optimize_options()) {
  646. ret.optimize_options = fbmeta->optimize_options();
  647. ret.optimized_for_inference = true;
  648. }
  649. return ret;
  650. }
  651. void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
  652. const fbs::Operator* fbopr) {
  653. m_cur_opr_tensor_cnt = 0;
  654. m_cur_opr_blob_cnt = 0;
  655. m_cur_opr_param_cnt = 0;
  656. OperatorNodeConfig config;
  657. if (fbopr->output_dtype()) {
  658. config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
  659. }
  660. if (fbopr->name()) {
  661. config.name(fbopr->name()->str());
  662. }
  663. if (fbopr->comp_node()) {
  664. auto cnt = fbopr->comp_node()->size();
  665. cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);
  666. for (size_t i = 0; i < cnt; i++) {
  667. CompNode cn{};
  668. auto node = fbopr->comp_node()->Get(i);
  669. if (node) {
  670. cn = load_comp_node(node);
  671. }
  672. comp_node_arr[i] = cn;
  673. }
  674. config.comp_node_arr(comp_node_arr);
  675. }
  676. auto registry = OprRegistry::find_by_unversioned_id(fbopr->type_id());
  677. mgb_throw_if(!registry, SerializationError,
  678. "failed to find opr with type %s, use python env "
  679. "config.dump_registered_oprs() to get a dict that maps from "
  680. "opr id to opr name",
  681. std::to_string(fbopr->type_id()).c_str());
  682. // load inputs
  683. VarNodeArray inputs;
  684. if (fbopr->inputs()) {
  685. inputs.resize(fbopr->inputs()->size());
  686. for (size_t i = 0; i < inputs.size(); ++i) {
  687. inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i));
  688. }
  689. }
  690. // call loader
  691. auto opr = registry->loader(*this, inputs, config);
  692. // check opr type; note that:
  693. // 1. registry->type may be empty for dynamic opr loaders or legacy oprs
  694. // 2. due to some optimization, an opr may be replaced by ImmutableTensor
  695. mgb_assert(
  696. opr && (opr->dyn_typeinfo() == registry->type || !registry->type ||
  697. opr->same_type<opr::ImmutableTensor>()),
  698. "got_type=%s expected_type=%s",
  699. opr ? opr->dyn_typeinfo()->name : nullptr, registry->type->name);
  700. // record output vars; read output names
  701. size_t i = 0;
  702. for (auto ovar : opr->output()) {
  703. if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
  704. m_id2varnode.push_back(ovar);
  705. if (fbopr->output_name()) {
  706. ovar->name(fbopr->output_name()->Get(i++)->str());
  707. }
  708. }
  709. }
  710. opr->node_prop().attribute().priority = fbopr->priority();
  711. }
  712. GraphLoader::LoadResult GraphLoaderOSS::OprLoadContextImpl::load_oprs() {
  713. // load oprs
  714. const auto* oprs = m_loader->m_graph->oprs();
  715. {
  716. // inplace arith graph optimization is disabled during opr load
  717. // it tries to restore the same graph as it was dumped
  718. // see test TestSerializer2.LOGEXP for example
  719. GraphLoader::ScopedGraphOptDisabler _(m_graph);
  720. for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) {
  721. m_current_opr = oprs->Get(i);
  722. load_single_opr(m_current_opr);
  723. }
  724. }
  725. // batched loading device values
  726. m_device_value_loader.apply();
  727. LoadResult ret;
  728. ret.graph = m_graph;
  729. ret.tensor_map = m_tensor_map;
  730. const auto* outputs = m_loader->m_graph->output_vars_idx();
  731. ret.output_var_list.resize(outputs->size());
  732. for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) {
  733. auto out = outputs->Get(i);
  734. auto var = m_id2varnode.at(out->compact_id());
  735. ret.output_var_map[var->name()] = var;
  736. ret.output_var_map_id[out->original_id()] = var;
  737. ret.output_var_list[i] = var;
  738. }
  739. mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size());
  740. return ret;
  741. }
  742. GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
  743. bool rewind) {
  744. mgb_assert(m_file);
  745. m_cur_load_config = &config;
  746. if (rewind) {
  747. m_file->rewind();
  748. }
  749. uint32_t magic;
  750. m_file->read(&magic, sizeof(magic));
  751. mgb_throw_if(magic != MGB_MAGIC, SerializationError,
  752. "wrong magic: wanted %#08x, actual %#08x (not a invalid fbs "
  753. "model?)",
  754. MGB_MAGIC, magic);
  755. m_file->skip(4);
  756. uint64_t offset_to_fbs;
  757. m_file->read(&offset_to_fbs, sizeof(offset_to_fbs));
  758. auto tensor_begin = m_file->tell();
  759. // Skip tensor data
  760. m_file->skip(offset_to_fbs);
  761. // Read fbs::Graph
  762. uint32_t size;
  763. m_file->read(&size, sizeof(size));
  764. m_graph_buf = m_file->read_shared(size);
  765. // Rewind back to tensor data
  766. m_file->rewind();
  767. m_file->skip(tensor_begin);
  768. mgb_throw_if(!fbs::GraphBufferHasIdentifier(m_graph_buf.data()),
  769. SerializationError, "invalid fbs model");
  770. {
  771. flatbuffers::Verifier verifier(
  772. static_cast<const uint8_t*>(m_graph_buf.data()),
  773. m_graph_buf.size());
  774. mgb_throw_if(!fbs::VerifyGraphBuffer(verifier), SerializationError,
  775. "model verification failed (invalid or corrupted model?)");
  776. }
  777. m_graph = fbs::GetGraph(m_graph_buf.data());
  778. m_mgb_version = m_graph->mgb_version();
  779. if (m_graph->mgb_version() > MGB_VERSION) {
  780. mgb_log_warn(
  781. "loading model from future runtime: version=%u "
  782. "model_version=%u",
  783. MGB_VERSION, m_graph->mgb_version());
  784. }
  785. if (!m_graph_hash) {
  786. m_graph_hash = m_graph->hash();
  787. mgb_assert(m_graph_hash,
  788. "invalid graph hash; maybe error "
  789. "occurred during graph dump");
  790. } else {
  791. mgb_assert(m_graph_hash == m_graph->hash(),
  792. "A GraphLoader instance can be used to load only one graph,"
  793. " since the tensor values are shared. Previous graph hash "
  794. "is 0x%llx, current graph hash is 0x%llx.",
  795. static_cast<unsigned long long>(m_graph_hash),
  796. static_cast<unsigned long long>(m_graph->hash()));
  797. }
  798. if (m_shared_tensor_map.empty()) {
  799. m_shared_tensor_map.resize(m_graph->nr_shared_tensor());
  800. } else {
  801. mgb_assert(m_shared_tensor_map.size() == m_graph->nr_shared_tensor());
  802. }
  803. OprLoadContextImpl ctx{this, m_graph->mgb_version()};
  804. auto metadata = ctx.load_metadata();
  805. auto result = ctx.load_oprs();
  806. result.metadata = metadata;
  807. auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size;
  808. auto cur = m_file->tell();
  809. mgb_assert(fbs_end > cur);
  810. // Skip to Graph end
  811. m_file->skip(fbs_end - cur);
  812. return result;
  813. }
  814. std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file) {
  815. return std::make_unique<GraphDumperOSS>(std::move(file));
  816. }
  817. std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file) {
  818. return std::make_unique<GraphLoaderOSS>(std::move(file));
  819. }
  820. bool is_fbs_file(InputFile& file) {
  821. uint64_t magic_with_reserved = 0;
  822. file.read(&magic_with_reserved, sizeof(magic_with_reserved));
  823. file.skip(-sizeof(magic_with_reserved));
  824. return magic_with_reserved == MGB_MAGIC;
  825. }
  826. } // namespace serialization
  827. } // namespace mgb
  828. #endif

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台