Browse Source

feat(mgb/serialization): be able to serialize operator names

GitOrigin-RevId: d295abb5da
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
09de5a0725
6 changed files with 53 additions and 1 deletions
  1. +3
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +2
    -1
      imperative/python/src/graph_rt.cpp
  3. +1
    -0
      src/serialization/impl/schema.fbs
  4. +9
    -0
      src/serialization/impl/serializer_oss.cpp
  5. +5
    -0
      src/serialization/include/megbrain/serialization/load_dump_config.h
  6. +33
    -0
      src/serialization/test/serializer_oss.cpp

+ 3
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -305,6 +305,7 @@ def dump_graph(
output_vars: Union[Dict[str, VarNode], List[VarNode]],
*,
keep_var_name: int = 1,
keep_op_name: bool = True,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
@@ -325,6 +326,7 @@ def dump_graph(
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_op_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
@@ -368,6 +370,7 @@ def dump_graph(
dump_content = _imperative_rt.dump_graph(
ov,
keep_var_name,
keep_op_name,
keep_param_name,
keep_opr_priority,
stat,


+ 2
- 1
imperative/python/src/graph_rt.cpp View File

@@ -294,6 +294,7 @@ void init_graph_rt(py::module m) {
m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars,
int keep_var_name,
bool keep_op_name,
bool keep_param_name,
bool keep_opr_priority,
py::list& stat,
@@ -306,7 +307,7 @@ void init_graph_rt(py::module m) {
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());

ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority};
keep_opr_priority, keep_op_name};

auto rst = dumper->dump(symvars, config);
for (auto i : rst.inputs) {


+ 1
- 0
src/serialization/impl/schema.fbs View File

@@ -124,6 +124,7 @@ table Operator {
blobs:[Blob];
/// Operator may want to save more than one OperatorParam
additional_params:[OperatorParam];
name:string;
}

struct OutputVar {


+ 9
- 0
src/serialization/impl/serializer_oss.cpp View File

@@ -208,6 +208,11 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
inputs = m_builder.CreateVector(v);
}

Offset<String> operator_name;
if (m_config.keep_op_name) {
operator_name = m_builder.CreateSharedString(opr->name());
}

Offset<Vector<Offset<String>>> output_names;
if (m_config.keep_var_name >= 2 ||
(m_config.keep_var_name == 1 &&
@@ -255,6 +260,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
}
builder.add_comp_node(comp_node);
builder.add_output_name(output_names);
builder.add_name(operator_name);
builder.add_output_dtype(output_dtype);
if (param_cnt > 0) {
builder.add_param_type(m_cur_opr_param_type[0]);
@@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr(
if (fbopr->output_dtype()) {
config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype()));
}
if (fbopr->name()) {
config.name(fbopr->name()->str());
}
if (fbopr->comp_node()) {
auto cnt = fbopr->comp_node()->size();
cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt);


+ 5
- 0
src/serialization/include/megbrain/serialization/load_dump_config.h View File

@@ -43,6 +43,9 @@ struct GraphDumpConfig {
//! whether to keep operator priorities
bool keep_opr_priority;

//! whether to keep operator names
bool keep_op_name;

//! extra user data to be passed by dump caller into opr dump
//! implementations; useful for implementing nested opr dump
std::shared_ptr<UserDataContainer> user_data;
@@ -57,12 +60,14 @@ struct GraphDumpConfig {

GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false,
bool keep_opr_priority_ = false,
bool keep_op_name_ = true,
const std::shared_ptr<UserDataContainer>& user_data_ =
std::make_shared<UserDataContainer>(),
const TensorValueDumper& tensor_value_dumper_ = {})
: keep_var_name{keep_var_name_},
keep_param_name{keep_param_name_},
keep_opr_priority{keep_opr_priority_},
keep_op_name{keep_op_name_},
user_data{user_data_},
tensor_value_dumper{tensor_value_dumper_} {}
};


+ 33
- 0
src/serialization/test/serializer_oss.cpp View File

@@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) {
load();
}

TEST(TestSerializer2, OperatorName) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};

auto dump = [&]() {
auto cn = CompNode::load("xpu0");
auto host_x = std::make_shared<HostTensorND>(cn, shape),
host_y = std::make_shared<HostTensorND>(cn, shape);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}),
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"});
using Mode = opr::Elemwise::Mode;
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"});

auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = dumper->dump({z.rename("z")});
};

auto load = [&]() {
HostTensorGenerator<> gen;
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
auto rst = loader->load();
auto z = rst.output_var_map.at("z");
auto op_name = z.node()->owner_opr()->cname();
int cmp = strcmp(op_name, "add(x, y)");
EXPECT_EQ(cmp, 0);
};

dump();
load();
}

TEST(TestSerializer2, HasOutputDtype) {
auto fname = GET_OUTPUT_FILE();


Loading…
Cancel
Save