Browse Source

fix(serialization): when the dump fbsv2 model is used, the middle_tensor becomes optional

GitOrigin-RevId: 3d0bbfd441
release-1.11.1
Megvii Engine Team 2 years ago
parent
commit
dee0228982
1 changed files with 23 additions and 5 deletions
  1. +23
    -5
      src/serialization/impl/serializer_oss_v2.cpp

+ 23
- 5
src/serialization/impl/serializer_oss_v2.cpp View File

@@ -161,8 +161,16 @@ flatbuffers::Offset<fbs::v2::MiddleTensor> GraphDumperOSSV2::build_middle_tensor
auto fformat = build_tensor_format(layout.format);
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(
m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat);
} else if (var.node()->shape().ndim > 0) {
auto shape = var.node()->shape();
auto fshape =
m_builder.CreateVectorScalarCast<uint32_t>(shape.shape, shape.ndim);
serialized_middle_tensor =
fbs::v2::CreateMiddleTensor(m_builder, fbname, fshape);

} else {
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
}
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname);
return serialized_middle_tensor;
}

@@ -278,8 +286,12 @@ flatbuffers::Offset<fbs::v2::Operator> GraphDumperOSSV2::build_single_opr(
v.reserve(m_cur_opr->output().size());
for (auto out : m_cur_opr->output()) {
if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
auto fbs_out = build_middle_tensor(out);
m_model_middle_tensors.push_back(fbs_out);
if (m_config.keep_var_name >= 1) {
auto fbs_out = build_middle_tensor(out);
m_model_middle_tensors.push_back(fbs_out);
} else {
m_model_middle_tensors.push_back(0);
}
m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1;
v.emplace_back(m_var2midtensor_id.at(out));
}
@@ -425,13 +437,19 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
}
}
auto fbs_output_alias = m_builder.CreateVector(output_vars_alias);
auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);
flatbuffers::Offset<flatbuffers::Vector<
flatbuffers::Offset<mgb::serialization::fbs::v2::MiddleTensor>>>
fb_mid_tensor;
if (m_config.keep_var_name >= 1)
fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors);

fbs::v2::ModelBuilder model(m_builder);
model.add_mge_version(MGB_VERSION);
model.add_model_version(m_version);
model.add_oprs(fb_oprs);
model.add_middle_tensors(fb_mid_tensor);
if (m_config.keep_var_name >= 1) {
model.add_middle_tensors(fb_mid_tensor);
}
model.add_output_vars_idx(fb_output_vars);
model.add_output_alias(fbs_output_alias);
model.add_nr_shared_tensor(m_nr_shared_tensor);


Loading…
Cancel
Save