GitOrigin-RevId: 65064452c9
release-1.10
@@ -367,10 +367,12 @@ def dump_graph( | |||
keep_opr_name: bool = False, | |||
keep_param_name: bool = False, | |||
keep_opr_priority: bool = False, | |||
no_change_graph: bool = False, | |||
strip_info_file=None, | |||
append_json=False, | |||
metadata=None, | |||
dump_format=None | |||
dump_format=None, | |||
model_version: int = 2 | |||
) -> Tuple[bytes, CompGraphDumpResult]: | |||
r"""serialize the computing graph of `output_vars` and get byte result. | |||
@@ -386,12 +388,22 @@ def dump_graph( | |||
keep_param_name: whether to keep param names, so param values can be | |||
easily manipulated after loading model | |||
keep_opr_priority: whether to keep priority setting for operators | |||
no_change_graph: whether to change the compute graph when dump, for | |||
model compatibility, some operators will convert to its compatible | |||
format in this version. | |||
* if set False, some operators maybe convert to other operator for | |||
compatibility, all operators will ensure compatibility. | |||
* if set True, no operator will change in the graph when dump. | |||
strip_info_file: a string for path or a file handler. if is not None, | |||
then the dump information for code strip would be written to ``strip_info_file`` | |||
append_json: will be check when `strip_info_file` is not None. if set | |||
true, the information for code strip will be append to strip_info_file. | |||
if set false, will rewrite strip_info_file | |||
dump_format: using different dump formats. | |||
model_version: the model version of "FBS_V2", begin with version 2, this | |||
works only when dump format is "FBS_V2". | |||
Note: | |||
The underlying C++ API only accepts a var list. If a dict is given, | |||
@@ -441,8 +453,10 @@ def dump_graph( | |||
keep_opr_name, | |||
keep_param_name, | |||
keep_opr_priority, | |||
no_change_graph, | |||
metadata, | |||
dump_format, | |||
model_version, | |||
stat, | |||
inputs, | |||
outputs, | |||
@@ -549,6 +549,7 @@ class trace: | |||
keep_opr_name: bool = False, | |||
keep_param_name: bool = False, | |||
keep_opr_priority: bool = False, | |||
no_change_graph: bool = False, | |||
strip_info_file=None, | |||
append_json=False, | |||
optimize_for_inference=True, | |||
@@ -562,6 +563,7 @@ class trace: | |||
resize_input=False, | |||
input_transform=None, | |||
dump_format: str = None, | |||
model_version: int = 2, | |||
**kwargs | |||
): | |||
r"""Serializes trace to file system. | |||
@@ -583,6 +585,14 @@ class trace: | |||
keep_param_name: whether to keep param names, so param values can be | |||
easily manipulated after loading model | |||
keep_opr_priority: whether to keep priority setting for operators | |||
no_change_graph: whether to change the compute graph when dump, for | |||
model compatibility, some operators will convert to its compatible | |||
format in this version. | |||
* if set False, some operators maybe convert to other operator for | |||
compatibility, all operators will ensure compatibility. | |||
* if set True, no operator will change in the graph when dump. | |||
strip_info_file: a string for path or a file handler. if is not None, | |||
then the dump information for code strip would be written to ``strip_info_file`` | |||
append_json: will be check when `strip_info_file` is not None. if set | |||
@@ -616,6 +626,9 @@ class trace: | |||
dump_format: using different dump formats. the open source MegEngine | |||
defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose, | |||
internal MegEngine have an other choice of internal proprietary formats | |||
model_version: the model version of FBS_V2, begin with version 2, this | |||
works only when dump format is FBS_V2. | |||
Keyword Arguments: | |||
@@ -762,10 +775,12 @@ class trace: | |||
keep_opr_name=keep_opr_name, | |||
keep_param_name=keep_param_name, | |||
keep_opr_priority=keep_opr_priority, | |||
no_change_graph=no_change_graph, | |||
strip_info_file=strip_info_file, | |||
append_json=append_json, | |||
metadata=metadata, | |||
dump_format=dump_format, | |||
model_version=model_version, | |||
) | |||
file.write(dump_content) | |||
@@ -381,20 +381,26 @@ void init_graph_rt(py::module m) { | |||
m.def("dump_graph", | |||
[](const std::vector<VarNode*>& dest_vars, int keep_var_name, | |||
bool keep_opr_name, bool keep_param_name, bool keep_opr_priority, | |||
std::optional<_SerializationMetadata> metadata, | |||
std::optional<_SerializationFormat> dump_format, py::list& stat, | |||
py::list& inputs, py::list& outputs, py::list& params) { | |||
bool no_change_graph, std::optional<_SerializationMetadata> metadata, | |||
std::optional<_SerializationFormat> dump_format, | |||
std::optional<int> model_version, py::list& stat, py::list& inputs, | |||
py::list& outputs, py::list& params) { | |||
std::vector<uint8_t> buf; | |||
ser::GraphDumpFormat format = ser::GraphDumpFormat::FLATBUFFERS_V2; | |||
int version = 2; | |||
if (dump_format.has_value()) { | |||
format = dump_format.value(); | |||
} | |||
if (model_version.has_value()) { | |||
version = model_version.value(); | |||
} | |||
auto dumper = ser::GraphDumper::make( | |||
ser::OutputFile::make_vector_proxy(&buf), format); | |||
ser::OutputFile::make_vector_proxy(&buf), format, version); | |||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||
ser::GraphDumper::DumpConfig config{ | |||
keep_var_name, keep_param_name, keep_opr_priority, keep_opr_name}; | |||
config.no_change_graph = no_change_graph; | |||
ser::GraphDumper::DumpResult rst; | |||
if (metadata) | |||
@@ -21,6 +21,13 @@ struct OprLoadDumpImplV2<opr::Softmax, 1> { | |||
ctx.write_param<PersisParam>(opr.cast_final_safe<Opr>().param()); | |||
} | |||
/** This converter is just a example for Operator serialization compatible, | |||
* Just in this situation: when optimize the softmax Operator by | |||
* fusing the elemwise and reduce to a big Operator, but the whole softmax | |||
* Operator can't be recognized by old version, in order to model | |||
* compatibility the softmax Operator should be covert to elemwise and | |||
* reduce Operators when dump the model | |||
*/ | |||
static cg::OperatorNodeBase* replace_opr( | |||
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { | |||
int32_t axis = opr->cast_final_safe<Opr>().param().axis; | |||
@@ -196,9 +203,11 @@ namespace opr { | |||
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \ | |||
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION); | |||
SERGE_OPR_V2_CONVERTER( | |||
//! this is just a example for Operator compatibility | |||
/*SERGE_OPR_V2_CONVERTER( | |||
Softmax, 1, | |||
(mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr)); | |||
(mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr));*/ | |||
SERGE_OPR_V2_NO_CONVERTER(Softmax, 1) | |||
SERGE_OPR_V2_NO_CONVERTER(ConvBiasForward, 0) | |||
SERGE_OPR_V2_NO_CONVERTER(BatchConvBiasForward, 0); | |||
@@ -59,7 +59,8 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file); | |||
std::unique_ptr<GraphDumper> make_fbs_dumper(std::unique_ptr<OutputFile> file); | |||
std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file); | |||
std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file); | |||
std::unique_ptr<GraphDumper> make_fbs_v2_dumper( | |||
std::unique_ptr<OutputFile> file, int version); | |||
bool is_fbs_file(InputFile& file); | |||
bool is_fbs_v2_file(InputFile& file); | |||
@@ -72,7 +73,7 @@ bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) { | |||
} | |||
std::unique_ptr<GraphDumper> GraphDumper::make( | |||
std::unique_ptr<OutputFile> file, GraphDumpFormat format) { | |||
std::unique_ptr<OutputFile> file, GraphDumpFormat format, int version) { | |||
switch (format) { | |||
case GraphDumpFormat::FLATBUFFERS: | |||
#if MGB_ENABLE_FBS_SERIALIZATION | |||
@@ -81,7 +82,7 @@ std::unique_ptr<GraphDumper> GraphDumper::make( | |||
MGB_FALLTHRU | |||
case GraphDumpFormat::FLATBUFFERS_V2: | |||
#if MGB_ENABLE_FBS_SERIALIZATION | |||
return make_fbs_v2_dumper(std::move(file)); | |||
return make_fbs_v2_dumper(std::move(file), version); | |||
#endif | |||
MGB_FALLTHRU | |||
default: | |||
@@ -194,7 +194,7 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) { | |||
} | |||
} else { | |||
auto registry = OprRegistryV2::versioned_find_by_typeinfo( | |||
opr->dyn_typeinfo(), CURRENT_VERSION); | |||
opr->dyn_typeinfo(), m_version); | |||
if (!registry || !registry->dumper) { | |||
mgb_throw( | |||
cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make<MegBrainError>, | |||
@@ -202,6 +202,9 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) { | |||
"operator %s", | |||
opr->dyn_typeinfo()->name); | |||
} | |||
mgb_assert( | |||
registry->version <= m_version, | |||
"The Operator version should less than model version"); | |||
m_oprs_to_dump.emplace_back(opr, registry); | |||
} | |||
}; | |||
@@ -352,7 +355,10 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||
const Metadata& metadata) { | |||
mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph"); | |||
auto&& new_output_vars = converter_all_opr_to_compatiable(output_vars); | |||
auto new_output_vars = output_vars; | |||
if (!config.no_change_graph) { | |||
new_output_vars = converter_all_opr_to_compatiable(output_vars); | |||
} | |||
auto begin_pos = m_file->tell(); | |||
m_config = config; | |||
@@ -416,6 +422,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump( | |||
fbs::v2::ModelBuilder model(m_builder); | |||
model.add_mge_version(MGB_VERSION); | |||
model.add_model_version(m_version); | |||
model.add_oprs(fb_oprs); | |||
model.add_middle_tensors(fb_mid_tensor); | |||
model.add_output_vars_idx(fb_output_vars); | |||
@@ -694,10 +701,8 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr( | |||
OprRegistryV2::versioned_find_by_id(type_id, opr_version); | |||
mgb_throw_if( | |||
!registry, SerializationError, | |||
"failed to find opr with type %s , use python env " | |||
"config.dump_registered_oprs() to get a dict that maps from " | |||
"opr id to opr name", | |||
fbopr->type()->str().c_str()); | |||
"failed to find opr with type %s and version %d.", | |||
fbopr->type()->str().c_str(), opr_version); | |||
// load inputs | |||
VarNodeArray inputs; | |||
@@ -811,12 +816,19 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re | |||
m_model = fbs::v2::GetModel(m_model_buf.data()); | |||
m_mgb_version = m_model->mge_version(); | |||
m_model_version = m_model->model_version(); | |||
if (m_model->mge_version() > MGB_VERSION) { | |||
mgb_log_warn( | |||
"loading model from future runtime: version=%u " | |||
"model_version=%u", | |||
MGB_VERSION, m_model->mge_version()); | |||
} | |||
if (m_model_version > CURRENT_VERSION) { | |||
mgb_log_warn( | |||
"The model dump in the future version %d, try to load it, maybe case " | |||
"load error in %d version.", | |||
m_model_version, CURRENT_VERSION); | |||
} | |||
if (m_shared_tensor_map.empty()) { | |||
m_shared_tensor_map.resize(m_model->nr_shared_tensor()); | |||
@@ -845,8 +857,9 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re | |||
return result; | |||
} | |||
std::unique_ptr<GraphDumper> make_fbs_v2_dumper(std::unique_ptr<OutputFile> file) { | |||
return std::make_unique<GraphDumperOSSV2>(std::move(file)); | |||
std::unique_ptr<GraphDumper> make_fbs_v2_dumper( | |||
std::unique_ptr<OutputFile> file, int version) { | |||
return std::make_unique<GraphDumperOSSV2>(std::move(file), version); | |||
} | |||
std::unique_ptr<GraphLoader> make_fbs_v2_loader(std::unique_ptr<InputFile> file) { | |||
@@ -58,18 +58,25 @@ struct GraphDumpConfig { | |||
//! names. this list record the mapping between output node and it's name | |||
std::vector<std::pair<std::string, SymbolVar>> alias_name_map; | |||
//! whether just to dump all the op with no change the graph, sometimes the | |||
//! opr maybe not compatible, if false, some opr will converter to the compatibility | |||
//! format and then dump | |||
bool no_change_graph; | |||
GraphDumpConfig( | |||
int keep_var_name_ = 1, bool keep_param_name_ = false, | |||
bool keep_opr_priority_ = false, bool keep_op_name_ = true, | |||
const std::shared_ptr<UserDataContainer>& user_data_ = | |||
std::make_shared<UserDataContainer>(), | |||
const TensorValueDumper& tensor_value_dumper_ = {}) | |||
const TensorValueDumper& tensor_value_dumper_ = {}, | |||
bool no_change_graph_ = false) | |||
: keep_var_name{keep_var_name_}, | |||
keep_param_name{keep_param_name_}, | |||
keep_opr_priority{keep_opr_priority_}, | |||
keep_op_name{keep_op_name_}, | |||
user_data{user_data_}, | |||
tensor_value_dumper{tensor_value_dumper_} {} | |||
tensor_value_dumper{tensor_value_dumper_}, | |||
no_change_graph{no_change_graph_} {} | |||
}; | |||
//! config for loading a whole graph; setup in GraphLoader | |||
@@ -15,6 +15,7 @@ namespace serialization { | |||
class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers { | |||
const std::unique_ptr<OutputFile> m_file; | |||
int m_version; | |||
flatbuffers::FlatBufferBuilder m_builder; | |||
DumpConfig m_config; | |||
@@ -51,7 +52,8 @@ class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers { | |||
flatbuffers::Offset<fbs::DType> build_dtype(DType dtype); | |||
public: | |||
GraphDumperOSSV2(std::unique_ptr<OutputFile> file) : m_file{std::move(file)} {} | |||
GraphDumperOSSV2(std::unique_ptr<OutputFile> file, int version) | |||
: m_file{std::move(file)}, m_version{version} {} | |||
DumpResult dump( | |||
const SymbolVarArray& output_vars, const DumpConfig& config = {}, | |||
@@ -95,6 +97,7 @@ class GraphLoaderOSSV2 final : public GraphLoader { | |||
const fbs::v2::Model* m_model; | |||
SharedTensorIDMap m_shared_tensor_map; | |||
uint32_t m_mgb_version = 0; | |||
uint32_t m_model_version = CURRENT_VERSION; | |||
bool m_model_loaded = false; | |||
void verify(); | |||
@@ -5,6 +5,7 @@ | |||
#include "megbrain/serialization/file.h" | |||
#include "megbrain/serialization/load_dump_config.h" | |||
#include "megbrain/serialization/metadata.h" | |||
#include "megbrain/serialization/opr_load_dump.h" | |||
namespace mgb { | |||
namespace serialization { | |||
@@ -160,7 +161,8 @@ public: | |||
}; | |||
MGE_WIN_DECLSPEC_FUC static std::unique_ptr<GraphDumper> make( | |||
std::unique_ptr<OutputFile> file, GraphDumpFormat format = {}); | |||
std::unique_ptr<OutputFile> file, GraphDumpFormat format = {}, | |||
int version = VERSION_2); | |||
virtual ~GraphDumper() = default; | |||
@@ -987,7 +987,9 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) { | |||
OutputFile::make_fs(fname.c_str()), GraphDumpFormat::FLATBUFFERS_V2); | |||
auto rst = dumper->dump({x}); | |||
func->execute().wait(); | |||
ASSERT_EQ(rst.nr_opr, 6); | |||
//! if convert to reduce and elemwise, nr_opr is 6 | |||
// ASSERT_EQ(rst.nr_opr, 6); | |||
ASSERT_EQ(rst.nr_opr, 2); | |||
ASSERT_EQ(rst.inputs.size(), 1); | |||
ASSERT_EQ(rst.outputs.size(), 1); | |||
ASSERT_EQ(rst.params.size(), 0); | |||