From 09de5a07254582b439546c1879d65431a851f382 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 28 Jan 2021 22:09:14 +0800 Subject: [PATCH] feat(mgb/serialization): be able to serialize operator names GitOrigin-RevId: d295abb5da0b70d4675e62e6632dc1c7bd77d58c --- .../python/megengine/core/tensor/megbrain_graph.py | 3 ++ imperative/python/src/graph_rt.cpp | 3 +- src/serialization/impl/schema.fbs | 1 + src/serialization/impl/serializer_oss.cpp | 9 ++++++ .../megbrain/serialization/load_dump_config.h | 5 ++++ src/serialization/test/serializer_oss.cpp | 33 ++++++++++++++++++++++ 6 files changed, 53 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 5004cfb1..73eb4b98 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -305,6 +305,7 @@ def dump_graph( output_vars: Union[Dict[str, VarNode], List[VarNode]], *, keep_var_name: int = 1, + keep_op_name: bool = True, keep_param_name: bool = False, keep_opr_priority: bool = False, strip_info_file=None, @@ -325,6 +326,7 @@ def dump_graph( * 0: none of the names are kept * 1: (default)keep names of output vars * 2: keep names of all (output and internal) vars + :param keep_op_name: whether to keep operator names. :param keep_param_name: whether to keep param names, so param values can be easily manipulated after loading model :param keep_opr_priority: whether to keep priority setting for operators @@ -368,6 +370,7 @@ def dump_graph( dump_content = _imperative_rt.dump_graph( ov, keep_var_name, + keep_op_name, keep_param_name, keep_opr_priority, stat, diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 81fc7970..4a913449 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -294,6 +294,7 @@ void init_graph_rt(py::module m) { m.def("dump_graph", []( const std::vector& dest_vars, int keep_var_name, + bool keep_op_name, bool keep_param_name, bool keep_opr_priority, py::list& stat, @@ -306,7 +307,7 @@ void init_graph_rt(py::module m) { SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, - keep_opr_priority}; + keep_opr_priority, keep_op_name}; auto rst = dumper->dump(symvars, config); for (auto i : rst.inputs) { diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 403d9224..e12ef167 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -124,6 +124,7 @@ table Operator { blobs:[Blob]; /// Operator may want to save more than one OperatorParam additional_params:[OperatorParam]; + name:string; } struct OutputVar { diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index 9935089b..22223f17 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -208,6 +208,11 @@ flatbuffers::Offset GraphDumperOSS::build_single_opr( inputs = m_builder.CreateVector(v); } + Offset operator_name; + if (m_config.keep_op_name) { + operator_name = m_builder.CreateSharedString(opr->name()); + } + Offset>> output_names; if (m_config.keep_var_name >= 2 || (m_config.keep_var_name == 1 && @@ -255,6 +260,7 @@ flatbuffers::Offset GraphDumperOSS::build_single_opr( } builder.add_comp_node(comp_node); builder.add_output_name(output_names); + builder.add_name(operator_name); builder.add_output_dtype(output_dtype); if (param_cnt > 0) { builder.add_param_type(m_cur_opr_param_type[0]); @@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( if (fbopr->output_dtype()) { config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype())); } + if (fbopr->name()) { + config.name(fbopr->name()->str()); + } if (fbopr->comp_node()) { auto cnt = fbopr->comp_node()->size(); cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt); diff --git a/src/serialization/include/megbrain/serialization/load_dump_config.h b/src/serialization/include/megbrain/serialization/load_dump_config.h index 857c59d4..92fb7e2c 100644 --- a/src/serialization/include/megbrain/serialization/load_dump_config.h +++ b/src/serialization/include/megbrain/serialization/load_dump_config.h @@ -43,6 +43,9 @@ struct GraphDumpConfig { //! whether to keep operator priorities bool keep_opr_priority; + //! whether to keep operator names + bool keep_op_name; + //! extra user data to be passed by dump caller into opr dump //! implementations; useful for implementing nested opr dump std::shared_ptr user_data; @@ -57,12 +60,14 @@ struct GraphDumpConfig { GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false, bool keep_opr_priority_ = false, + bool keep_op_name_ = true, const std::shared_ptr& user_data_ = std::make_shared(), const TensorValueDumper& tensor_value_dumper_ = {}) : keep_var_name{keep_var_name_}, keep_param_name{keep_param_name_}, keep_opr_priority{keep_opr_priority_}, + keep_op_name{keep_op_name_}, user_data{user_data_}, tensor_value_dumper{tensor_value_dumper_} {} }; diff --git a/src/serialization/test/serializer_oss.cpp b/src/serialization/test/serializer_oss.cpp index cc6a7dff..6df3bb4a 100644 --- a/src/serialization/test/serializer_oss.cpp +++ b/src/serialization/test/serializer_oss.cpp @@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) { load(); } +TEST(TestSerializer2, OperatorName) { + auto fname = GET_OUTPUT_FILE(); + TensorShape shape{2, 3}; + + auto dump = [&]() { + auto cn = CompNode::load("xpu0"); + auto host_x = std::make_shared(cn, shape), + host_y = std::make_shared(cn, shape); + auto graph = ComputingGraph::make(); + auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}), + y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"}); + using Mode = opr::Elemwise::Mode; + auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"}); + + auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), + GraphDumpFormat::FLATBUFFERS); + auto rst = dumper->dump({z.rename("z")}); + }; + + auto load = [&]() { + HostTensorGenerator<> gen; + auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), + GraphDumpFormat::FLATBUFFERS); + auto rst = loader->load(); + auto z = rst.output_var_map.at("z"); + auto op_name = z.node()->owner_opr()->cname(); + int cmp = strcmp(op_name, "add(x, y)"); + EXPECT_EQ(cmp, 0); + }; + + dump(); + load(); +} TEST(TestSerializer2, HasOutputDtype) { auto fname = GET_OUTPUT_FILE();