|
|
@@ -0,0 +1,847 @@ |
|
|
|
#if MGB_ENABLE_FBS_SERIALIZATION |
|
|
|
|
|
|
|
#include "megbrain/comp_node_env.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/serialization/helper.h" |
|
|
|
#include "megbrain/serialization/internal/flatbuffers_helper.h" |
|
|
|
#include "megbrain/serialization/internal/schema_v2_generated.h" |
|
|
|
#include "megbrain/serialization/metadata.h" |
|
|
|
#include "megbrain/serialization/opr_load_dump.h" |
|
|
|
#include "megbrain/serialization/oss_opr_load_dump.h" |
|
|
|
#include "megbrain/utils/hash_ct.h" |
|
|
|
#include "megdnn/tensor_format.h" |
|
|
|
#include "serializer_oss_common.h" |
|
|
|
|
|
|
|
#include "megbrain/gopt/framework.h" |
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
namespace serialization { |
|
|
|
|
|
|
|
/*! |
|
|
|
* \brief replace the the opr who has the replace_opr methord in OprLoadDumpImplV2 |
|
|
|
*/ |
|
|
|
class PassConvertToCompatible : public gopt::Pass { |
|
|
|
ThinHashMap< |
|
|
|
Typeinfo*, thin_function<cg::OperatorNodeBase*( |
|
|
|
cg::OperatorNodeBase*, const VarNodeArray&)>> |
|
|
|
m_opr_replace_func; |
|
|
|
gopt::VarReplaceCheckFlag m_var_replace_check_flag = |
|
|
|
gopt::VarReplaceCheckFlag::CHECK_ALL; |
|
|
|
|
|
|
|
public: |
|
|
|
const char* name() const override { return "PassConvertToCompatible"; }; |
|
|
|
|
|
|
|
PassConvertToCompatible& set_var_replace_check_flag( |
|
|
|
gopt::VarReplaceCheckFlag flag) { |
|
|
|
m_var_replace_check_flag = flag; |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
void apply(gopt::OptState& state) const override { |
|
|
|
state.set_var_replace_check_flag(m_var_replace_check_flag); |
|
|
|
auto rewriter = state.graph().make_rewriter(); |
|
|
|
|
|
|
|
auto on_opr = [this, &rewriter](cg::OperatorNodeBase* opr) { |
|
|
|
auto it = m_opr_replace_func.find(opr->dyn_typeinfo()); |
|
|
|
if (it != m_opr_replace_func.end()) { |
|
|
|
VarNodeArray new_inp; |
|
|
|
new_inp.clear(); |
|
|
|
new_inp.reserve(opr->input().size()); |
|
|
|
for (auto i : opr->input()) { |
|
|
|
new_inp.push_back(rewriter.get_var(i)); |
|
|
|
} |
|
|
|
auto new_opr = (it->second)(opr, new_inp); |
|
|
|
|
|
|
|
auto &&origin_out = opr->output(), &&cur_out = new_opr->output(); |
|
|
|
for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size()); |
|
|
|
i++) { |
|
|
|
rewriter.replace_var(origin_out[i], cur_out[i], nullptr); |
|
|
|
} |
|
|
|
} else { |
|
|
|
rewriter.auto_replace_outputs(opr); |
|
|
|
} |
|
|
|
}; |
|
|
|
state.graph().iter(on_opr); |
|
|
|
rewriter.apply_inplace(); |
|
|
|
} |
|
|
|
|
|
|
|
static std::unique_ptr<PassConvertToCompatible> make( |
|
|
|
const SymbolVarArray& output_vars) { |
|
|
|
auto ret = std::make_unique<PassConvertToCompatible>(); |
|
|
|
// iterate oprs to init |
|
|
|
auto on_opr = [&](cg::OperatorNodeBase* opr) { |
|
|
|
if (!GraphDumper::should_remove_in_dump(opr)) { |
|
|
|
auto registry = OprRegistryV2::versioned_find_by_typeinfo( |
|
|
|
opr->dyn_typeinfo(), CURRENT_VERSION); |
|
|
|
mgb_throw_if( |
|
|
|
!registry, |
|
|
|
cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>, |
|
|
|
"serialization as FlatBuffers is not supported for " |
|
|
|
"operator %s, typeinfo %p", |
|
|
|
opr->dyn_typeinfo()->name, opr->dyn_typeinfo()); |
|
|
|
if (registry->converter) { |
|
|
|
ret->m_opr_replace_func[opr->dyn_typeinfo()] = registry->converter; |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
cg::DepOprIter dep_opr_iter{on_opr}; |
|
|
|
for (auto i : output_vars) { |
|
|
|
dep_opr_iter.add(i.node()->owner_opr()); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
}; |
|
|
|
}; |
|
|
|
|
|
|
|
namespace { |
|
|
|
fbs::v2::TensorFormat get_flatbuffer_tensor_format_type( |
|
|
|
const TensorLayout::Format& format) { |
|
|
|
using Type = megdnn::TensorFormat::Type; |
|
|
|
switch (format.type()) { |
|
|
|
case Type::DEFAULT: |
|
|
|
return fbs::v2::TensorFormat::TensorFormat_DefaultTensorFormat; |
|
|
|
case Type::IMAGE2D_PACK4: |
|
|
|
return fbs::v2::TensorFormat::TensorFormat_Image2DPackedTensorFormat; |
|
|
|
case Type::LOWBITS_ALIGNED_TO_BYTE: |
|
|
|
return fbs::v2::TensorFormat::TensorFormat_LowbitsAlignedTensorFormat; |
|
|
|
default: |
|
|
|
mgb_throw( |
|
|
|
SerializationError, "invalid tensor format type in serialization."); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
flatbuffers::Offset<fbs::DType> GraphDumperOSSV2::build_dtype(DType dtype) { |
|
|
|
return fbs::intl::build_dtype(m_builder, dtype); |
|
|
|
} |
|
|
|
|
|
|
|
flatbuffers::Offset<void> GraphDumperOSSV2::build_tensor_format( |
|
|
|
const TensorLayout::Format& format) { |
|
|
|
using Type = megdnn::TensorFormat::Type; |
|
|
|
switch (format.type()) { |
|
|
|
case Type::DEFAULT: |
|
|
|
return fbs::v2::CreateDefaultTensorFormat(m_builder).Union(); |
|
|
|
case Type::IMAGE2D_PACK4: |
|
|
|
return fbs::v2::CreateImage2DPackedTensorFormat( |
|
|
|
m_builder, format.as_impl<megdnn::Image2DPack4TensorFormat>() |
|
|
|
.align_axis()) |
|
|
|
.Union(); |
|
|
|
case Type::LOWBITS_ALIGNED_TO_BYTE: { |
|
|
|
auto size_bite = format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>() |
|
|
|
.size_nbits(); |
|
|
|
auto align_size_in_bits = |
|
|
|
format.as_impl<megdnn::LowbitsAlignedToBytesTensorFormat>() |
|
|
|
.align_size_in_bits(); |
|
|
|
return fbs::v2::CreateLowbitsAlignedTensorFormat( |
|
|
|
m_builder, size_bite, align_size_in_bits) |
|
|
|
.Union(); |
|
|
|
} |
|
|
|
default: |
|
|
|
mgb_throw( |
|
|
|
SerializationError, "invalid tensor format type in serialization."); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
flatbuffers::Offset<fbs::v2::MiddleTensor> GraphDumperOSSV2::build_middle_tensor( |
|
|
|
const SymbolVar var) { |
|
|
|
mgb_assert(var.node()); |
|
|
|
auto fbname = m_builder.CreateSharedString(var.node()->name()); |
|
|
|
flatbuffers::Offset<fbs::v2::MiddleTensor> serialized_middle_tensor; |
|
|
|
if (var.node()->dev_tensor_valid()) { |
|
|
|
auto layout = var.node()->layout(); |
|
|
|
auto fshape = |
|
|
|
m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim); |
|
|
|
|
|
|
|
auto fcomp_node = fbs::v2::CreateCompNode( |
|
|
|
m_builder, m_builder.CreateSharedString( |
|
|
|
var.node()->comp_node().to_string_logical())); |
|
|
|
|
|
|
|
auto fdtype = build_dtype(layout.dtype); |
|
|
|
auto fformat_type = get_flatbuffer_tensor_format_type(layout.format); |
|
|
|
auto fformat = build_tensor_format(layout.format); |
|
|
|
serialized_middle_tensor = fbs::v2::CreateMiddleTensor( |
|
|
|
m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat); |
|
|
|
} |
|
|
|
serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname); |
|
|
|
return serialized_middle_tensor; |
|
|
|
} |
|
|
|
|
|
|
|
flatbuffers::Offset<fbs::v2::OutputVar> GraphDumperOSSV2::build_output_var( |
|
|
|
const SymbolVar var) { |
|
|
|
auto out_node = var.node(); |
|
|
|
if (m_var2midtensor_id.find(var.node()) == m_var2midtensor_id.end()) { |
|
|
|
mgb_assert(m_var_remove_in_dump.find(var.node()) != m_var_remove_in_dump.end()); |
|
|
|
out_node = m_var_remove_in_dump[var.node()]; |
|
|
|
} |
|
|
|
return fbs::v2::CreateOutputVar( |
|
|
|
m_builder, m_var2midtensor_id.at(out_node), var.node()->id()); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) { |
|
|
|
m_oprs_to_dump.clear(); |
|
|
|
|
|
|
|
// iterate oprs to init |
|
|
|
auto on_opr = [&](cg::OperatorNodeBase* opr) { |
|
|
|
if (should_remove_in_dump(opr)) { |
|
|
|
mgb_assert(opr->input().size() == 1); |
|
|
|
// Copy input ID to output |
|
|
|
for (auto i : opr->output()) { |
|
|
|
if (m_var_remove_in_dump.find(opr->input(0)) != |
|
|
|
m_var_remove_in_dump.end()) { |
|
|
|
m_var_remove_in_dump[i] = m_var_remove_in_dump[opr->input(0)]; |
|
|
|
} else { |
|
|
|
m_var_remove_in_dump[i] = opr->input(0); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto registry = OprRegistryV2::versioned_find_by_typeinfo( |
|
|
|
opr->dyn_typeinfo(), CURRENT_VERSION); |
|
|
|
if (!registry || !registry->dumper) { |
|
|
|
mgb_throw( |
|
|
|
cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>, |
|
|
|
"serialization as FlatBuffers is not supported for " |
|
|
|
"operator %s", |
|
|
|
opr->dyn_typeinfo()->name); |
|
|
|
} |
|
|
|
m_oprs_to_dump.emplace_back(opr, registry); |
|
|
|
} |
|
|
|
}; |
|
|
|
cg::DepOprIter dep_opr_iter{on_opr}; |
|
|
|
for (auto i : endpoints) { |
|
|
|
dep_opr_iter.add(i.node()->owner_opr()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
flatbuffers::Offset<fbs::v2::Metadata> GraphDumperOSSV2::build_metadata( |
|
|
|
const Metadata& metadata) { |
|
|
|
auto user_info = m_builder.CreateSharedString(metadata.user_info); |
|
|
|
fbs::v2::MetadataBuilder builder(m_builder); |
|
|
|
builder.add_is_valid(metadata.is_valid); |
|
|
|
builder.add_graph_modified(metadata.graph_modified); |
|
|
|
builder.add_optimize_options(metadata.optimize_options); |
|
|
|
builder.add_user_info(user_info); |
|
|
|
return builder.Finish(); |
|
|
|
} |
|
|
|
|
|
|
|
flatbuffers::Offset<fbs::v2::Operator> GraphDumperOSSV2::build_single_opr( |
|
|
|
cg::OperatorNodeBase* opr, const OprRegistryV2* registry) { |
|
|
|
m_cur_opr = opr; |
|
|
|
++m_cur_rst.nr_opr; |
|
|
|
|
|
|
|
using namespace flatbuffers; |
|
|
|
Offset<Vector<uint32_t>> inputs; |
|
|
|
if (m_cur_opr->input().size()) { |
|
|
|
std::vector<uint32_t> v; |
|
|
|
v.reserve(m_cur_opr->input().size()); |
|
|
|
for (auto inp : m_cur_opr->input()) { |
|
|
|
if (m_var2midtensor_id.find(inp) != m_var2midtensor_id.end()) { |
|
|
|
v.emplace_back(m_var2midtensor_id.at(inp)); |
|
|
|
} else { |
|
|
|
mgb_assert( |
|
|
|
m_var_remove_in_dump.find(inp) != m_var_remove_in_dump.end(), |
|
|
|
"when dump the model, the dependence of var is wrong."); |
|
|
|
v.emplace_back(m_var2midtensor_id.at(m_var_remove_in_dump[inp])); |
|
|
|
} |
|
|
|
} |
|
|
|
inputs = m_builder.CreateVector(v); |
|
|
|
} |
|
|
|
|
|
|
|
m_cur_opr_tensor.clear(); |
|
|
|
m_blobs.clear(); |
|
|
|
m_cur_opr_param.clear(); |
|
|
|
m_cur_opr_param_type.clear(); |
|
|
|
registry->dumper(*this, *m_cur_opr); |
|
|
|
|
|
|
|
Offset<Vector<Offset<fbs::v2::CompNode>>> comp_node; |
|
|
|
auto& config = m_cur_opr->config(); |
|
|
|
if (config.has_comp_node_set()) { |
|
|
|
std::vector<flatbuffers::Offset<fbs::v2::CompNode>> cns; |
|
|
|
for (const auto& cn : config.comp_node()) { |
|
|
|
cns.emplace_back(fbs::v2::CreateCompNode( |
|
|
|
m_builder, m_builder.CreateSharedString(cn.to_string_logical()))); |
|
|
|
} |
|
|
|
comp_node = m_builder.CreateVector(cns); |
|
|
|
} |
|
|
|
Offset<String> operator_name; |
|
|
|
if (m_config.keep_op_name) { |
|
|
|
operator_name = m_builder.CreateSharedString(m_cur_opr->name()); |
|
|
|
} |
|
|
|
|
|
|
|
auto output_dtype = build_dtype(config.output_dtype()); |
|
|
|
|
|
|
|
Offset<Vector<uint32_t>> outputs; |
|
|
|
if (m_cur_opr->output().size()) { |
|
|
|
std::vector<uint32_t> v; |
|
|
|
v.reserve(m_cur_opr->output().size()); |
|
|
|
for (auto out : m_cur_opr->output()) { |
|
|
|
if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { |
|
|
|
auto fbs_out = build_middle_tensor(out); |
|
|
|
m_model_middle_tensors.push_back(fbs_out); |
|
|
|
m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1; |
|
|
|
v.emplace_back(m_var2midtensor_id.at(out)); |
|
|
|
} |
|
|
|
} |
|
|
|
outputs = m_builder.CreateVector(v); |
|
|
|
} |
|
|
|
|
|
|
|
Offset<Vector<Offset<fbs::v2::Tensor>>> tensors; |
|
|
|
if (m_cur_opr_tensor.size()) |
|
|
|
tensors = m_builder.CreateVector(m_cur_opr_tensor); |
|
|
|
|
|
|
|
//! the blobs data is used by custom data |
|
|
|
//! m_blobs will be filled by the Operator dumper function |
|
|
|
Offset<Vector<Offset<fbs::v2::Blob>>> blobs; |
|
|
|
if (m_blobs.size()) |
|
|
|
blobs = m_builder.CreateVector(m_blobs); |
|
|
|
|
|
|
|
Offset<Vector<uint8_t>> additional_params_type; |
|
|
|
Offset<Vector<Offset<void>>> additional_params; |
|
|
|
auto param_cnt = m_cur_opr_param_type.size(); |
|
|
|
if (param_cnt > 1) { |
|
|
|
additional_params_type = m_builder.CreateVectorScalarCast<uint8_t>( |
|
|
|
m_cur_opr_param_type.data() + 1, param_cnt - 1); |
|
|
|
additional_params = |
|
|
|
m_builder.CreateVector(m_cur_opr_param.data() + 1, param_cnt - 1); |
|
|
|
} |
|
|
|
auto opr_type = m_builder.CreateSharedString(registry->name); |
|
|
|
|
|
|
|
fbs::v2::OperatorBuilder builder(m_builder); |
|
|
|
builder.add_type(opr_type); |
|
|
|
builder.add_type_id(registry->type_id); |
|
|
|
builder.add_inputs(inputs); |
|
|
|
builder.add_outputs(outputs); |
|
|
|
if (m_config.keep_opr_priority) { |
|
|
|
builder.add_priority(opr->node_prop().attribute().priority); |
|
|
|
} |
|
|
|
builder.add_comp_node(comp_node); |
|
|
|
builder.add_opr_version(registry->get_version()); |
|
|
|
builder.add_name(operator_name); |
|
|
|
builder.add_output_dtype(output_dtype); |
|
|
|
if (param_cnt > 0) { |
|
|
|
builder.add_param_type(m_cur_opr_param_type[0]); |
|
|
|
builder.add_param(m_cur_opr_param[0]); |
|
|
|
} |
|
|
|
if (param_cnt > 1) { |
|
|
|
builder.add_additional_params_type(additional_params_type); |
|
|
|
builder.add_additional_params(additional_params); |
|
|
|
} |
|
|
|
builder.add_tensors(tensors); |
|
|
|
builder.add_custom_data(blobs); |
|
|
|
m_cur_opr = nullptr; |
|
|
|
return builder.Finish(); |
|
|
|
} |
|
|
|
|
|
|
|
SymbolVarArray GraphDumperOSSV2::converter_all_opr_to_compatiable( |
|
|
|
const SymbolVarArray& output_vars) { |
|
|
|
gopt::GraphOptimizer optimizer; |
|
|
|
VarNodeArray rets_var; |
|
|
|
for (auto& symbolvar : output_vars) { |
|
|
|
rets_var.push_back(symbolvar.node()); |
|
|
|
} |
|
|
|
optimizer.add_pass(PassConvertToCompatible::make(output_vars)); |
|
|
|
optimizer.apply_inplace(rets_var); |
|
|
|
|
|
|
|
SymbolVarArray dst_vars; |
|
|
|
for (auto& var : rets_var) { |
|
|
|
dst_vars.push_back({var}); |
|
|
|
} |
|
|
|
return dst_vars; |
|
|
|
} |
|
|
|
|
|
|
|
GraphDumper::DumpResult GraphDumperOSSV2::dump( |
|
|
|
const SymbolVarArray& output_vars, const DumpConfig& config, |
|
|
|
const Metadata& metadata) { |
|
|
|
mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph"); |
|
|
|
|
|
|
|
auto&& new_output_vars = converter_all_opr_to_compatiable(output_vars); |
|
|
|
|
|
|
|
auto begin_pos = m_file->tell(); |
|
|
|
m_config = config; |
|
|
|
m_builder.Reset(); |
|
|
|
|
|
|
|
m_output_vars.clear(); |
|
|
|
m_cur_rst = {}; |
|
|
|
m_used_input_names.clear(); |
|
|
|
m_used_param_names.clear(); |
|
|
|
m_var_remove_in_dump.clear(); |
|
|
|
m_model_middle_tensors.clear(); |
|
|
|
m_var2midtensor_id.clear(); |
|
|
|
m_nr_shared_tensor = 0; |
|
|
|
|
|
|
|
// process output vars |
|
|
|
bool keep_output_var_name = m_config.keep_var_name >= 1; |
|
|
|
std::unordered_set<std::string> output_var_names; |
|
|
|
for (auto i : new_output_vars) { |
|
|
|
mgb_assert( |
|
|
|
!i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT), |
|
|
|
"can not dump var with VOLATILE_CONTENT flag: %s", |
|
|
|
cg::dump_var_info({i.node()}).c_str()); |
|
|
|
if (m_output_vars.insert(i.node()).second && keep_output_var_name) { |
|
|
|
auto name_ins = output_var_names.insert(i.node()->name()).second; |
|
|
|
mgb_assert(name_ins, "duplicated output var name: %s", i.node()->cname()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Dump metadata |
|
|
|
auto fbmeta = build_metadata(metadata); |
|
|
|
|
|
|
|
// Dump operators |
|
|
|
init_oprs_to_dump(new_output_vars); |
|
|
|
std::vector<flatbuffers::Offset<fbs::v2::Operator>> oprs; |
|
|
|
for (auto&& i : m_oprs_to_dump) { |
|
|
|
oprs.emplace_back(build_single_opr(i.first, i.second)); |
|
|
|
} |
|
|
|
auto fb_oprs = m_builder.CreateVector(oprs); |
|
|
|
|
|
|
|
// Dump output vars |
|
|
|
std::vector<flatbuffers::Offset<fbs::v2::OutputVar>> output_vars_idx; |
|
|
|
output_vars_idx.reserve(new_output_vars.size()); |
|
|
|
for (auto i : new_output_vars) { |
|
|
|
auto foutput_vars_idx = build_output_var(i); |
|
|
|
output_vars_idx.push_back(foutput_vars_idx); |
|
|
|
} |
|
|
|
auto fb_output_vars = m_builder.CreateVector(output_vars_idx); |
|
|
|
auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors); |
|
|
|
|
|
|
|
fbs::v2::ModelBuilder model(m_builder); |
|
|
|
model.add_mge_version(MGB_VERSION); |
|
|
|
model.add_oprs(fb_oprs); |
|
|
|
model.add_middle_tensors(fb_mid_tensor); |
|
|
|
model.add_output_vars_idx(fb_output_vars); |
|
|
|
model.add_nr_shared_tensor(m_nr_shared_tensor); |
|
|
|
model.add_metadata(fbmeta); |
|
|
|
m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier()); |
|
|
|
|
|
|
|
// Write serialized fbs::Graph |
|
|
|
m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize()); |
|
|
|
|
|
|
|
// Finalize DumpResult |
|
|
|
auto&& ret = m_cur_rst; |
|
|
|
for (size_t i = 0; i < new_output_vars.size(); i++) { |
|
|
|
ret.outputs.emplace_back( |
|
|
|
keep_output_var_name ? new_output_vars[i].node()->cname() |
|
|
|
: ssprintf("unnamed%zu", i)); |
|
|
|
} |
|
|
|
std::sort(ret.inputs.begin(), ret.inputs.end()); |
|
|
|
mgb_assert(ret.nr_opr == m_oprs_to_dump.size()); |
|
|
|
ret.tot_bytes = m_file->tell() - begin_pos; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphDumperOSSV2::dump_tensor( |
|
|
|
const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) { |
|
|
|
using namespace flatbuffers; |
|
|
|
using Meth = TensorWriteMethod; |
|
|
|
mgb_assert( |
|
|
|
(method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()), |
|
|
|
"name must be non-empty for non Meth::VALUE_ANONYMOUS tensors"); |
|
|
|
|
|
|
|
bool has_value = method != Meth::META_INPUT; |
|
|
|
bool should_keep_name = true; |
|
|
|
switch (method) { |
|
|
|
case Meth::VALUE_ANONYMOUS: |
|
|
|
should_keep_name = false; |
|
|
|
break; |
|
|
|
case Meth::VALUE_SHARED: |
|
|
|
should_keep_name = m_config.keep_param_name; |
|
|
|
++m_nr_shared_tensor; |
|
|
|
if (m_config.keep_param_name) { |
|
|
|
mgb_assert( |
|
|
|
m_used_param_names.insert(name).second, |
|
|
|
"duplicated VALUE_SHARED tensor name: %s", name.c_str()); |
|
|
|
m_cur_rst.params.emplace_back(name); |
|
|
|
} |
|
|
|
break; |
|
|
|
case Meth::META_INPUT: |
|
|
|
case Meth::VALUE_INPUT: |
|
|
|
mgb_assert(!name.empty(), "empty input tensor name"); |
|
|
|
mgb_assert( |
|
|
|
m_used_input_names.insert(name).second, |
|
|
|
"duplicated input tensor name: %s", name.c_str()); |
|
|
|
m_cur_rst.inputs.emplace_back(name); |
|
|
|
break; |
|
|
|
} |
|
|
|
|
|
|
|
auto& layout = tensor.layout(); |
|
|
|
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> data; |
|
|
|
if (has_value) { |
|
|
|
check_tensor_value_valid(name, tensor); |
|
|
|
auto&& dumper = m_config.tensor_value_dumper; |
|
|
|
if (dumper) { |
|
|
|
mgb_log_warn( |
|
|
|
"serialization v2 format is pure flatbuffer format, not support " |
|
|
|
"user tensor value dumper"); |
|
|
|
} |
|
|
|
data = m_builder.CreateVector( |
|
|
|
reinterpret_cast<uint8_t*>(tensor.raw_ptr()), layout.span().high_byte); |
|
|
|
m_cur_rst.tensor_value_bytes += layout.span().high_byte; |
|
|
|
} |
|
|
|
|
|
|
|
auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0; |
|
|
|
auto fshape = m_builder.CreateVectorScalarCast<uint32_t>(layout.shape, layout.ndim); |
|
|
|
auto fcomp_node = fbs::v2::CreateCompNode( |
|
|
|
m_builder, |
|
|
|
m_builder.CreateSharedString(tensor.comp_node().to_string_logical())); |
|
|
|
auto fdtype = build_dtype(layout.dtype); |
|
|
|
|
|
|
|
auto fformat_type = get_flatbuffer_tensor_format_type(layout.format); |
|
|
|
auto fformat = build_tensor_format(layout.format); |
|
|
|
auto serialized_tensor = fbs::v2::CreateTensor( |
|
|
|
m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data); |
|
|
|
m_cur_opr_tensor.emplace_back(serialized_tensor); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) { |
|
|
|
auto blob = fbs::v2::CreateBlob( |
|
|
|
m_builder, m_builder.CreateVector(static_cast<const uint8_t*>(data), size)); |
|
|
|
m_blobs.emplace_back(blob); |
|
|
|
} |
|
|
|
|
|
|
|
// ----------------------------- Loader -------------------------------------- |
|
|
|
CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node( |
|
|
|
const fbs::v2::CompNode* comp_node) { |
|
|
|
mgb_assert(comp_node); |
|
|
|
if (!comp_node->logical_locator()) |
|
|
|
return {}; |
|
|
|
auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str()); |
|
|
|
m_loader->m_cur_load_config->comp_node_mapper(loc); |
|
|
|
return CompNode::load(loc); |
|
|
|
} |
|
|
|
|
|
|
|
TensorFormat load_tensor_format( |
|
|
|
const fbs::v2::TensorFormat fformat_type, const void* fformat, |
|
|
|
const CompNode& comp_node) { |
|
|
|
switch (fformat_type) { |
|
|
|
case fbs::v2::TensorFormat_DefaultTensorFormat: |
|
|
|
return megdnn::DefaultTensorFormat::make(); |
|
|
|
case fbs::v2::TensorFormat_Image2DPackedTensorFormat: { |
|
|
|
auto image_format = |
|
|
|
static_cast<const fbs::v2::Image2DPackedTensorFormat*>(fformat); |
|
|
|
auto handle = |
|
|
|
MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle(); |
|
|
|
return megdnn::Image2DPack4TensorFormat::make( |
|
|
|
image_format->align_axis(), handle); |
|
|
|
} |
|
|
|
case fbs::v2::TensorFormat_LowbitsAlignedTensorFormat: { |
|
|
|
auto lowbit_format = |
|
|
|
static_cast<const fbs::v2::LowbitsAlignedTensorFormat*>(fformat); |
|
|
|
return megdnn::LowbitsAlignedToBytesTensorFormat::make( |
|
|
|
lowbit_format->size_nbits()); |
|
|
|
} |
|
|
|
default: |
|
|
|
mgb_throw( |
|
|
|
SerializationError, "invalid tensor format type in serialization."); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
TensorLayout load_tensor_layout( |
|
|
|
const fbs::v2::Tensor* tensor, const CompNode& comp_node) { |
|
|
|
TensorLayout layout; |
|
|
|
if (tensor->shape()) { |
|
|
|
layout.ndim = tensor->shape()->size(); |
|
|
|
std::copy(tensor->shape()->begin(), tensor->shape()->end(), layout.shape); |
|
|
|
} |
|
|
|
if (tensor->dtype()) { |
|
|
|
// modify data type inplace for TensorLayout |
|
|
|
layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype())); |
|
|
|
} |
|
|
|
if (tensor->format() && tensor->format_type()) { |
|
|
|
layout.format = |
|
|
|
load_tensor_format(tensor->format_type(), tensor->format(), comp_node); |
|
|
|
} |
|
|
|
layout.init_contiguous_stride(); |
|
|
|
return layout; |
|
|
|
} |
|
|
|
|
|
|
|
//! the opr loader should make sure the exist of tensors and the number of |
|
|
|
//! tensor, here just assert it. |
|
|
|
std::shared_ptr<HostTensorND> GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() { |
|
|
|
mgb_assert( |
|
|
|
m_current_opr->tensors() && |
|
|
|
m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); |
|
|
|
auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); |
|
|
|
auto comp_node = load_comp_node(tensor->comp_node()); |
|
|
|
auto layout = load_tensor_layout(tensor, comp_node); |
|
|
|
auto ret = std::make_shared<HostTensorND>(comp_node, layout); |
|
|
|
|
|
|
|
auto&& loader = m_loader->m_cur_load_config->tensor_value_loader; |
|
|
|
if (tensor->data() && tensor->data()->size() > 0) { |
|
|
|
if (loader) { |
|
|
|
mgb_log_warn( |
|
|
|
"serialization v2 format is pure flatbuffer format, not support " |
|
|
|
"user tensor value loader"); |
|
|
|
} |
|
|
|
memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size()); |
|
|
|
} |
|
|
|
if (tensor->name()) { |
|
|
|
m_tensor_map[tensor->name()->str()] = ret; |
|
|
|
} |
|
|
|
if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) { |
|
|
|
bool has_value = false; |
|
|
|
if (tensor && tensor->data()) { |
|
|
|
has_value = tensor->data()->size() != 0; |
|
|
|
} |
|
|
|
mod(tensor->name() ? tensor->name()->str() : "", has_value, *ret); |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<DeviceTensorND> GraphLoaderOSSV2::OprLoadContextImpl:: |
|
|
|
load_tensor_shared() { |
|
|
|
mgb_assert( |
|
|
|
m_current_opr->tensors() && |
|
|
|
m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); |
|
|
|
auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); |
|
|
|
auto comp_node = load_comp_node(tensor->comp_node()); |
|
|
|
auto layout = load_tensor_layout(tensor, comp_node); |
|
|
|
mgb_assert(tensor->data()); |
|
|
|
auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++); |
|
|
|
auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()]; |
|
|
|
if (shared_tensor_ref) { |
|
|
|
if (shared_tensor_ref->comp_node() == comp_node) |
|
|
|
return shared_tensor_ref; |
|
|
|
// same mem node but different comp node, change comp node and share |
|
|
|
// value |
|
|
|
auto ret = std::make_shared<DeviceTensorND>(*shared_tensor_ref); |
|
|
|
ret->comp_node(comp_node); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
if (tensor->name()) { |
|
|
|
shared_pair.first = tensor->name()->str(); |
|
|
|
} |
|
|
|
|
|
|
|
if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { |
|
|
|
// directly forward CPU memory |
|
|
|
HostTensorND hv{comp_node}; |
|
|
|
if (tensor->data() && tensor->data()->size() > 0) { |
|
|
|
hv.dtype(layout.dtype).resize(layout); |
|
|
|
memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size()); |
|
|
|
} |
|
|
|
shared_tensor_ref = std::make_shared<DeviceTensorND>(); |
|
|
|
*shared_tensor_ref = DeviceTensorND::make_proxy(hv); |
|
|
|
} else { |
|
|
|
// use lazy load for non-CPU devices |
|
|
|
HostTensorND hv{CompNode::default_cpu()}; |
|
|
|
if (tensor->data() && tensor->data()->size() > 0) { |
|
|
|
hv.dtype(layout.dtype).resize(layout); |
|
|
|
memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size()); |
|
|
|
} |
|
|
|
shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv)); |
|
|
|
} |
|
|
|
return shared_tensor_ref; |
|
|
|
} |
|
|
|
|
|
|
|
Metadata GraphLoaderOSSV2::OprLoadContextImpl::load_metadata() { |
|
|
|
const auto* fbmeta = m_loader->m_model->metadata(); |
|
|
|
Metadata ret; |
|
|
|
if (fbmeta) { |
|
|
|
ret.is_valid = fbmeta->is_valid(); |
|
|
|
ret.graph_modified = fbmeta->graph_modified(); |
|
|
|
if (fbmeta->user_info()) { |
|
|
|
ret.user_info = fbmeta->user_info()->str(); |
|
|
|
ret.has_user_info = true; |
|
|
|
} |
|
|
|
if (fbmeta->optimize_options()) { |
|
|
|
ret.optimize_options = fbmeta->optimize_options(); |
|
|
|
ret.optimized_for_inference = true; |
|
|
|
} |
|
|
|
} |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr( |
|
|
|
const fbs::v2::Operator* fbopr) { |
|
|
|
m_cur_opr_tensor_cnt = 0; |
|
|
|
m_cur_opr_blob_cnt = 0; |
|
|
|
m_cur_opr_param_cnt = 0; |
|
|
|
|
|
|
|
OperatorNodeConfig config; |
|
|
|
if (fbopr->output_dtype()) { |
|
|
|
config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype())); |
|
|
|
} |
|
|
|
if (fbopr->name()) { |
|
|
|
config.name(fbopr->name()->str()); |
|
|
|
} |
|
|
|
if (fbopr->comp_node()) { |
|
|
|
auto cnt = fbopr->comp_node()->size(); |
|
|
|
cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt); |
|
|
|
for (size_t i = 0; i < cnt; i++) { |
|
|
|
CompNode cn{}; |
|
|
|
auto node = fbopr->comp_node()->Get(i); |
|
|
|
if (node) { |
|
|
|
cn = load_comp_node(node); |
|
|
|
} |
|
|
|
comp_node_arr[i] = cn; |
|
|
|
} |
|
|
|
config.comp_node_arr(comp_node_arr); |
|
|
|
} |
|
|
|
//! opr version must be exist |
|
|
|
uint8_t opr_version = fbopr->opr_version(); |
|
|
|
auto type_id = fbopr->type_id(); |
|
|
|
auto opr_type = fbopr->type()->str(); |
|
|
|
const OprRegistryV2* registry = |
|
|
|
OprRegistryV2::versioned_find_by_id(type_id, opr_version); |
|
|
|
mgb_throw_if( |
|
|
|
!registry, SerializationError, |
|
|
|
"failed to find opr with type %s id is %zu, use python env " |
|
|
|
"config.dump_registered_oprs() to get a dict that maps from " |
|
|
|
"opr id to opr name", |
|
|
|
fbopr->type()->str().c_str(), type_id); |
|
|
|
|
|
|
|
// load inputs |
|
|
|
VarNodeArray inputs; |
|
|
|
if (fbopr->inputs()) { |
|
|
|
inputs.resize(fbopr->inputs()->size()); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// call loader |
|
|
|
auto accessor = registry->loader(*this, inputs, config); |
|
|
|
auto opr = accessor.opr(); |
|
|
|
|
|
|
|
// check opr type; note that: |
|
|
|
// 1. registry->type may be empty for dynamic opr loaders or legacy oprs |
|
|
|
// 2. due to some optimization, an opr may be replaced by ImmutableTensor |
|
|
|
mgb_assert( |
|
|
|
opr && (opr->dyn_typeinfo() == registry->type || !registry->type || |
|
|
|
opr->same_type<opr::ImmutableTensor>()), |
|
|
|
"got_type=%s expected_type=%s", opr ? opr->dyn_typeinfo()->name : nullptr, |
|
|
|
registry->type->name); |
|
|
|
// record output vars; read output names |
|
|
|
size_t i = 0; |
|
|
|
for (auto ovar : accessor.output()) { |
|
|
|
if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { |
|
|
|
m_id2varnode.push_back(ovar); |
|
|
|
if (fbopr->outputs()) { |
|
|
|
auto id = fbopr->outputs()->Get(i); |
|
|
|
mgb_assert( |
|
|
|
m_id2varnode.size() - 1 == fbopr->outputs()->Get(i), |
|
|
|
"id2var is %zu, fbs get id is %d\n", m_id2varnode.size() - 1, |
|
|
|
fbopr->outputs()->Get(i)); |
|
|
|
if (m_middle_tensors.size() > i) { |
|
|
|
auto name = m_middle_tensors[id]->name()->str(); |
|
|
|
ovar->name(name); |
|
|
|
} |
|
|
|
} |
|
|
|
i++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
opr->node_prop().attribute().priority = fbopr->priority(); |
|
|
|
} |
|
|
|
|
|
|
|
GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() { |
|
|
|
// load oprs |
|
|
|
const auto* oprs = m_loader->m_model->oprs(); |
|
|
|
{ |
|
|
|
// inplace arith graph optimization is disabled during opr load |
|
|
|
// it tries to restore the same graph as it was dumped |
|
|
|
// see test TestSerializer2.LOGEXP for example |
|
|
|
GraphLoader::ScopedGraphOptDisabler _(m_graph); |
|
|
|
for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) { |
|
|
|
m_current_opr = oprs->Get(i); |
|
|
|
load_single_opr(m_current_opr); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// batched loading device values |
|
|
|
m_device_value_loader.apply(); |
|
|
|
|
|
|
|
LoadResult ret; |
|
|
|
ret.graph = m_graph; |
|
|
|
ret.tensor_map = m_tensor_map; |
|
|
|
|
|
|
|
const auto* outputs = m_loader->m_model->output_vars_idx(); |
|
|
|
ret.output_var_list.resize(outputs->size()); |
|
|
|
for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) { |
|
|
|
auto out = outputs->Get(i); |
|
|
|
auto var = m_id2varnode.at(out->compact_id()); |
|
|
|
ret.output_var_map[var->name()] = var; |
|
|
|
ret.output_var_map_id[out->original_id()] = var; |
|
|
|
ret.output_var_list[i] = var; |
|
|
|
} |
|
|
|
mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size()); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphLoaderOSSV2::OprLoadContextImpl::load_middle_tensor() { |
|
|
|
auto model = m_loader->m_model; |
|
|
|
if (model->middle_tensors()) { |
|
|
|
for (unsigned int i = 0; i < m_loader->m_model->middle_tensors()->size(); i++) { |
|
|
|
m_middle_tensors.push_back(model->middle_tensors()->Get(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool rewind) { |
|
|
|
mgb_assert(m_file); |
|
|
|
m_cur_load_config = &config; |
|
|
|
if (rewind) { |
|
|
|
m_file->rewind(); |
|
|
|
} |
|
|
|
// Read fbs::Graph |
|
|
|
uint32_t size; |
|
|
|
m_file->read(&size, sizeof(size)); |
|
|
|
m_model_buf = m_file->read_shared(size); |
|
|
|
|
|
|
|
mgb_throw_if( |
|
|
|
!fbs::v2::ModelBufferHasIdentifier(m_model_buf.data()), SerializationError, |
|
|
|
"invalid fbs model"); |
|
|
|
|
|
|
|
{ |
|
|
|
flatbuffers::Verifier verifier( |
|
|
|
static_cast<const uint8_t*>(m_model_buf.data()), m_model_buf.size()); |
|
|
|
mgb_throw_if( |
|
|
|
!fbs::v2::VerifyModelBuffer(verifier), SerializationError, |
|
|
|
"model verification failed (invalid or corrupted model?)"); |
|
|
|
} |
|
|
|
|
|
|
|
m_model = fbs::v2::GetModel(m_model_buf.data()); |
|
|
|
m_mgb_version = m_model->mge_version(); |
|
|
|
if (m_model->mge_version() > MGB_VERSION) { |
|
|
|
mgb_log_warn( |
|
|
|
"loading model from future runtime: version=%u " |
|
|
|
"model_version=%u", |
|
|
|
MGB_VERSION, m_model->mge_version()); |
|
|
|
} |
|
|
|
|
|
|
|
if (m_shared_tensor_map.empty()) { |
|
|
|
m_shared_tensor_map.resize(m_model->nr_shared_tensor()); |
|
|
|
} else { |
|
|
|
mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor()); |
|
|
|
} |
|
|
|
|
|
|
|
OprLoadContextImpl ctx{this, m_model->mge_version()}; |
|
|
|
ctx.load_middle_tensor(); |
|
|
|
auto metadata = ctx.load_metadata(); |
|
|
|
auto result = ctx.load_oprs(); |
|
|
|
result.metadata = metadata; |
|
|
|
m_model_loaded = true; |
|
|
|
result.graph_compile_ahead(); |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file) { |
|
|
|
return std::make_unique<GraphDumperOSSV2>(std::move(file)); |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) { |
|
|
|
return std::make_unique<GraphLoaderOSSV2>(std::move(file)); |
|
|
|
} |
|
|
|
|
|
|
|
bool is_fbs_v2_file(InputFile& file) { |
|
|
|
constexpr size_t identifier_length = 25; |
|
|
|
char identifier[identifier_length]; |
|
|
|
file.read(identifier, identifier_length); |
|
|
|
file.skip(-identifier_length); |
|
|
|
//! skip the size in prefix of the file |
|
|
|
return fbs::v2::ModelBufferHasIdentifier(identifier + sizeof(uint32_t)); |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace serialization |
|
|
|
} // namespace mgb |
|
|
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |