diff --git a/src/serialization/impl/serializer_oss_v2.cpp b/src/serialization/impl/serializer_oss_v2.cpp index e04257bf..fbab8b2a 100644 --- a/src/serialization/impl/serializer_oss_v2.cpp +++ b/src/serialization/impl/serializer_oss_v2.cpp @@ -161,8 +161,16 @@ flatbuffers::Offset GraphDumperOSSV2::build_middle_tensor auto fformat = build_tensor_format(layout.format); serialized_middle_tensor = fbs::v2::CreateMiddleTensor( m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat); + } else if (var.node()->shape().ndim > 0) { + auto shape = var.node()->shape(); + auto fshape = + m_builder.CreateVectorScalarCast(shape.shape, shape.ndim); + serialized_middle_tensor = + fbs::v2::CreateMiddleTensor(m_builder, fbname, fshape); + + } else { + serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname); } - serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname); return serialized_middle_tensor; } @@ -278,8 +286,12 @@ flatbuffers::Offset GraphDumperOSSV2::build_single_opr( v.reserve(m_cur_opr->output().size()); for (auto out : m_cur_opr->output()) { if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { - auto fbs_out = build_middle_tensor(out); - m_model_middle_tensors.push_back(fbs_out); + if (m_config.keep_var_name >= 1) { + auto fbs_out = build_middle_tensor(out); + m_model_middle_tensors.push_back(fbs_out); + } else { + m_model_middle_tensors.push_back(0); + } m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1; v.emplace_back(m_var2midtensor_id.at(out)); } @@ -425,13 +437,19 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( } } auto fbs_output_alias = m_builder.CreateVector(output_vars_alias); - auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors); + flatbuffers::Offset>> + fb_mid_tensor; + if (m_config.keep_var_name >= 1) + fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors); fbs::v2::ModelBuilder model(m_builder); model.add_mge_version(MGB_VERSION); model.add_model_version(m_version); model.add_oprs(fb_oprs); - model.add_middle_tensors(fb_mid_tensor); + if (m_config.keep_var_name >= 1) { + model.add_middle_tensors(fb_mid_tensor); + } model.add_output_vars_idx(fb_output_vars); model.add_output_alias(fbs_output_alias); model.add_nr_shared_tensor(m_nr_shared_tensor);