Browse Source

feat(mgb): upgrade flatbuffer

GitOrigin-RevId: 7b1a04934e
release-1.10
Megvii Engine Team 3 years ago
parent
commit
f0088335bb
5 changed files with 14 additions and 20 deletions
  1. +3
    -3
      src/serialization/impl/file.cpp
  2. +4
    -7
      src/serialization/impl/serializer_oss.cpp
  3. +4
    -7
      src/serialization/impl/serializer_oss_v2.cpp
  4. +1
    -1
      src/serialization/include/megbrain/serialization/file.h
  5. +2
    -2
      toolchains/ios.toolchain.cmake

+ 3
- 3
src/serialization/impl/file.cpp View File

@@ -33,7 +33,7 @@ public:


void rewind() override { std::rewind(m_fptr); } void rewind() override { std::rewind(m_fptr); }


void skip(size_t bytes) override {
void skip(int64_t bytes) override {
auto err = fseek(m_fptr, bytes, SEEK_CUR); auto err = fseek(m_fptr, bytes, SEEK_CUR);
mgb_assert(!err); mgb_assert(!err);
} }
@@ -104,7 +104,7 @@ public:


void rewind() override { m_offset = 0; } void rewind() override { m_offset = 0; }


void skip(size_t bytes) override {
void skip(int64_t bytes) override {
m_offset += bytes; m_offset += bytes;
mgb_assert(m_offset <= m_size); mgb_assert(m_offset <= m_size);
} }
@@ -146,7 +146,7 @@ public:
m_offset = 0; m_offset = 0;
} }


void skip(size_t bytes) override {
void skip(int64_t bytes) override {
m_offset += bytes; m_offset += bytes;
mgb_assert(m_offset <= m_size); mgb_assert(m_offset <= m_size);
} }


+ 4
- 7
src/serialization/impl/serializer_oss.cpp View File

@@ -838,25 +838,22 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, bool rewi
// Read fbs::Graph // Read fbs::Graph
uint32_t size; uint32_t size;
m_file->read(&size, sizeof(size)); m_file->read(&size, sizeof(size));
m_graph_buf = m_file->read_shared(size);
m_file->skip(-sizeof(size));
m_graph_buf = m_file->read_shared(size + sizeof(size));


// Rewind back to tensor data // Rewind back to tensor data
m_file->rewind(); m_file->rewind();
m_file->skip(tensor_begin); m_file->skip(tensor_begin);


mgb_throw_if(
!fbs::GraphBufferHasIdentifier(m_graph_buf.data()), SerializationError,
"invalid fbs model");

{ {
flatbuffers::Verifier verifier( flatbuffers::Verifier verifier(
static_cast<const uint8_t*>(m_graph_buf.data()), m_graph_buf.size()); static_cast<const uint8_t*>(m_graph_buf.data()), m_graph_buf.size());
mgb_throw_if( mgb_throw_if(
!fbs::VerifyGraphBuffer(verifier), SerializationError,
!fbs::VerifySizePrefixedGraphBuffer(verifier), SerializationError,
"model verification failed (invalid or corrupted model?)"); "model verification failed (invalid or corrupted model?)");
} }


m_graph = fbs::GetGraph(m_graph_buf.data());
m_graph = fbs::GetSizePrefixedGraph(m_graph_buf.data());
m_mgb_version = m_graph->mgb_version(); m_mgb_version = m_graph->mgb_version();
if (m_graph->mgb_version() > MGB_VERSION) { if (m_graph->mgb_version() > MGB_VERSION) {
mgb_log_warn( mgb_log_warn(


+ 4
- 7
src/serialization/impl/serializer_oss_v2.cpp View File

@@ -801,21 +801,18 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
// Read fbs::Graph // Read fbs::Graph
uint32_t size; uint32_t size;
m_file->read(&size, sizeof(size)); m_file->read(&size, sizeof(size));
m_model_buf = m_file->read_shared(size);

mgb_throw_if(
!fbs::v2::ModelBufferHasIdentifier(m_model_buf.data()), SerializationError,
"invalid fbs model");
m_file->skip(-sizeof(size));
m_model_buf = m_file->read_shared(size + sizeof(size));


{ {
flatbuffers::Verifier verifier( flatbuffers::Verifier verifier(
static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size()); static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size());
mgb_throw_if( mgb_throw_if(
!fbs::v2::VerifyModelBuffer(verifier), SerializationError,
!fbs::v2::VerifySizePrefixedModelBuffer(verifier), SerializationError,
"model verification failed (invalid or corrupted model?)"); "model verification failed (invalid or corrupted model?)");
} }


m_model = fbs::v2::GetModel(m_model_buf.data());
m_model = fbs::v2::GetSizePrefixedModel(m_model_buf.data());
m_mgb_version = m_model->mge_version(); m_mgb_version = m_model->mge_version();
m_model_version = m_model->model_version(); m_model_version = m_model->model_version();
if (m_model->mge_version() > MGB_VERSION) { if (m_model->mge_version() > MGB_VERSION) {


+ 1
- 1
src/serialization/include/megbrain/serialization/file.h View File

@@ -33,7 +33,7 @@ public:
virtual void rewind() = 0; virtual void rewind() = 0;


//! skip given number of bytes //! skip given number of bytes
virtual void skip(size_t bytes) = 0;
virtual void skip(int64_t bytes) = 0;


//! read data into buffer //! read data into buffer
virtual void read(void* dst, size_t size) = 0; virtual void read(void* dst, size_t size) = 0;


+ 2
- 2
toolchains/ios.toolchain.cmake View File

@@ -217,10 +217,10 @@ if(NOT DEFINED IOS_DEPLOYMENT_TARGET)
"2.0" "2.0"
CACHE STRING "Minimum iOS version to build for.") CACHE STRING "Minimum iOS version to build for.")
else() else()
# Unless specified, SDK version 10.0 is used by default as minimum target version
# Unless specified, SDK version 11.0 is used by default as minimum target version
# (iOS, tvOS). # (iOS, tvOS).
set(IOS_DEPLOYMENT_TARGET set(IOS_DEPLOYMENT_TARGET
"10.0"
"11.0"
CACHE STRING "Minimum iOS version to build for.") CACHE STRING "Minimum iOS version to build for.")
endif() endif()
message( message(


Loading…
Cancel
Save