|
|
@@ -400,6 +400,18 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( |
|
|
|
output_vars_idx.push_back(foutput_vars_idx); |
|
|
|
} |
|
|
|
auto fb_output_vars = m_builder.CreateVector(output_vars_idx); |
|
|
|
std::vector<flatbuffers::Offset<fbs::v2::OutputAlias>> output_vars_alias; |
|
|
|
if (m_config.alias_name_map.size() > 0) { |
|
|
|
for (auto&& pair : m_config.alias_name_map) { |
|
|
|
std::string name; |
|
|
|
SymbolVar var; |
|
|
|
std::tie(name, var) = pair; |
|
|
|
auto fbs_name = m_builder.CreateSharedString(name); |
|
|
|
output_vars_alias.push_back( |
|
|
|
fbs::v2::CreateOutputAlias(m_builder, var.node()->id(), fbs_name)); |
|
|
|
} |
|
|
|
} |
|
|
|
auto fbs_output_alias = m_builder.CreateVector(output_vars_alias); |
|
|
|
auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors); |
|
|
|
|
|
|
|
fbs::v2::ModelBuilder model(m_builder); |
|
|
@@ -407,6 +419,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( |
|
|
|
model.add_oprs(fb_oprs); |
|
|
|
model.add_middle_tensors(fb_mid_tensor); |
|
|
|
model.add_output_vars_idx(fb_output_vars); |
|
|
|
model.add_output_alias(fbs_output_alias); |
|
|
|
model.add_nr_shared_tensor(m_nr_shared_tensor); |
|
|
|
model.add_metadata(fbmeta); |
|
|
|
m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier()); |
|
|
@@ -469,7 +482,7 @@ void GraphDumperOSSV2::dump_tensor( |
|
|
|
if (dumper) { |
|
|
|
mgb_log_warn( |
|
|
|
"serialization v2 format is pure flatbuffer format, not support " |
|
|
|
"user tensor value dumper"); |
|
|
|
"user tensor value dumper callback."); |
|
|
|
} |
|
|
|
data = m_builder.CreateVector( |
|
|
|
reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte); |
|
|
@@ -568,7 +581,7 @@ std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor( |
|
|
|
if (loader) { |
|
|
|
mgb_log_warn( |
|
|
|
"serialization v2 format is pure flatbuffer format, not support " |
|
|
|
"user tensor value loader"); |
|
|
|
"user tensor value loader callback."); |
|
|
|
} |
|
|
|
memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size()); |
|
|
|
} |
|
|
@@ -677,15 +690,14 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr( |
|
|
|
//! opr version must be exist |
|
|
|
uint8_t opr_version = fbopr->opr_version(); |
|
|
|
auto type_id = fbopr->type_id(); |
|
|
|
auto opr_type = fbopr->type()->str(); |
|
|
|
const OprRegistryV2* registry = |
|
|
|
OprRegistryV2::versioned_find_by_id(type_id, opr_version); |
|
|
|
mgb_throw_if( |
|
|
|
!registry, SerializationError, |
|
|
|
"failed to find opr with type %s id is %zu, use python env " |
|
|
|
"failed to find opr with type %s , use python env " |
|
|
|
"config.dump_registered_oprs() to get a dict that maps from " |
|
|
|
"opr id to opr name", |
|
|
|
fbopr->type()->str().c_str(), type_id); |
|
|
|
fbopr->type()->str().c_str()); |
|
|
|
|
|
|
|
// load inputs |
|
|
|
VarNodeArray inputs; |
|
|
@@ -817,6 +829,17 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re |
|
|
|
auto metadata = ctx.load_metadata(); |
|
|
|
auto result = ctx.load_oprs(); |
|
|
|
result.metadata = metadata; |
|
|
|
if (m_model->output_alias() && m_model->output_alias()->size() > 0) { |
|
|
|
auto nr_alias = m_model->output_alias()->size(); |
|
|
|
result.output_var_list.resize(nr_alias); |
|
|
|
for (size_t i = 0; i < nr_alias; i++) { |
|
|
|
auto output_alias = m_model->output_alias()->Get(i); |
|
|
|
std::string name = output_alias->name()->str(); |
|
|
|
size_t id = output_alias->id(); |
|
|
|
result.output_var_map[name] = result.output_var_map_id[id]; |
|
|
|
result.output_var_list[i] = result.output_var_map_id[id]; |
|
|
|
} |
|
|
|
} |
|
|
|
m_model_loaded = true; |
|
|
|
result.graph_compile_ahead(); |
|
|
|
return result; |
|
|
|