GitOrigin-RevId: d295abb5da
tags/v1.3.0
@@ -305,6 +305,7 @@ def dump_graph( | |||||
output_vars: Union[Dict[str, VarNode], List[VarNode]], | output_vars: Union[Dict[str, VarNode], List[VarNode]], | ||||
*, | *, | ||||
keep_var_name: int = 1, | keep_var_name: int = 1, | ||||
keep_op_name: bool = True, | |||||
keep_param_name: bool = False, | keep_param_name: bool = False, | ||||
keep_opr_priority: bool = False, | keep_opr_priority: bool = False, | ||||
strip_info_file=None, | strip_info_file=None, | ||||
@@ -325,6 +326,7 @@ def dump_graph( | |||||
* 0: none of the names are kept | * 0: none of the names are kept | ||||
* 1: (default)keep names of output vars | * 1: (default)keep names of output vars | ||||
* 2: keep names of all (output and internal) vars | * 2: keep names of all (output and internal) vars | ||||
:param keep_op_name: whether to keep operator names. | |||||
:param keep_param_name: whether to keep param names, so param values can be | :param keep_param_name: whether to keep param names, so param values can be | ||||
easily manipulated after loading model | easily manipulated after loading model | ||||
:param keep_opr_priority: whether to keep priority setting for operators | :param keep_opr_priority: whether to keep priority setting for operators | ||||
@@ -368,6 +370,7 @@ def dump_graph( | |||||
dump_content = _imperative_rt.dump_graph( | dump_content = _imperative_rt.dump_graph( | ||||
ov, | ov, | ||||
keep_var_name, | keep_var_name, | ||||
keep_op_name, | |||||
keep_param_name, | keep_param_name, | ||||
keep_opr_priority, | keep_opr_priority, | ||||
stat, | stat, | ||||
@@ -294,6 +294,7 @@ void init_graph_rt(py::module m) { | |||||
m.def("dump_graph", []( | m.def("dump_graph", []( | ||||
const std::vector<VarNode*>& dest_vars, | const std::vector<VarNode*>& dest_vars, | ||||
int keep_var_name, | int keep_var_name, | ||||
bool keep_op_name, | |||||
bool keep_param_name, | bool keep_param_name, | ||||
bool keep_opr_priority, | bool keep_opr_priority, | ||||
py::list& stat, | py::list& stat, | ||||
@@ -306,7 +307,7 @@ void init_graph_rt(py::module m) { | |||||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | ||||
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | ||||
keep_opr_priority}; | |||||
keep_opr_priority, keep_op_name}; | |||||
auto rst = dumper->dump(symvars, config); | auto rst = dumper->dump(symvars, config); | ||||
for (auto i : rst.inputs) { | for (auto i : rst.inputs) { | ||||
@@ -124,6 +124,7 @@ table Operator { | |||||
blobs:[Blob]; | blobs:[Blob]; | ||||
/// Operator may want to save more than one OperatorParam | /// Operator may want to save more than one OperatorParam | ||||
additional_params:[OperatorParam]; | additional_params:[OperatorParam]; | ||||
name:string; | |||||
} | } | ||||
struct OutputVar { | struct OutputVar { | ||||
@@ -208,6 +208,11 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||||
inputs = m_builder.CreateVector(v); | inputs = m_builder.CreateVector(v); | ||||
} | } | ||||
Offset<String> operator_name; | |||||
if (m_config.keep_op_name) { | |||||
operator_name = m_builder.CreateSharedString(opr->name()); | |||||
} | |||||
Offset<Vector<Offset<String>>> output_names; | Offset<Vector<Offset<String>>> output_names; | ||||
if (m_config.keep_var_name >= 2 || | if (m_config.keep_var_name >= 2 || | ||||
(m_config.keep_var_name == 1 && | (m_config.keep_var_name == 1 && | ||||
@@ -255,6 +260,7 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr( | |||||
} | } | ||||
builder.add_comp_node(comp_node); | builder.add_comp_node(comp_node); | ||||
builder.add_output_name(output_names); | builder.add_output_name(output_names); | ||||
builder.add_name(operator_name); | |||||
builder.add_output_dtype(output_dtype); | builder.add_output_dtype(output_dtype); | ||||
if (param_cnt > 0) { | if (param_cnt > 0) { | ||||
builder.add_param_type(m_cur_opr_param_type[0]); | builder.add_param_type(m_cur_opr_param_type[0]); | ||||
@@ -698,6 +704,9 @@ void GraphLoaderOSS::OprLoadContextImpl::load_single_opr( | |||||
if (fbopr->output_dtype()) { | if (fbopr->output_dtype()) { | ||||
config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype())); | config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype())); | ||||
} | } | ||||
if (fbopr->name()) { | |||||
config.name(fbopr->name()->str()); | |||||
} | |||||
if (fbopr->comp_node()) { | if (fbopr->comp_node()) { | ||||
auto cnt = fbopr->comp_node()->size(); | auto cnt = fbopr->comp_node()->size(); | ||||
cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt); | cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt); | ||||
@@ -43,6 +43,9 @@ struct GraphDumpConfig { | |||||
//! whether to keep operator priorities | //! whether to keep operator priorities | ||||
bool keep_opr_priority; | bool keep_opr_priority; | ||||
//! whether to keep operator names | |||||
bool keep_op_name; | |||||
//! extra user data to be passed by dump caller into opr dump | //! extra user data to be passed by dump caller into opr dump | ||||
//! implementations; useful for implementing nested opr dump | //! implementations; useful for implementing nested opr dump | ||||
std::shared_ptr<UserDataContainer> user_data; | std::shared_ptr<UserDataContainer> user_data; | ||||
@@ -57,12 +60,14 @@ struct GraphDumpConfig { | |||||
GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false, | GraphDumpConfig(int keep_var_name_ = 1, bool keep_param_name_ = false, | ||||
bool keep_opr_priority_ = false, | bool keep_opr_priority_ = false, | ||||
bool keep_op_name_ = true, | |||||
const std::shared_ptr<UserDataContainer>& user_data_ = | const std::shared_ptr<UserDataContainer>& user_data_ = | ||||
std::make_shared<UserDataContainer>(), | std::make_shared<UserDataContainer>(), | ||||
const TensorValueDumper& tensor_value_dumper_ = {}) | const TensorValueDumper& tensor_value_dumper_ = {}) | ||||
: keep_var_name{keep_var_name_}, | : keep_var_name{keep_var_name_}, | ||||
keep_param_name{keep_param_name_}, | keep_param_name{keep_param_name_}, | ||||
keep_opr_priority{keep_opr_priority_}, | keep_opr_priority{keep_opr_priority_}, | ||||
keep_op_name{keep_op_name_}, | |||||
user_data{user_data_}, | user_data{user_data_}, | ||||
tensor_value_dumper{tensor_value_dumper_} {} | tensor_value_dumper{tensor_value_dumper_} {} | ||||
}; | }; | ||||
@@ -711,6 +711,39 @@ TEST(TestSerializer2, ParamerizedDType) { | |||||
load(); | load(); | ||||
} | } | ||||
TEST(TestSerializer2, OperatorName) { | |||||
auto fname = GET_OUTPUT_FILE(); | |||||
TensorShape shape{2, 3}; | |||||
auto dump = [&]() { | |||||
auto cn = CompNode::load("xpu0"); | |||||
auto host_x = std::make_shared<HostTensorND>(cn, shape), | |||||
host_y = std::make_shared<HostTensorND>(cn, shape); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x, {"x"}), | |||||
y = opr::Host2DeviceCopy::make(*graph, host_y, {"y"}); | |||||
using Mode = opr::Elemwise::Mode; | |||||
auto z = opr::Elemwise::make({x, y}, Mode::ADD, {"add(x, y)"}); | |||||
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), | |||||
GraphDumpFormat::FLATBUFFERS); | |||||
auto rst = dumper->dump({z.rename("z")}); | |||||
}; | |||||
auto load = [&]() { | |||||
HostTensorGenerator<> gen; | |||||
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), | |||||
GraphDumpFormat::FLATBUFFERS); | |||||
auto rst = loader->load(); | |||||
auto z = rst.output_var_map.at("z"); | |||||
auto op_name = z.node()->owner_opr()->cname(); | |||||
int cmp = strcmp(op_name, "add(x, y)"); | |||||
EXPECT_EQ(cmp, 0); | |||||
}; | |||||
dump(); | |||||
load(); | |||||
} | |||||
TEST(TestSerializer2, HasOutputDtype) { | TEST(TestSerializer2, HasOutputDtype) { | ||||
auto fname = GET_OUTPUT_FILE(); | auto fname = GET_OUTPUT_FILE(); | ||||