Browse Source

fix(mgb/serialization): fix multiple graph load error

GitOrigin-RevId: 89414b014b
release-0.5
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
129fa70cba
2 changed files with 31 additions and 3 deletions
  1. +2
    -2
      src/serialization/impl/serializer_oss.cpp
  2. +29
    -1
      src/serialization/test/serializer_oss.cpp

+ 2
- 2
src/serialization/impl/serializer_oss.cpp View File

@@ -846,7 +846,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
OprLoadContextImpl ctx{this, m_graph->mgb_version()};
auto result = ctx.load_oprs();

auto fbs_end = tensor_begin + offset_to_fbs + size;
auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size;
auto cur = m_file->tell();
mgb_assert(fbs_end > cur);
// Skip to Graph end
@@ -872,4 +872,4 @@ bool is_fbs_file(InputFile& file) {
} // namespace serialization
} // namespace mgb

#endif
#endif

+ 29
- 1
src/serialization/test/serializer_oss.cpp View File

@@ -64,6 +64,34 @@ TEST(TestSerializer2, GraphDumpLoad) {
load();
}

TEST(TestSerializer2, MultiGraphDumpLoad) {
auto fname = GET_OUTPUT_FILE();

auto dump = [&]() {
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
auto x = opr::ImmutableTensor::make(*graph, 1926.0817f, {cn});
x.rename("varz");
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
// dump twice
dumper->dump({x});
dumper->dump({x});
};
auto load = [&]() {
GraphLoader::LoadConfig load_config = {};
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
// load twice
loader->load(load_config, false);
loader = GraphLoader::make(loader->reset_file(), loader->format());
loader->load(load_config, false);
};

dump();
load();
}

TEST(TestSerializer2, APlusB) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};
@@ -733,4 +761,4 @@ TEST(TestSerializer2, HasOutputDtype) {
load();
}

#endif
#endif

Loading…
Cancel
Save