diff --git a/inc/external/parser/caffe_parser.h b/inc/external/parser/caffe_parser.h new file mode 100644 index 0000000..2a687d0 --- /dev/null +++ b/inc/external/parser/caffe_parser.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ +#define INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ + +#include +#include +#include + +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "graph/graph.h" + +namespace ge { +graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph); +} // namespace ge + +#endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ diff --git a/inc/external/parser/tensorflow_parser.h b/inc/external/parser/tensorflow_parser.h new file mode 100644 index 0000000..b7c1c8c --- /dev/null +++ b/inc/external/parser/tensorflow_parser.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ +#define INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ + +#include +#include +#include +#include + +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "graph/graph.h" + +namespace ge { +graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); +} // namespace ge + +#endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ \ No newline at end of file diff --git a/parser/CMakeLists.txt b/parser/CMakeLists.txt new file mode 100644 index 0000000..069d109 --- /dev/null +++ b/parser/CMakeLists.txt @@ -0,0 +1,150 @@ +set(PROTO_LIST + "${TOP_DIR}/inc/register/proto/tensorflow/graph_library.proto" +) + +set(SRC_LIST + "tensorflow/tensorflow_arg_parser.cc" + "tensorflow/tensorflow_auto_mapping_parser_adapter.cc" + "tensorflow/tensorflow_constant_parser.cc" + "tensorflow/tensorflow_data_parser.cc" + "tensorflow/tensorflow_enter_parser.cc" + "tensorflow/tensorflow_fill_parser.cc" + "tensorflow/tensorflow_frameworkop_parser.cc" + "tensorflow/tensorflow_fusionop_util.cc" + "tensorflow/tensorflow_identity_parser.cc" + "tensorflow/tensorflow_merge_parser.cc" + "tensorflow/tensorflow_no_op_parser.cc" + "tensorflow/tensorflow_parser.cc" + "tensorflow/tensorflow_ref_switch_parser.cc" + "tensorflow/tensorflow_reshape_parser.cc" + "tensorflow/tensorflow_shape_n_parser.cc" + "tensorflow/tensorflow_squeeze_parser.cc" + "tensorflow/tensorflow_var_is_initialized_op_parser.cc" + "tensorflow/tensorflow_variable_v2_parser.cc" + "caffe/caffe_parser.cc" + "caffe/caffe_data_parser.cc" + "caffe/caffe_reshape_parser.cc" + "caffe/caffe_custom_parser_adapter.cc" + "caffe/caffe_op_parser.cc" + "tensorflow/scope/scope_pass_manager.cc" + "tensorflow/graph_functiondef.cc" + "tensorflow/graph_optimizer.cc" + "tensorflow/iterator_fusion_pass.cc" + "common/op_def/arg_op.cc" + "common/op_def/constant_op.cc" + "common/op_def/fill_op.cc" + "common/op_def/frameworkop_op.cc" + "common/op_def/no_op_op.cc" + "common/op_def/ref_switch_op.cc" + "common/op_def/shape_n_op.cc" + "common/op_def/var_is_initialized_op_op.cc" + "common/op_def/variable_op.cc" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +############ libfmk_parser.so ############ +add_library(fmk_parser SHARED ${SRC_LIST} ${PROTO_SRCS}) + +target_compile_options(fmk_parser PRIVATE + -Werror +) + +target_compile_definitions(fmk_parser PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 +) + +target_include_directories(fmk_parser PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${TOP_DIR}/framework/domi + ${TOP_DIR}/framework/domi/common + ${TOP_DIR}/framework/domi/parser + ${TOP_DIR}/inc + ${TOP_DIR}/inc/external + ${TOP_DIR}/inc/external/parser + ${TOP_DIR}/inc/external/graph + ${TOP_DIR}/inc/framework + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge +) + +target_link_libraries(fmk_parser + $ + -Wl,--no-as-needed + protobuf + error_manager + parser_common + graph + register + _caffe_parser + c_sec + slog + mmpa + -Wl,--as-needed + json + -lrt +) + +################################################################## +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_tensorflow_parser.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_caffe_parser.cc + COMMAND echo "Generating stub files." + && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} + && mv tensorflow_parser.cc stub_tensorflow_parser.cc + && mv caffe_parser.cc stub_caffe_parser.cc + && echo "Generating stub files end." + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} +) + +################################################################## + +############ stub/libfmk_parser.so ############ +add_library(fmk_parser_stub SHARED + ${CMAKE_CURRENT_BINARY_DIR}/stub_tensorflow_parser.cc + ${CMAKE_CURRENT_BINARY_DIR}/stub_caffe_parser.cc +) + +target_compile_options(fmk_parser_stub PRIVATE + -O2 +) + +target_compile_definitions(fmk_parser_stub PRIVATE + $<$:FMK_SUPPORT_DUMP> + PROTOBUF_INLINE_NOT_IN_HEADERS=0 + REUSE_MEMORY=1 + FMK_HOST_INFER +) + +target_include_directories(fmk_parser_stub PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${TOP_DIR}/inc + ${TOP_DIR}/inc/external + ${TOP_DIR}/inc/external/parser + ${TOP_DIR}/inc/external/graph + ${TOP_DIR}/inc/framework + ${CMAKE_BINARY_DIR} + ${CMAKE_CURRENT_BINARY_DIR} +) + +target_link_libraries(fmk_parser_stub PRIVATE + $ +) + +set_target_properties(fmk_parser_stub PROPERTIES + OUTPUT_NAME fmk_parser + LIBRARY_OUTPUT_DIRECTORY stub +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS fmk_parser OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) + +install(TARGETS fmk_parser_stub OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub +) diff --git a/parser/caffe/caffe_custom_parser_adapter.cc b/parser/caffe/caffe_custom_parser_adapter.cc new file mode 100644 index 0000000..74f7b04 --- /dev/null +++ b/parser/caffe/caffe_custom_parser_adapter.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/caffe/caffe_custom_parser_adapter.h" +#include +#include +#include "common/debug/log.h" +#include "common/ge/ge_util.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/omg_inner_types.h" +#include "framework/omg/parser/parser_types.h" +#include "graph/utils/graph_utils.h" +#include "parser/common/op_parser_factory.h" +#include "register/op_registry.h" + +using domi::ParseParamByOpFunc; +using domi::ParseParamFunc; +using std::vector; + +namespace ge { +namespace { +const char *const kConvolution = "Convolution"; +const char *const kInnerProduct = "InnerProduct"; +const int64_t kDimDedaultValue = 1; +const int kBlobIndexOne = 1; +} // namespace + +Status CaffeCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { + GE_CHECK_NOTNULL(op_src); + const LayerParameter *layer = reinterpret_cast(op_src); + GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); + GE_CHECK_NOTNULL(op_dest); + + ParseParamFunc customOpParser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest->GetType(), layer->type()); + GE_CHECK_NOTNULL(customOpParser); + + op_dest->SetName(layer->name()); + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + GE_CHK_BOOL_RET_STATUS(customOpParser(op_src, op) == SUCCESS, FAILED, "Custom parser params failed"); + return SUCCESS; +} + +Status CaffeCustomParserAdapter::ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest) { + GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ", op_src.GetName().c_str(), + op_src.GetOpType().c_str()); + GE_CHECK_NOTNULL(op_dest); + + ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); + GE_CHECK_NOTNULL(custom_op_parser); + + op_dest->SetName(op_src.GetName()); + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + + GE_CHK_BOOL_RET_STATUS(custom_op_parser(op_src, op) == SUCCESS, FAILED, "Custom parser params failed"); + return SUCCESS; +} + +Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op); + const LayerParameter *layer = reinterpret_cast(op_src); + + GE_CHK_BOOL_RET_STATUS(nullptr != layer, FAILED, "Dynamic cast op_src to LayerParameter failed"); + GELOGI("layer: %s blobs_size: %d bottom_size: %d", layer->name().c_str(), layer->blobs_size(), layer->bottom_size()); + if (layer->blobs_size() == 0) { + return SUCCESS; + } + + bool bias_en = false; + int start_pos = layer->bottom_size(); + for (int i = 0; i < layer->blobs_size(); ++i) { + ge::GeTensorPtr weight = ge::MakeShared(); + GE_CHECK_NOTNULL(weight); + GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), "Convert blobs(%d) for layer %s failed", i, + layer->name().c_str()); + GE_IF_BOOL_EXEC(layer->type() == kConvolution && i == kBlobIndexOne, + const ConvolutionParameter &conv_params_src = layer->convolution_param(); + bias_en = conv_params_src.bias_term();); + GE_IF_BOOL_EXEC(layer->type() == kInnerProduct && i == kBlobIndexOne, + const InnerProductParameter &fc_params_src = layer->inner_product_param(); + bias_en = fc_params_src.bias_term();); + auto bias_shape = weight->MutableTensorDesc().GetShape(); + // The num 0, 1, 2, 3 represet the dim index. + bool matched = bias_en && bias_shape.GetDimNum() == static_cast(ge::parser::DIM_DEFAULT_SIZE) && + bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1 && bias_shape.GetDim(2) == 1; + if (matched) { + weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(3)})); + } + matched = layer->type() == kInnerProduct && i == 0 && + bias_shape.GetDimNum() == static_cast(ge::parser::DIM_DEFAULT_SIZE) && + bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1; + if (matched) { + weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(2), bias_shape.GetDim(3)})); + } + + // construct const node + auto const_opdesc = ge::OpDescUtils::CreateConstOp(weight); // use org weight before SetWeights Overwrite + GE_CHECK_NOTNULL(const_opdesc); + auto owner_graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + + // add edge from const to current node + auto const_node = owner_graph->AddNodeFront(const_opdesc); + GE_CHECK_NOTNULL(const_node); + auto index = start_pos + i; + auto valid_input_name = op->GetValidInputNameByIndex(static_cast(index)); + if (valid_input_name.empty()) { + if (node->AddLinkFrom(static_cast(index), const_node) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "AddEdge failed of from Node %s output to Node %s input %d", const_node->GetName().c_str(), + node->GetName().c_str(), index); + } + } else { + if (node->AddLinkFrom(valid_input_name, const_node) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "AddEdge failed of from Node %s output to Node %s input %s", const_node->GetName().c_str(), + node->GetName().c_str(), valid_input_name.c_str()); + } + } + + std::vector original_nodes; + ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); + } + GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, "tvm_origin_input_num", layer->bottom_size())), + GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return + + return SUCCESS; +} +REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(CAFFE, CaffeCustomParserAdapter); +} // namespace ge diff --git a/parser/caffe/caffe_custom_parser_adapter.h b/parser/caffe/caffe_custom_parser_adapter.h new file mode 100644 index 0000000..da09087 --- /dev/null +++ b/parser/caffe/caffe_custom_parser_adapter.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_ +#define PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_ + +#include "parser/caffe/caffe_op_parser.h" + +namespace ge { +class CaffeCustomParserAdapter : public CaffeOpParser { + public: + /** + * @ingroup domi_omg + * @brief parse params of the operation + * @param [in] op_src params to be parsed + * @param [out] op_dest params after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; + + /** + * @ingroup domi_omg + * @brief parse params of the operation + * @param [in] op_src params to be parsed + * @param [out] op_dest params after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest); + + /** + * @ingroup domi_omg + * @brief parse weight of the operation + * @param [in] op_src params to be parsed + * @param [out] node params after parsing + * @return SUCCESS parse successfullyparse failed + * @return FAILED + * @author + */ + Status ParseWeights(const Message *op_src, ge::NodePtr &node) override; +}; +} // namespace ge + +#endif // PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_ diff --git a/parser/caffe/caffe_data_parser.cc b/parser/caffe/caffe_data_parser.cc new file mode 100644 index 0000000..a155e7a --- /dev/null +++ b/parser/caffe/caffe_data_parser.cc @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/caffe/caffe_data_parser.h" +#include +#include +#include "common/debug/log.h" +#include "framework/omg/parser/parser_types.h" +#include "common/util.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "parser/common/op_parser_factory.h" + +using namespace ge::parser; + +namespace ge { +Status CaffeDataParser::GetOutputDesc(const string &name, int dim_size, const std::vector &input_dims, + ge::OpDescPtr &op) { + GE_CHECK_NOTNULL(op); + GELOGI("The input dim size is %zu in layer %s.", input_dims.size(), name.c_str()); + + // Caffe default data type is float32 + GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, DATA_ATTR_NAME_DATA_TYPE, ge::DT_FLOAT)), + GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return + + // Initialize input and output description of OP according to input_dims information + GE_RETURN_WITH_LOG_IF_ERROR(ParseShape(input_dims, op), "data layer %s ParseShape failed", name.c_str()); + + return SUCCESS; +} + +Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op); + const domi::caffe::LayerParameter *layer = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(layer); + GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); + + if (layer->type() == ge::parser::INPUT_TYPE) { + GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", + layer->name().c_str(), layer->type().c_str()); + } else if(layer->type() == ge::parser::DUMMY_DATA) { + GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", + layer->name().c_str(), layer->type().c_str()); + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E11030"); + GELOGE(PARAM_INVALID, "Caffe prototxt has no optype [Input]"); + return FAILED; + } + return SUCCESS; +} + +Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op) { + if (layer->has_input_param()) { + const domi::caffe::InputParameter &input_param = layer->input_param(); + if (input_param.shape_size() == 0) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11027", {"layername", "layertype"}, {layer->name(), layer->type()}); + GELOGE(PARAM_INVALID, + "input_param shape size is zero, caffe layer name [%s], layer type [%s].", + layer->name().c_str(), layer->type().c_str()); + return FAILED; + } + for (int i = 0; i < input_param.shape_size(); i++) { + const domi::caffe::BlobShape &blob_shape = input_param.shape(i); + vector shape; + unordered_map> &shape_map = GetParserContext().input_dims; + std::vector model_dims; + for (auto &blob_shape_dim_temp : blob_shape.dim()) { + model_dims.push_back(blob_shape_dim_temp); + } + string name = layer->name(); + GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); + GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s", + name.c_str()); + } + } else { + // Get from external input + const ge::ParserContext &ctx = GetParserContext(); + std::unordered_map> input_dims = ctx.input_dims; + string name = layer->name(); + auto search = input_dims.find(name); + if (search == input_dims.end()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11028", {"layername", "layertype"}, {layer->name(), layer->type()}); + GELOGE(PARAM_INVALID, + "Caffe prototxt has no input_param or user should set --input_shape in atc parameter, " + "caffe layer name [%s], layer type [%s].", layer->name().c_str(), layer->type().c_str()); + return FAILED; + } + std::vector dims = search->second; + GE_CHK_STATUS_RET(GetOutputDesc(name, dims.size(), dims, op), "Get output desc failed in layer %s.", + name.c_str()); + } + return SUCCESS; +} + +Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op) { + if (layer->has_dummy_data_param()) { + const domi::caffe::DummyDataParameter &dummy_data_param = layer->dummy_data_param(); + if (dummy_data_param.shape_size() == 0) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11027", {"layername", "layertype"}, {layer->name(), layer->type()}); + GELOGE(PARAM_INVALID, + "input_param shape size is zero, caffe layer name [%s], layer type [%s].", + layer->name().c_str(), layer->type().c_str()); + return FAILED; + } + for (int i = 0; i < dummy_data_param.shape_size(); i++) { + const domi::caffe::BlobShape &blob_shape = dummy_data_param.shape(i); + + vector shape; + unordered_map> &shape_map = GetParserContext().input_dims; + std::vector model_dims; + for (auto &blob_shape_dim_temp : blob_shape.dim()) { + model_dims.push_back(blob_shape_dim_temp); + } + + string name = layer->name(); + GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); + GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s", + name.c_str()); + } + } else { + // Get from external input + const ge::ParserContext &ctx = GetParserContext(); + std::unordered_map> input_dims = ctx.input_dims; + string name = layer->name(); + auto search = input_dims.find(name); + if (search == input_dims.end()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11028", {"layername", "layertype"}, {layer->name(), layer->type()}); + GELOGE(PARAM_INVALID, + "Caffe prototxt has no input_param or user should set --input_shape in atc parameter, " + "caffe layer name [%s], layer type [%s].", layer->name().c_str(), layer->type().c_str()); + return FAILED; + } + std::vector dims = search->second; + GE_CHK_STATUS_RET(GetOutputDesc(name, dims.size(), dims, op), "Get output desc failed in layer %s.", + name.c_str()); + } + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(CAFFE, DATA, CaffeDataParser); +} // namespace ge diff --git a/parser/caffe/caffe_data_parser.h b/parser/caffe/caffe_data_parser.h new file mode 100644 index 0000000..ee5f7ad --- /dev/null +++ b/parser/caffe/caffe_data_parser.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_CAFFE_CAFFE_DATA_PARSER_H_ +#define PARSER_CAFFE_CAFFE_DATA_PARSER_H_ + +#include +#include +#include "parser/caffe/caffe_op_parser.h" +#include "parser/common/data_op_parser.h" + +namespace ge { +class CaffeDataParser : public CaffeOpParser, public DataOpParser { + public: + /** + * @ingroup domi_omg + * @brief parse params of the operation + * @param [in] op_src params to be parsed + * @param [out] graph params after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; + + private: + /** + * @ingroup domi_omg + * @brief Get the output dimension according to the input dimension + * @param [in] name the name of the input layer + * @param [in] input_dims the dimension of the input layer + * @param [out] op_def op after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status GetOutputDesc(const std::string &name, int dim_size, + const std::vector &input_dims, ge::OpDescPtr &op); + + // caffe data layer type could be type of `Input` or `DummyData` + Status ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op); + Status ParseParamsForDummyData(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op); +}; +} // namespace ge + +#endif // PARSER_CAFFE_CAFFE_DATA_PARSER_H_ diff --git a/parser/caffe/caffe_op_parser.cc b/parser/caffe/caffe_op_parser.cc new file mode 100644 index 0000000..63d0b6e --- /dev/null +++ b/parser/caffe/caffe_op_parser.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/caffe/caffe_op_parser.h" +#include +#include "parser/common/op_parser_factory.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/omg/parser/parser_types.h" + +using namespace ge::parser; + +using domi::CAFFE; + +namespace ge { +Status CaffeOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { return SUCCESS; } + +Status CaffeOpParser::ParseWeights(const Message *op_src, ge::NodePtr &node) { return SUCCESS; } + +Status CaffeOpParser::AddConstInput(ge::NodePtr &node) { return SUCCESS; } + +void CaffeOpParser::ConvertShape(const BlobProto &proto, std::vector &shape) { + shape.clear(); + + if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) { + // Compatible with old formats, shape description: (num, channels, height, width) + shape.push_back(proto.num()); + shape.push_back(proto.channels()); + shape.push_back(proto.height()); + shape.push_back(proto.width()); + } else { + // The shape of the new format is described with "repeated Int64 dim" + for (int i = 0; i < proto.shape().dim_size(); ++i) { + shape.push_back(proto.shape().dim(i)); + } + } +} + +Status CaffeOpParser::ConvertWeight(const BlobProto &proto, const string &lay_name, ge::GeTensorPtr &weight) { + GE_CHECK_NOTNULL(weight); + std::vector shape_vec; + ConvertShape(proto, shape_vec); + ge::GeShape shape(shape_vec); + // Calculate the number of data in weight + int count = 1; + for (size_t i = 0; i < shape.GetDimNum(); ++i) { + int dim = shape.GetDim(i); + if (dim <= 0) { + GELOGE(FAILED, "Convert weight fail, Blob size invalid"); + return FAILED; + } + + if (dim >= INT64_MAX / count) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11033", {"opname", "blobsize", "reason"}, + {lay_name, std::to_string(dim) + "*" + std::to_string(count), + "it exceeds INT64_MAX[" + std::to_string(INT64_MAX) + "]"}); + GELOGE(FAILED, "Convert weight fail, Blob size exceeds INT64_MAX, dim:%d, count:%d", dim, count); + return FAILED; + } + + count *= dim; + } + return ParseWeightType(proto, shape, count, lay_name, weight); +} + +Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape &shape, int size, + const string &lay_name, ge::GeTensorPtr &weight) { + // Extract weight data and store it in weightdef by float type + GE_CHECK_NOTNULL(weight); + ge::DataType dtype = ge::DT_FLOAT; + if (proto.double_data_size() > 0) { + // Convert by double type + if (size != proto.double_data_size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11033", {"opname", "blobsize", "reason"}, + {lay_name, std::to_string(proto.double_data_size()), + "it does not match shape size[" + std::to_string(size) + "]"}); + GELOGE(FAILED, "Convert weight fail, Blob size does not match shape size, shape size:%d, blob size:%d", size, + proto.double_data_size()); + return FAILED; + } + std::unique_ptr buf(new (std::nothrow) float[size]()); + GE_CHECK_NOTNULL(buf); + for (int i = 0; i < size; ++i) { + buf[i] = proto.double_data(i); + } + GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS, + GELOGW("SetData failed for GeTensor.");); // no need to return + } else if (proto.int8_data().length() > 0) { + if (size != static_cast(proto.int8_data().length())) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11033", {"opname", "blobsize", "reason"}, + {lay_name, std::to_string(proto.int8_data().length()), + "it does not match shape size[" + std::to_string(size) + "]"}); + GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%ld", size, + proto.int8_data().length()); + return FAILED; + } + const char *data_ptr = proto.int8_data().data(); + GE_CHECK_NOTNULL(data_ptr); + GE_IF_BOOL_EXEC( + weight->SetData(reinterpret_cast(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS, + GELOGW("SetData failed for GeTensor.");); // no need to return + dtype = ge::DT_INT8; + } else if (proto.int32_data_size() > 0) { + if (size != proto.int32_data_size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11033", {"opname", "blobsize", "reason"}, + {lay_name, std::to_string(proto.int32_data_size()), + "it does not match shape size[" + std::to_string(size) + "]"}); + GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%d", size, + proto.int32_data_size()); + return FAILED; + } + std::unique_ptr int32_weight_buf(new (std::nothrow) int32_t[size]()); + GE_CHECK_NOTNULL(int32_weight_buf); + for (int i = 0; i < size; ++i) { + int32_weight_buf[i] = proto.int32_data(i); + } + GE_IF_BOOL_EXEC( + weight->SetData(reinterpret_cast(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS, + GELOGW("SetData failed for GeTensor.");); // no need to return + dtype = ge::DT_INT32; + } else if (proto.uint64_data_size() > 0) { + if (size != proto.uint64_data_size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11033", {"opname", "blobsize", "reason"}, + {lay_name, std::to_string(proto.uint64_data_size()), + "it does not match shape size[" + std::to_string(size) + "]"}); + GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%d", size, + proto.uint64_data_size()); + return FAILED; + } + std::unique_ptr uint64_weight_buf(new (std::nothrow) uint64_t[size]()); + GE_CHECK_NOTNULL(uint64_weight_buf); + for (int i = 0; i < size; ++i) { + uint64_weight_buf[i] = proto.uint64_data(i); + } + GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast(uint64_weight_buf.get()), size * sizeof(uint64_t)) != + ge::GRAPH_SUCCESS, + GELOGW("SetData failed for GeTensor.");); // no need to return + dtype = ge::DT_UINT64; + } else { + // Convert by float type + if (size != proto.data_size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11033", {"opname", "blobsize", "reason"}, + {lay_name, std::to_string(proto.data_size()), + "it does not match shape size[" + std::to_string(size) + "]"}); + GELOGE(FAILED, "Convert weight fail, Blob size does not match shape size, shape size:%d, blob.data_size:%d", size, + proto.data_size()); + return FAILED; + } + const float *data_ptr = proto.data().data(); + GE_CHECK_NOTNULL(data_ptr); + GE_IF_BOOL_EXEC( + weight->SetData(reinterpret_cast(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS, + GELOGW("SetData failed for GeTensor.");); // no need to return + } + ge::GeTensorDesc weight_desc = ge::GeTensorDesc(); + weight_desc.Update(shape, ge::FORMAT_NCHW, dtype); + weight->SetTensorDesc(weight_desc); + return SUCCESS; +} + +// Dropout's corresponding op_parser is registered as caffeopparser, optimized in optimization stage. +REGISTER_OP_PARSER_CREATOR(CAFFE, DROPOUT, CaffeOpParser); + +// A new operator added by framework in OM model is used to +// collect and arrange all outputs in the order of the original model's output +// Net output operator does not need special processing in the parse stage, +// and directly registers in the op_parser file +REGISTER_OP_PARSER_CREATOR(CAFFE, NETOUTPUT, CaffeOpParser); +} // namespace ge diff --git a/parser/caffe/caffe_op_parser.h b/parser/caffe/caffe_op_parser.h new file mode 100644 index 0000000..9fa5921 --- /dev/null +++ b/parser/caffe/caffe_op_parser.h @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_CAFFE_CAFFE_OP_PARSER_H_ +#define PARSER_CAFFE_CAFFE_OP_PARSER_H_ + +#include +#include "graph/debug/ge_attr_define.h" +#include "common/util.h" +#include "graph/compute_graph.h" +#include "graph/ge_attr_value.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/types.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/tensor_utils.h" +#include "omg/parser/op_parser.h" +#include "proto/caffe/caffe.pb.h" + +using domi::caffe::ArgMaxParameter; +using domi::caffe::BatchNormParameter; +using domi::caffe::BlobProto; +using domi::caffe::BlobShape; +using domi::caffe::ConcatParameter; +using domi::caffe::ConvolutionParameter; +using domi::caffe::DetectionOutputParameter; +using domi::caffe::EltwiseParameter; +using domi::caffe::FillerParameter; +using domi::caffe::InnerProductParameter; +using domi::caffe::LayerParameter; +using domi::caffe::PoolingParameter; +using domi::caffe::PReLUParameter; +using domi::caffe::ReshapeParameter; +using domi::caffe::ROIAlignParameter; +using domi::caffe::TanHParameter; +using domi::caffe::UpsampleParameter; + +namespace ge { +/** + * @ingroup ge_omg + * @brief Used to parse Caffe operator information + */ +class CaffeOpParser : public OpParser { + public: + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; + + Status ParseParams(const Message *op_src, ge::Operator &op_dest) override { + return domi::SUCCESS; + } + + /** + * @ingroup ge_omg + * @brief parse weight information of the operation + * @param [in] op_src Weight data to be parsed + * @param [out] graph Weight data after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseWeights(const Message *op_src, ge::NodePtr &node) override; + + /** + * @ingroup ge_omg + * @brief add const input node + * @param [in] node to add const input + * @param [out] node after add const input + * @return SUCCESS add const input successfully + * @return FAILED add const input failed + * @author + */ + virtual Status AddConstInput(ge::NodePtr &node); + + protected: + /** + * @ingroup ge_omg + * @brief Convert blob proto to weight definition + * @param [in] proto Weight data to be parsed + * @param [out] weight Weight data after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + static Status ConvertWeight(const BlobProto &proto, const string &lay_name, ge::GeTensorPtr &weight); + + /** + * @ingroup ge_omg + * @brief Convert blob proto to shape definition + * @param [in] proto Shape information before conversion + * @param [out] shape Save converted shape information + */ + static void ConvertShape(const BlobProto &proto, std::vector &shape); + + private: + /** + * @ingroup ge_omg + * @brief Convert blob proto to weight definition + * @param [in] proto Weight data to be parsed + * @param [out] weight Weight data after parsing + * @return SUCCESS parse weight type successfully + * @return FAILED parse failed + */ + static Status ParseWeightType(const BlobProto &proto, const ge::GeShape &shape, + int size, const string &lay_name, ge::GeTensorPtr &weight); +}; +} // namespace ge + +#endif // PARSER_CAFFE_CAFFE_OP_PARSER_H_ diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc new file mode 100644 index 0000000..75f4107 --- /dev/null +++ b/parser/caffe/caffe_parser.cc @@ -0,0 +1,2486 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/caffe/caffe_parser.h" + +#include +#include +#include +#include +#include "parser/common/convert/pb2json.h" +#include "common/debug/log.h" +#include "common/ge/ge_util.h" +#include "common/op_map.h" +#include "common/util/error_manager/error_manager.h" +#include "common/ge_types.h" +#include "common/string_util.h" +#include "external/graph/operator_factory.h" +#include "external/parser/caffe_parser.h" +#include "external/ge/ge_api_types.h" +#include "framework/common/debug/ge_log.h" +#include "graph/optimize/common/params.h" +#include "graph/utils/graph_utils.h" +#include +#include +#include +#include +#include +#include +#include "omg/parser/op_parser.h" +#include "omg/parser/parser_factory.h" +#include "omg/parser/parser_inner_ctx.h" +#include "parser/caffe/caffe_custom_parser_adapter.h" +#include "parser/caffe/caffe_op_parser.h" +#include "parser/common/op_parser_factory.h" +#include "parser/common/pre_checker.h" +#include "framework/omg/parser/parser_types.h" +#include "parser/common/model_saver.h" +#include "parser/common/acl_graph_parser_util.h" +#include "parser/common/proto_file_parser.h" +#include "register/op_registry.h" + +using domi::caffe::LayerParameter; +using domi::caffe::NetParameter; +using domi::ParseParamByOpFunc; +using ge::caffe_op_map; +using ge::CaffeOpParser; +using ge::parser::ModelSaver; +using ge::OpParser; +using ge::OpParserFactory; +using ge::Pb2Json; +using ge::PreChecker; +using std::ifstream; + +#define CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(val, errormsg) \ + do { \ + if (val == nullptr) { \ + GELOGE(ge::PARAM_INVALID, errormsg); \ + ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {errormsg}); \ + return ge::PARAM_INVALID; \ + } \ + } while (0) + +namespace ge { +graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph) { + GE_CHECK_NOTNULL(model_file); + GetParserContext().type = domi::CAFFE; + std::map options; + options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE))); + + // load custom plugin so and proto + AclGrphParseUtil acl_graph_parse_util; + (void)acl_graph_parse_util.AclParserInitialize(options); + + // Create an empty computegraph + ge::ComputeGraphPtr compute_graph = ge::MakeShared("tmpGraph"); + GE_CHECK_NOTNULL(compute_graph); + + graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); + GE_CHECK_NOTNULL(model_parser); + + // parse caffe model_file and weights_file to GE graph + ge::graphStatus ret = model_parser->Parse(model_file, graph); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("Parser graph %s success.", graph.GetName().c_str()); + + auto weights_parser = domi::WeightsParserFactory::Instance()->CreateWeightsParser(domi::CAFFE); + ret = weights_parser->Parse(weights_file, graph); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Weights parse failed. graph: %s", graph.GetName().c_str()); + return ret; + } + GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); + return ge::SUCCESS; +} +} // namespace ge + + +namespace ge { +namespace { +const int32_t kAnchorIndexOne = 1; +const int32_t kAnchorIndexTwo = 2; +const int32_t kAnchorIndexThree = 3; +const int32_t kNumOne = 1; +const size_t kTensorNum = 2; +const int kMaxParseDepth = 5; +const int32_t kMinLineWorldSize = 3; +const int32_t kMaxIdentifier = 536870911; // 2^29 - 1 +const int32_t kBase = 10; +const char *const kPython = "Python"; +const char *const kProposalLayer = "ProposalLayer"; +const char *const kDetectionOutput = "DetectionOutput"; +const char *const kProjectRoot = "project_root"; +const char *const kBeginningMessageType = "domi.caffe.NetParameter"; +const char *const kLayerMessageType = "domi.caffe.LayerParameter"; +const char *const kLayerName = "layer"; +const char *const kLayersName = "layers"; +const char *const kFieldName = "name"; +const char *const kFieldType = "type"; +const char *const kFieldBottom = "bottom"; +const char *const kFieldTop = "top"; +const char *const kFieldBlobs = "blobs"; +const char *const kFieldShape = "shape"; +const char *const kFieldConvParam = "convolution_param"; +const char *const kFieldInnerPro = "inner_product_param"; +const char *const kFieldDim = "dim"; +const char *const kFieldBiasTerm = "bias_term"; +const char *const kDevNull = "/dev/null"; +const char* const kMessage = "message"; +const char* const kLayerParameter = "LayerParameter"; +const char* const kCloseBrace = "}"; +const std::string kOptional = "optional"; +const std::string kRepeated = "repeated"; +const std::string kRequired = "required"; +const std::string kCustom = "custom"; +const std::string kBuiltin = "built-in"; +std::vector kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, + ge::parser::NETOUTPUT}; +const std::set kCustomProtoLayerCommonField = {"name", "type"}; +const std::set kCaffeProtoLayerCommonField = {"name", "type", "bottom", "top", "phase", + "loss_weight", "param", "blobs", "propagate_down", + "include", "exclude"}; +Status CheckPathValid(const char *model_path, const string &custom_proto, string &custom_proto_path, + string &custom_proto_name) { + string path_model = ge::parser::RealPath(model_path); + if (path_model.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {model_path, strerror(errno)}); + GELOGE(FAILED, "Invalid path of model: %s", model_path); + return FAILED; + } + + custom_proto_name = kProjectRoot; + auto pos = custom_proto.find_last_of("/\\"); + if (pos == string::npos) { + custom_proto_path = "./"; + custom_proto_name += '/' + custom_proto; + } else { + custom_proto_path = custom_proto.substr(0, pos); + custom_proto_name += '/' + custom_proto.substr(pos + 1); + } + GELOGI("Check validity of model file: %s and proto file: %s success.", model_path, custom_proto.c_str()); + + return SUCCESS; +} +} // namespace + /* + MultiLabelLMDB?The negligible layer for weight analysis in license plate recognition network of Safe city. + Python: Currently, python custom layer only supports proposal, + and there is no corresponding data in the proposal weight file, so Python layer is ignored. + */ +const set CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxWithLoss", "Accuracy", "Data", + "Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; + +Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) { + if (proto_message.input_size() > 0) { + GELOGI("This net exsit input."); + + if (proto_message.input_dim_size() > 0) { + if (proto_message.input_shape_size() > 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E11001"); + GELOGE(FAILED, "input_dim and input_shape can not both exist!"); + return FAILED; + } + int input_dim_size = proto_message.input_dim_size(); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto_message.input_size() == 0), + ErrorManager::GetInstance().ATCReportErrMessage("E11002"); + return PARAM_INVALID, "Model has no input."); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((input_dim_size / proto_message.input_size() != ge::DIM_DEFAULT_SIZE || + input_dim_size % proto_message.input_size() != 0), + ErrorManager::GetInstance().ATCReportErrMessage( + "E11003", {"input_dim_size", "input_size"}, + {std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); + return FAILED, "Model input_dim size[%d] is not 4 times of input size[%d].", + input_dim_size, proto_message.input_size()) + + for (int i = 0; i < proto_message.input_size(); i++) { + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(proto_message.input(i)); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); + + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); + + for (int j = 0; j < ge::DIM_DEFAULT_SIZE; j++) { + // Can guarantee that it will not cross the border + shape->add_dim(static_cast(proto_message.input_dim(j + i * ge::DIM_DEFAULT_SIZE))); + } + input_data_flag = true; + } + } else if (proto_message.input_shape_size() > 0) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + proto_message.input_shape_size() != proto_message.input_size(), + ErrorManager::GetInstance().ATCReportErrMessage( + "E11004", {"input_shape_size", "input_size"}, + {std::to_string(proto_message.input_shape_size()), std::to_string(proto_message.input_size())}); + return FAILED, "caffe net input_shape size(%d) is not equal input size(%d).", + proto_message.input_shape_size(), proto_message.input_size()); + + for (int i = 0; i < proto_message.input_size(); i++) { + int dim_size = proto_message.input_shape(i).dim_size(); + + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(proto_message.input(i)); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); + + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); + + for (int j = 0; j < dim_size; j++) { + // Can guarantee that it will not cross the border + shape->add_dim(static_cast(proto_message.input_shape(i).dim(j))); + } + input_data_flag = true; + } + } else { + const ge::ParserContext &ctx = ge::GetParserContext(); + std::unordered_map> input_dims = ctx.input_dims; + for (int i = 0; i < proto_message.input_size(); i++) { + string name = proto_message.input(i); + if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input + ErrorManager::GetInstance().ATCReportErrMessage("E11005"); + GELOGE(FAILED, "Model has no input shape."); + return FAILED; + } + std::vector dims = input_dims.at(name); + size_t dim_size = dims.size(); + + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(name); + layer->set_type(ge::parser::INPUT_TYPE); + layer->add_top(proto_message.input(i)); + + domi::caffe::InputParameter *input_param = layer->mutable_input_param(); + GE_CHECK_NOTNULL(input_param); + domi::caffe::BlobShape *shape = input_param->add_shape(); + GE_CHECK_NOTNULL(shape); + + for (size_t j = 0; j < dim_size; j++) { + shape->add_dim(dims.at(j)); + } + input_data_flag = true; + } + } + } + + return SUCCESS; +} + + +Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path, + const string &custom_proto_name, vector &operators) { + google::protobuf::compiler::DiskSourceTree source_tree; + source_tree.MapPath(kProjectRoot, custom_proto_path); + google::protobuf::compiler::Importer importer(&source_tree, nullptr); + importer.Import(custom_proto_name.c_str()); + GELOGI("Import custom proto %s success.", custom_proto_path.c_str()); + + const google::protobuf::Descriptor *descriptor = importer.pool()->FindMessageTypeByName(kBeginningMessageType); + GE_CHECK_NOTNULL(descriptor); + google::protobuf::DynamicMessageFactory factory; + const google::protobuf::Message *proto = factory.GetPrototype(descriptor); + GE_CHECK_NOTNULL(proto); + google::protobuf::Message *message = proto->New(); + GE_CHECK_NOTNULL(message); + + if (ReadModelWithoutWarning(model_path, message) != SUCCESS) { + delete message; + GELOGE(FAILED, "ReadModelWithoutWarning %s failed.", model_path); + return FAILED; + } + + GELOGI("Start to parse model file: %s.", model_path); + const google::protobuf::Descriptor *layer_descriptor = importer.pool()->FindMessageTypeByName(kLayerMessageType); + if (layer_descriptor == nullptr) { + delete message; + ErrorManager::GetInstance().ATCReportErrMessage( + "E19021", {"reason"}, {"Does not find domi.caffe.LayerParameter in google::protobuf::Descriptor"}); + GELOGE(FAILED, "Does not find domi.caffe.LayerParameter in google::protobuf::Descriptor"); + return FAILED; + } + + if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) { + delete message; + GELOGE(FAILED, "ParseLayerParameter failed."); + return FAILED; + } + + delete message; + GELOGI("Parse model: %s by proto: %s success.", model_path, custom_proto_path.c_str()); + return SUCCESS; +} + +Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &custom_proto, + const string &caffe_proto, vector &operators) { + string custom_proto_path = ge::parser::RealPath(custom_proto.c_str()); + if (custom_proto_path.empty()) { + GELOGW("Valid custom proto: %s does not exist, skip parsing custom proto", custom_proto.c_str()); + return SUCCESS; + } + + string custom_proto_name; + if (CheckPathValid(model_path, custom_proto, custom_proto_path, custom_proto_name) != SUCCESS) { + GELOGE(FAILED, "CheckPathValid of model and proto failed."); + return FAILED; + } + + GELOGI("Start to parse model: %s by custom proto: %s.", model_path, custom_proto.c_str()); + Status ret = ParseNetModelByCustomProto(model_path, custom_proto_path, custom_proto_name, operators); + if (ret != SUCCESS) { + GELOGE(FAILED, "parse net model by custom proto failed."); + } + + return ret; +} + +Status CaffeModelParser::GetIdentifier(const std::string &line, int32_t &identifier) { + size_t size = line.size(); + size_t pos = line.find("="); + if (pos == std::string::npos) { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {line.c_str(), "it must contain '='"}); + GELOGE(FAILED, "line: %s must contain char =.", line.c_str()); + return FAILED; + } + for (size_t i = pos + 1; i < size; i++) { + if (line[i] == ';') { + break; + } + if (identifier > kMaxIdentifier || identifier < 0) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11032", {"name", "reason"}, {to_string(identifier), "it can not exceed max value or less than 0"}); + GELOGE(FAILED, "Param identifier exceeded max value, identifier: %d.", identifier); + return FAILED; + } + if (line[i] >= '0' && line[i] <= '9') { + identifier = identifier * kBase + line[i] - '0'; + } + } + + if (identifier == 0) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11032", {"name", "reason"}, {to_string(identifier), "it must larger than 0"}); + GELOGE(FAILED, "Param identifier must larger than zero, identifier: %d.", identifier); + return FAILED; + } + return SUCCESS; +} + +Status CaffeModelParser::SaveIdentifierOpMapInfo(const string &line, std::map &identifier_op_map) { + std::vector op_param_info; + + // get op param info + std::istringstream string_stream(line); + std::string temp; + while (std::getline(string_stream, temp, ' ')) { + if (temp.empty()) { + continue; + } + op_param_info.emplace_back(std::move(temp)); + } + if (op_param_info.size() < kMinLineWorldSize) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E12025", {"size", "compare", "name"}, {to_string(op_param_info.size()), "larger", "min size"}); + GELOGE(FAILED, "Op param size(%zu) must larger than min size.", op_param_info.size()); + return FAILED; + } + if (op_param_info[0] != kOptional && op_param_info[0] != kRepeated && op_param_info[0] != kRequired) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11032", {"name", "reason"}, + {op_param_info[0].c_str(), "First value of op param is not in [optional, repeated, required]"}); + GELOGE(FAILED, "First value of op param is not in [optional, repeated, required], first value: %s", + op_param_info[0].c_str()); + return FAILED; + } + + // get identifier + int32_t identifier = 0; + if (GetIdentifier(line, identifier) != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {to_string(identifier), "Get identifier failed"}); + GELOGE(FAILED, "Get identifier failed, identifier: %d", identifier); + return FAILED; + } + identifier_op_map[identifier] = op_param_info[1]; + return SUCCESS; +} + +Status CaffeModelParser::ParseProtoFile(const string &proto_file, std::map &identifier_op_map) { + ifstream read_file; + read_file.open(proto_file, std::ios::in); + if (read_file.fail()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, + {proto_file.c_str(), "ifstream open failed"}); + GELOGE(FAILED, "Ifsream open caffe proto failed."); + return FAILED; + } + + std::string line; + bool save_flag = false; + while (std::getline(read_file, line)) { + // set save flag when find message LayerParameter + if (line.find(kMessage) != std::string::npos && line.find(kLayerParameter) != std::string::npos) { + save_flag = true; + continue; + } + // stop to save when message end find (}) + if (save_flag && line.find(kCloseBrace) != std::string::npos) { + break; + } + // save identifier and op info + if (save_flag) { + if (line.find(kRepeated) == std::string::npos && line.find(kOptional) == std::string::npos && + line.find(kRequired) == std::string::npos) { + continue; + } + if (SaveIdentifierOpMapInfo(line, identifier_op_map) != SUCCESS) { + read_file.close(); + GELOGE(FAILED, " Save Identifier op map Info failed."); + return FAILED; + } + } + } + read_file.close(); + return SUCCESS; +} + +Status CaffeModelParser::ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message) { + int32_t copy_fd = mmDup(STDERR_FILENO); + if (copy_fd < 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E19020", {"file"}, {"STDERR_FILENO"}); + GELOGE(FAILED, "Dup failed: %d.", copy_fd); + return FAILED; + } + + int32_t fd = mmOpen(kDevNull, O_RDWR); + if (fd < 0) { + (void)mmClose(copy_fd); + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {kDevNull, strerror(errno)}); + GELOGE(FAILED, "Open file %s failed.", kDevNull); + return FAILED; + } + + if (mmDup2(fd, STDERR_FILENO) < 0) { + (void)mmClose(fd); + (void)mmClose(copy_fd); + ErrorManager::GetInstance().ATCReportErrMessage("E19020", {"file"}, {"STDERR_FILENO"}); + GELOGE(FAILED, "Re-orient failed."); + return FAILED; + } + + if (ReadCaffeModelFromText(model_path, message) != SUCCESS) { + (void)mmClose(fd); + (void)mmClose(copy_fd); + GELOGE(FAILED, "ReadCaffeModelFromText %s failed.", model_path); + return FAILED; + } + + if (mmDup2(copy_fd, STDERR_FILENO) < 0) { + (void)mmClose(fd); + (void)mmClose(copy_fd); + ErrorManager::GetInstance().ATCReportErrMessage("E19020", {"file"}, {"STDERR_FILENO"}); + GELOGE(FAILED, "Re-orient failed."); + return FAILED; + } + (void)mmClose(fd); + (void)mmClose(copy_fd); + + return SUCCESS; +} + +Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google::protobuf::Message *message) { + GE_CHECK_NOTNULL(model_path); + GE_CHECK_NOTNULL(message); + GELOGI("Start to read model file: %s.", model_path); + std::ifstream fs(model_path, std::ifstream::in); + if (!fs.is_open()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {model_path, "ifstream open failed"}); + GELOGE(FAILED, "Open file %s failed.", model_path); + return FAILED; + } + + google::protobuf::io::IstreamInputStream input(&fs); + google::protobuf::TextFormat::Parser model_parser; + model_parser.AllowUnknownField(true); + if (!model_parser.Parse(&input, message)) { + fs.close(); + ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {model_path}); + GELOGE(FAILED, "Parse model file %s failed.", model_path); + return FAILED; + } + fs.close(); + GELOGI("Read model file: %s success.", model_path); + + return SUCCESS; +} + +Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, + const google::protobuf::Message *message, + vector &operators) { + auto field_name = layer_descriptor->FindFieldByName(kFieldName); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); + auto field_type = layer_descriptor->FindFieldByName(kFieldType); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); + + const google::protobuf::Reflection *reflection = message->GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); + vector field_desc; + reflection->ListFields(*message, &field_desc); + for (auto &field : field_desc) { + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); + // Only care about layers + if (field->name() != kLayerName) { + continue; + } + if (!field->is_repeated()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {field->name().c_str(), "LayerParameter should be repeated"}); + GELOGE(FAILED, "LayerParameter should be repeated."); + return FAILED; + } + + int field_size = reflection->FieldSize(*message, field); + GELOGI("Total Layer num of model file is %d", field_size); + for (int i = 0; i < field_size; ++i) { + const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); + const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); + GE_CHECK_NOTNULL(layer_reflection); + + string op_name = layer_reflection->GetString(layer_message, field_name); + string op_type = layer_reflection->GetString(layer_message, field_type); + if (domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type) == nullptr) { + continue; + } + if (CreateCustomOperator(op_name, op_type, &layer_message, i, operators) != SUCCESS) { + GELOGE(FAILED, "CreateCustomOperator failed, name: %s, type: %s.", op_name.c_str(), op_type.c_str()); + return FAILED; + } + } + } + return SUCCESS; +} + +Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, const google::protobuf::Message *message, + int index, vector &operators) { + if (op_name.empty() || op_type.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E12026", {"name", "type"}, {op_name.c_str(), op_type.c_str()}); + GELOGE(FAILED, "Name or type of layer is empty, name: %s, type: %s.", op_name.c_str(), op_type.c_str()); + return FAILED; + } + + GELOGI("Start to create new operator, name: %s, type: %s, index: %d.", op_name.c_str(), op_type.c_str(), index); + ge::Operator ops(op_name, op_type); + if (ops.GetName() != op_name) { + ErrorManager::GetInstance().ATCReportErrMessage("E12027", {"name", "type"}, {op_name.c_str(), op_type.c_str()}); + GELOGE(FAILED, "Create operator failed, name: %s, type: %s, index: %d.", op_name.c_str(), op_type.c_str(), index); + return FAILED; + } + + if (ParseOperatorAttrs(message, 1, ops) != SUCCESS) { + GELOGE(FAILED, "ParseOperatorAttrs of %s failed.", op_name.c_str()); + return FAILED; + } + + operators.emplace_back(ops); + GELOGI("Create new operator success, name: %s, type: %s, index: %d.", op_name.c_str(), op_type.c_str(), index); + + return SUCCESS; +} + +Status CaffeModelParser::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { + if (depth > kMaxParseDepth) { + GELOGE(FAILED, "Message depth can not exceed %d.", kMaxParseDepth); + return FAILED; + } + + const google::protobuf::Reflection *reflection = message->GetReflection(); + GE_CHECK_NOTNULL(reflection); + vector field_desc; + reflection->ListFields(*message, &field_desc); + + for (auto &field : field_desc) { + GE_CHECK_NOTNULL(field); + if (field->is_repeated()) { + if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { + GELOGE(FAILED, "Parse repeated field %s failed.", field->name().c_str()); + return FAILED; + } + } else { + if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { + GELOGE(FAILED, "Parse field %s failed.", field->name().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} + +Status CaffeModelParser::ParseField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, + int depth, ge::Operator &ops) { + GELOGD("Start to parse field: %s.", field->name().c_str()); + switch (field->cpp_type()) { +#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ + valuetype value = reflection->Get##method(*message, field); \ + GELOGD("Parse result(%s : %##logtype)", field->name().c_str(), value); \ + (void)ops.SetAttr(field->name(), value); \ + break; \ + } + CASE_FIELD_TYPE(INT32, Int32, int32_t, d); + CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); + CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); + CASE_FIELD_TYPE(FLOAT, Float, float, f); + CASE_FIELD_TYPE(BOOL, Bool, bool, d); +#undef CASE_FIELD_TYPE + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); + int value = reflection->GetEnum(*message, field)->number(); + GELOGD("Parse result(%s : %d)", field->name().c_str(), value); + (void)ops.SetAttr(field->name(), value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + string value = reflection->GetString(*message, field); + GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); + (void)ops.SetAttr(field->name(), value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { + GELOGE(FAILED, "ParseOperatorAttrs of %s failed.", field->name().c_str()); + return FAILED; + } + break; + } + default: { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {field->name().c_str(), "Unsupported field type"}); + GELOGE(FAILED, "Unsupported field type, name: %s.", field->name().c_str()); + return FAILED; + } + } + GELOGD("Parse field: %s success.", field->name().c_str()); + return SUCCESS; +} + +Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, + ge::Operator &ops) { + GELOGD("Start to parse field: %s.", field->name().c_str()); + int field_size = reflection->FieldSize(*message, field); + if (field_size <= 0) { + GELOGE(FAILED, "Size of repeated field %s must bigger than 0", field->name().c_str()); + return FAILED; + } + + switch (field->cpp_type()) { +#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ + vector attr_value; \ + for (int i = 0; i < field_size; i++) { \ + valuetype value = reflection->GetRepeated##method(*message, field, i); \ + attr_value.push_back(value); \ + } \ + (void)ops.SetAttr(field->name(), attr_value); \ + break; \ + } + CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); + CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); + CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); + CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); + CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); + CASE_FIELD_TYPE_REPEATED(STRING, String, string); +#undef CASE_FIELD_TYPE_REPEATED + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + for (int i = 0; i < field_size; ++i) { + const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); + if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { + GELOGE(FAILED, "ParseOperatorAttrs of field: %s failed.", field->name().c_str()); + return FAILED; + } + } + break; + } + default: { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {field->name().c_str(), "Unsupported field type"}); + GELOGE(FAILED, "Unsupported field type, name: %s.", field->name().c_str()); + return FAILED; + } + } + GELOGD("Parse repeated field: %s success.", field->name().c_str()); + return SUCCESS; +} + +void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { + auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name); + if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) { + iter_node_name->second.emplace_back(top_index); + } else { + std::vector index_v; + index_v.emplace_back(top_index); + ge::GetParserContext().out_nodes_map.emplace(layer_name, index_v); + } + ge::GetParserContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index)); +} + +Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) { + if (ge::GetParserContext().user_out_nodes_top_vec.empty()) { + return SUCCESS; + } + + ge::GetParserContext().out_nodes_map.clear(); + ge::GetParserContext().user_out_nodes.clear(); + int32_t layer_count = proto_message.layer_size(); + const std::vector &user_out_nodes_top_vec = + ge::GetParserContext().user_out_nodes_top_vec; + + for (const auto &top_name : user_out_nodes_top_vec) { + bool find_node_falg = false; + string layer_name; + int32_t top_index = 0; + for (int32_t i = layer_count - 1; i >= 0; --i) { + domi::caffe::LayerParameter &layer = + const_cast(proto_message.layer(i)); + + for (int j = 0; j < layer.top_size(); ++j) { + string top_blob_name = layer.top(j); + if (top_blob_name != top_name) { + continue; + } + + find_node_falg = true; + layer_name.assign(layer.name()); + top_index = static_cast(j); + break; + } + if (find_node_falg) { + break; + } + } + if (!find_node_falg || layer_name.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11032", {"name", "reason"}, {top_name.c_str(), "Cannot find the top_name, which is invalid"}); + GELOGE(PARAM_INVALID, "Cannot find top_name[%s], which is invalid", top_name.c_str()); + return PARAM_INVALID; + } + GELOGD("Node[%s] find top_name[%s], top_index[%ld]", layer_name.c_str(), top_name.c_str(), top_index); + AddOutputInfoToContext(layer_name, top_index); + } + return SUCCESS; +} + +Status CaffeModelParser::AddBlobsToMap(const domi::caffe::LayerParameter &layer, + std::map &inplace_blob_name_remapping) { + if (layer.type() == ge::parser::NETOUTPUT) { + return SUCCESS; + } + + if (layer.top_size() <= 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E19011", {"opname"}, {layer.name()}); + GELOGE(FAILED, "The output size of layer %s needs to be greater than zero.", layer.name().c_str()); + return FAILED; + } + + // Need to check if the input is the output of 'inplace' + for (int i = 0; i < layer.bottom_size(); ++i) { + std::string blob_name = layer.bottom(i); + auto iter = inplace_blob_name_remapping.find(blob_name); + if (iter != inplace_blob_name_remapping.end()) { + blob_name = iter->second; + } + bottom_blobs_map_[blob_name].emplace_back(std::make_pair(layer.name(), i)); + } + + // Handling 'inplace' scenarios + for (int j = 0; j < layer.top_size(); ++j) { + std::string top_blob_name = layer.top(j); + if (IsInplaceTopBlob(layer, top_blob_name)) { + std::string remapped_blob_name = RemapTopNameByLayer(layer, top_blob_name, j); + inplace_blob_name_remapping[top_blob_name] = remapped_blob_name; + top_blob_name = remapped_blob_name; + } + top_blobs_map_[top_blob_name].emplace_back(std::make_pair(layer.name(), j)); + } + + return SUCCESS; +} + +bool CaffeModelParser::IsOpAttrEmpty(const ge::Operator &op, const std::string &type) { + const std::map attrs = op.GetAllAttrNamesAndTypes(); + + if (type == kCustom) { + for (const auto &attr : attrs) { + if (kCustomProtoLayerCommonField.count(attr.first) == 0) { + GELOGI("Custom op[%s] attr name[%s] exists, not empty.", op.GetName().c_str(), attr.first.c_str()); + return false; + } + } + } else if (type == kBuiltin) { + for (const auto &attr : attrs) { + if (kCaffeProtoLayerCommonField.count(attr.first) == 0) { + GELOGI("Built-in op[%s] attr name[%s] exists, not empty.", op.GetName().c_str(), attr.first.c_str()); + return false; + } + } + } + + return true; +} + +Status CaffeModelParser::GetCustomOp(const domi::caffe::LayerParameter &layer, vector &operators) { + string op_type = layer.type(); + string op_name = layer.name(); + + bool is_search_built_in_layer = false; + for (ge::Operator &custom_op : custom_operator_) { + if (custom_op.GetName() == layer.name() && custom_op.GetOpType() == op_type) { + if (IsOpAttrEmpty(custom_op, kCustom)) { + GELOGW("Custom op attr is empty, should try to get op params from built-in layer."); + is_search_built_in_layer = true; + } else { + operators.emplace_back(custom_op); + GELOGI("Find custom op success."); + return SUCCESS; + } + break; + } + } + + if (custom_operator_.empty()) { + GELOGW("Custom operator is empty, should try to get op params from built-in layer."); + is_search_built_in_layer = true; + } + + if (is_search_built_in_layer) { + const google::protobuf::Message *layer_message = reinterpret_cast(&layer); + Status status = CreateCustomOperator(op_name, op_type, layer_message, 0, operators); + if (status != SUCCESS || operators.empty()) { + GELOGE(status, "CreateCustomOperator failed, name: %s, type: %s.", op_name.c_str(), op_type.c_str()); + return FAILED; + } + if (IsOpAttrEmpty(operators[0], kBuiltin)) { + GELOGW("Custom and built-in op attr param is empty, name: %s, type: %s.", op_name.c_str(), op_type.c_str()); + } + GELOGI("Search built-in layer success."); + } + return SUCCESS; +} + +Status CaffeModelParser::ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op, + std::shared_ptr &op_parser) { + GE_CHECK_NOTNULL(op); + GE_CHECK_NOTNULL(op_parser); + string op_type = layer.type(); + + Status status = FAILED; + ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type); + if (parse_param_func == nullptr) { + // Parsing weight information through opparser + status = op_parser->ParseParams(&layer, op); + } else { + // The custom op defined by customer deals with parse params separately + std::shared_ptr caffe_custom_op_parser = + std::dynamic_pointer_cast(op_parser); + vector custom_operator; + status = GetCustomOp(layer, custom_operator); + if (status != SUCCESS || custom_operator.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {layer.name(), op_type}); + GELOGE(status, "Parse Params for custom op [%s] failed, optype [%s]", layer.name().c_str(), op_type.c_str()); + return status; + } + status = caffe_custom_op_parser->ParseParams(custom_operator[0], op); + } + + if (status != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {layer.name(), op_type}); + GELOGE(status, "Parse Params for op [%s] fail, optype [%s]", layer.name().c_str(), op_type.c_str()); + return status; + } + + return SUCCESS; +} + +Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + // Release in node destructor + string op_type; + + // Python type parsing is supported in the model file. Python layer is a user-defined layer, + // which can represent a variety of operator types. Currently, proposal operator is supported + if (layer.type() == kPython) { + // Judge whether there is Python_Param. If not, it is illegal + if (!layer.has_python_param()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11006", {"opname"}, {layer.name()}); + GELOGE(FAILED, "Op[%s] optype[Python] has no python_param.", layer.name().c_str()); + return FAILED; + } + + const domi::caffe::PythonParameter &python_param = layer.python_param(); + // Judge whether it is a Proposal operator + if (python_param.layer() == kProposalLayer) { + ErrorManager::GetInstance().ATCReportErrMessage("E11031", {"opname"}, {layer.name()}); + GELOGE(PARAM_INVALID, "Python Layer %s need to be rewritten according to product directions", + layer.name().c_str()); + return FAILED; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E11007", {"opname"}, {python_param.layer()}); + GELOGE(FAILED, + "If optype is [Python], opname must be [ProposalLayer], " + "but actual opname is [%s].", + python_param.layer().c_str()); + return FAILED; + } + } else { + op_type = layer.type(); + // User defined duplicate name operator processing + auto m_iter = ge::GetParserContext().op_conf_map.find(op_type); + // User specified configuration item found + if (m_iter != ge::GetParserContext().op_conf_map.end()) { + op_type = m_iter->second; + } + // General layer layer, search optype + auto iter = caffe_op_map.find(op_type); + if (iter == caffe_op_map.end()) { + if (op_type == kDetectionOutput) { + ErrorManager::GetInstance().ATCReportErrMessage("E11008"); + GELOGE(FAILED, + "Op type 'DetectionOutput' is confused. " + "Suggest you modify the model file and use a explicit type, such as 'FSRDetectionOutput' or " + "'SSDDetectionOutput'."); + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E11009", {"opname", "optype"}, {layer.name(), op_type}); + GELOGE(FAILED, + "Unsupport op[%s] optype[%s], " + "you should customize the op at first.", + layer.name().c_str(), op_type.c_str()); + } + + return FAILED; + } + op_type = iter->second; + } + GELOGD("Caffe layer name:%s, layer type %s", layer.name().c_str(), op_type.c_str()); + // create OpParser + std::shared_ptr factory = OpParserFactory::Instance(domi::CAFFE); + GE_CHECK_NOTNULL(factory); + std::shared_ptr op_parser = factory->CreateOpParser(op_type); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_parser == nullptr, return FAILED, "op_parser is null, op_type: %s.", + op_type.c_str()); + + ge::OpDescPtr op; + // Process change of tensordesc initialization of opdesc, + // The previous process tensordesc was constructed according to the graph structure in the builder stage + // The current process requires tensordesc to determine before the opdesc of the operator is added to the graph + GE_RETURN_IF_ERROR(AddTensorDescToOpDescByIr(op, layer, op_type)); + GELOGI("After AddTensorDescToOpDescByIr op[%s] type[%s] have input size: %zu, output size: %zu", + op->GetName().c_str(), op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize()); + // op parser execute + GE_RETURN_IF_ERROR(ParseOpParam(layer, op, op_parser)); + GELOGI("After op parser op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(), + op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize()); + + // Caffe has also been plug-in at present. Here it is directly set to NCHW + // Set input and output format + GELOGI("Enter caffe parser. op name:%s, type:%s", op->GetName().c_str(), op->GetType().c_str()); + // inputDescsPtr and outputDescsPtr are guaranteed not to be nullptr + auto inputDescsPtr = op->GetAllInputsDescPtr(); + auto outputDescsPtr = op->GetAllOutputsDescPtr(); + ge::Format format = ge::FORMAT_NCHW; + + for (auto &inputDescPtr : inputDescsPtr) { + GE_CHECK_NOTNULL(inputDescPtr); + inputDescPtr->SetFormat(format); + inputDescPtr->SetOriginFormat(format); + } + for (auto &outputDescPtr : outputDescsPtr) { + GE_CHECK_NOTNULL(outputDescPtr); + outputDescPtr->SetFormat(format); + outputDescPtr->SetOriginFormat(format); + } + + ge::NodePtr node = graph->AddNode(op); + if (node == nullptr) { + GELOGE(FAILED, "call Graph add node failed, op name:%s, type:%s", op->GetName().c_str(), op->GetType().c_str()); + return FAILED; + } + + // Caffe's reshape is different from IR definition, which has only one input. + // In caffe process, after constructing reshape according to IR, the second input of reshape is empty. + // So a constant node needs to be added to reshape as the second input. + // AddConstInput is a function defined in caffe_op_parser, override in caffe_reshape_parser. + std::shared_ptr caffe_op_parser = std::static_pointer_cast(op_parser); + GE_CHECK_NOTNULL(caffe_op_parser); + Status status; + status = caffe_op_parser->AddConstInput(node); + if (status != SUCCESS) { + GELOGE(FAILED, "add const input to node %s fail.", node->GetOpDesc()->GetName().c_str()); + return status; + } + + node_map[layer.name()] = node; + return SUCCESS; +} + +Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer) { + GE_CHECK_NOTNULL(op_desc); + // Data node input and output tensordesc added in parserparam + if (op_desc->GetType() == ge::parser::DATA) { + return SUCCESS; + } + + for (int i = 0; i < layer.bottom_size(); i++) { + ge::GeTensorDesc input_tensor; + GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); + } + GELOGD("AddTensorInputDescToOpDesc, op name: %s, type: %s, input num: %d", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), layer.bottom_size()); + // Output number + int32_t output_tensor_num = layer.top_size(); + GELOGD("AddTensorOutputDescToOpDesc, op name: %s, type: %s, output num: %d", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), output_tensor_num); + for (int32_t j = 0; j < output_tensor_num; j++) { + ge::GeTensorDesc output_tensor; + GE_RETURN_IF_ERROR(op_desc->AddOutputDesc(output_tensor)); + } + + // yolo v2 YoloDetectionOutput + if (op_desc->GetType() == ge::parser::YOLODETECTIONOUTPUT) { + ge::GeTensorDesc input_tensor; + GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); + GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); + GELOGD( + "Current op type is YOLODETECTIONOUTPUT, add 2 additional inputs" + "while it's original input num is: %d", + layer.bottom_size()); + } + + // Netoutput node processing + if (op_desc->GetType() == ge::parser::NETOUTPUT) { + size_t input_output_tensor_num = 0; + if (!ge::GetParserContext().user_out_nodes.empty()) { + // User specified output + input_output_tensor_num = ge::GetParserContext().user_out_nodes.size(); + } else { + for (auto t_iter = top_blobs_map_.begin(); t_iter != top_blobs_map_.end(); t_iter++) { + auto b_iter = bottom_blobs_map_.find(t_iter->first); + // Find the output node of the network + if (b_iter == bottom_blobs_map_.end()) { + input_output_tensor_num += top_blobs_map_[t_iter->first].size(); + } + } + } + // add tensordesc + GELOGD( + "Current op type is NETOUTPUT, add additional input&output num: %zu." + "while it's original input num is: %d, output num is: %d", + input_output_tensor_num, layer.bottom_size(), output_tensor_num); + for (size_t i = 0; i < input_output_tensor_num; i++) { + ge::GeTensorDesc input_tensor; + GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); + ge::GeTensorDesc output_tensor; + GE_RETURN_IF_ERROR(op_desc->AddOutputDesc(output_tensor)); + } + } + return SUCCESS; +} + +Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, + const string &op_type) { + if (std::find(kAddTensorIrSkipNodes.begin(), kAddTensorIrSkipNodes.end(), op_type) != kAddTensorIrSkipNodes.end()) { + op_desc = ge::MakeShared(layer.name(), op_type); + GE_CHECK_NOTNULL(op_desc); + Status ret = AddTensorDescToOpDesc(op_desc, layer); + if (ret != SUCCESS) { + GELOGE(FAILED, "op[%s] type[%s] AddTensorDescToOpDesc failed.", layer.name().c_str(), op_type.c_str()); + } + return ret; + } + + // Get opDesc by ir + string layer_name = layer.name(); + ge::Operator op_factory = ge::OperatorFactory::CreateOperator(layer_name, op_type); + if (op_factory.GetName() != layer.name()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11011", {"opname", "optype"}, {layer_name, op_type}); + GELOGE(FAILED, "IR for op[%s] optype[%s] is not registered.", layer_name.c_str(), op_type.c_str()); + return FAILED; + } else { + op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_factory); + GE_CHECK_NOTNULL(op_desc); + auto valid_size = layer.bottom_size(); + GELOGI("After GetOpDescFromOperator op[%s] type[%s] have all input size: %zu, caffe_input_size:%d output size: %zu", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), + op_desc->GetAllInputsSize(), valid_size, op_desc->GetOutputsSize()); + for (int i = 0; i < valid_size; i++) { + ge::GeTensorDesc input_tensor; + std::string input_name; + ge::graphStatus ret = ge::GRAPH_SUCCESS; + // Only two case is supported fow now when there are optional inputs + // x means optional, o means requierd input + // a. ooxxx, layer.bottom_size=number of o and x + // b. oxoxoxox, layer.bottom_size=number of o + if (static_cast(i) >= op_desc->GetInputsSize()) { + ret = op_desc->UpdateInputDesc(static_cast(i), input_tensor); + } else { + input_name = op_desc->GetValidInputNameByIndex(static_cast(i)); + ret = op_desc->UpdateInputDesc(input_name, input_tensor); + } + + if (ret != ge::GRAPH_SUCCESS) { + GELOGW("op [%s], type[%s], update input(%d) with name %s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i, input_name.c_str()); + } else { + GELOGI("op [%s], type[%s], update input(%d) with name %s success", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i, input_name.c_str()); + } + } + + for (int i = 0; i < layer.top_size(); i++) { + ge::GeTensorDesc output_tensor; + ge::graphStatus ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(i), output_tensor); + if (ret != ge::GRAPH_SUCCESS) { + GELOGW("op [%s], type[%s], update output(%d) with name %s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i, op_desc->GetOutputNameByIndex(i).c_str()); + } else { + GELOGI("op [%s], type[%s], update output(%d) with name %s success", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i, op_desc->GetOutputNameByIndex(i).c_str()); + } + } + } + return SUCCESS; +} + +Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + // Traversal input + for (auto b_iter = bottom_blobs_map_.begin(); b_iter != bottom_blobs_map_.end(); b_iter++) { + // Find the top blob corresponding to the bottom blob + auto t_iter = top_blobs_map_.find(b_iter->first); + // Unable to find the output corresponding to the input, error reported + if (t_iter == top_blobs_map_.end()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11012", {"bottom_blob", "layer", "bottom_index"}, + {b_iter->first, b_iter->second[0].first, std::to_string(b_iter->second[0].second)}); + GELOGE(FAILED, "Unknown bottom blob '%s', in layer '%s', bottom index:%d.", b_iter->first.c_str(), + b_iter->second[0].first.c_str(), b_iter->second[0].second); + return PARAM_INVALID; + } + + vector> &top_blob_layers = t_iter->second; + vector> &bottom_blob_layers = b_iter->second; + // 1.Traversal output, all input layers of the current blob + for (auto &top_blob_layer_pair : top_blob_layers) { + // 2.Traversal input, all output layers of the current blob + for (auto &bottom_blob_layer_pair : bottom_blob_layers) { + // Find the layer for this output + auto top_node_iter = node_map.find(top_blob_layer_pair.first); + // Find the layer for this input + auto bottom_node_iter = node_map.find(bottom_blob_layer_pair.first); + if (top_node_iter != node_map.end() && bottom_node_iter != node_map.end()) { + // Output node top_node_iter->second, + // Output index top_blob_layer_pair.second + // input node bottom_node_iter->second + // input index bottom_blob_layer_pair.second + GELOGD("Start add edge: From %s:%d To %s:%d.", top_node_iter->second->GetName().c_str(), + top_blob_layer_pair.second, bottom_node_iter->second->GetName().c_str(), + bottom_blob_layer_pair.second); + ge::OutDataAnchorPtr out_archor_ptr = top_node_iter->second->GetOutDataAnchor(top_blob_layer_pair.second); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = bottom_node_iter->second->GetInDataAnchor(bottom_blob_layer_pair.second); + GE_CHECK_NOTNULL(in_archor_ptr); + GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, + ErrorManager::GetInstance().ATCReportErrMessage( + "E11013", {"opname1", "opname2"}, + {top_node_iter->second->GetName(), bottom_node_iter->second->GetName()}); + GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", + top_node_iter->second->GetName().c_str(), bottom_node_iter->second->GetName().c_str()); + return INTERNAL_ERROR;); + } + GE_IF_BOOL_EXEC(top_node_iter == node_map.end(), ErrorManager::GetInstance().ATCReportErrMessage( + "E11014", {"opname"}, {top_blob_layer_pair.first}); + GELOGE(INTERNAL_ERROR, "Failed to find top layer name: %s.", top_blob_layer_pair.first.c_str()); + return ge::FAILED;) + GE_IF_BOOL_EXEC( + top_node_iter == node_map.end(), + ErrorManager::GetInstance().ATCReportErrMessage("E11015", {"opname"}, {bottom_blob_layer_pair.first}); + GELOGE(INTERNAL_ERROR, "Failed to find bottom layer name: %s.", bottom_blob_layer_pair.first.c_str()); + return ge::FAILED;) + } + } + } + + return SUCCESS; +} + +bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) { + bool ret = false; + auto iter = ge::GetParserContext().out_nodes_map.find(op_name); + if (iter != ge::GetParserContext().out_nodes_map.end()) { + std::vector tmp_index_v; + for (int32_t id : iter->second) { + if (index == id) { + ret = true; + } else { + tmp_index_v.emplace_back(id); + } + } + // To prevent specifying network output again in the build phase, need to delete the output node in the map list. + if (ret) { + ge::GetParserContext().out_nodes_map.erase(op_name); + ge::GetParserContext().out_nodes_map.emplace(op_name, tmp_index_v); + } + } + return ret; +} + +Status CaffeModelParser::AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + ge::NodePtr net_output_node = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT); + if (net_output_node == nullptr) { + GELOGE(INTERNAL_ERROR, "Can not find netoutput node."); + return INTERNAL_ERROR; + } + uint32_t net_output_num = net_output_node->GetAllInDataAnchorsSize(); + int32_t index = 0; + const std::vector> &user_out_nodes = ge::GetParserContext().user_out_nodes; + for (const auto &out_pair : user_out_nodes) { + auto node_iter = node_map.find(out_pair.first); + GELOGI("Add to output, node name: %s", out_pair.first.c_str()); + if (node_iter != node_map.end()) { + if ((static_cast(out_pair.second) >= node_iter->second->GetAllOutDataAnchorsSize()) || + (static_cast(index) >= net_output_num)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E11016", {"opname", "outputindex", "totlaloutputindex", "inputindex", "totlalinputindex"}, + {out_pair.first.c_str(), std::to_string(out_pair.second), + std::to_string(node_iter->second->GetAllOutDataAnchorsSize()), std::to_string(index), + std::to_string(net_output_num)}); + GELOGE(INTERNAL_ERROR, + "Add op %s to NetOutput faild, current node output index:%d should < %u. NetOutput" + "input_index:%d should < %u.", + out_pair.first.c_str(), out_pair.second, node_iter->second->GetAllOutDataAnchorsSize(), index, + net_output_num); + return INTERNAL_ERROR; + } + GELOGD("Start add edge for user out node: From %s:%d To %s:%d.", node_iter->second->GetName().c_str(), + out_pair.second, net_output_node->GetName().c_str(), index); + ge::OutDataAnchorPtr out_archor_ptr = node_iter->second->GetOutDataAnchor(out_pair.second); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = net_output_node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_archor_ptr); + if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E11013", {"opname1", "opname2"}, + {node_iter->second->GetName(), net_output_node->GetName()}); + GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", node_iter->second->GetName().c_str(), + net_output_node->GetName().c_str()); + return INTERNAL_ERROR; + } + ++index; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E11017", {"opname"}, {out_pair.first}); + GELOGE(PARAM_INVALID, "Can not find out_node:%s, you should check --out_nodes.", out_pair.first.c_str()); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + ge::NodePtr node = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT); + + GE_RETURN_WITH_LOG_IF_FALSE(node != nullptr, "Net without output, some phase failed in front."); + + int32_t index = 0; + for (int32_t i = 0; i < proto_message.layer_size(); i++) { + const domi::caffe::LayerParameter &layer = proto_message.layer(i); + + if (!CheckValidLayer(layer)) { + continue; + } + + for (int i = 0; i < layer.top_size(); i++) { + string top = layer.top(i); + // Handling 'inplace' scenarios + if (IsInplaceTopBlob(layer, top)) { + top = RemapTopNameByLayer(layer, top, i); + } + + auto t_iter = top_blobs_map_.find(top); + + GE_RETURN_WITH_LOG_IF_FALSE(t_iter != top_blobs_map_.end(), "Failed to find top: %s, layer name:%s", top.c_str(), + layer.name().c_str()); + + // Find the bottom blob corresponding to the top blob + auto b_iter = bottom_blobs_map_.find(t_iter->first); + if (b_iter != bottom_blobs_map_.end() && !IsOutputTop(layer.name(), i)) { + continue; + } + + // If not found, add to the output side of the output + // Find the layer for this output + auto top_node_iter = node_map.find(layer.name()); + GELOGI("output in top_blob: %s", layer.name().c_str()); + if (top_node_iter != node_map.end()) { + // add edge + // Output node, output index, input node, input index + GELOGD("Start add edge for out node: From %s:%d To %s:%d.", top_node_iter->second->GetName().c_str(), i, + node->GetName().c_str(), index); + ge::OutDataAnchorPtr out_archor_ptr = top_node_iter->second->GetOutDataAnchor(i); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_archor_ptr); + GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, + ErrorManager::GetInstance().ATCReportErrMessage( + "E11013", {"opname1", "opname2"}, {top_node_iter->second->GetName(), node->GetName()}); + GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to to op[%s].", + top_node_iter->second->GetName().c_str(), node->GetName().c_str()); + return INTERNAL_ERROR;); + index++; + } + } + } + + return SUCCESS; +} + +bool CaffeModelParser::CheckValidLayer(const domi::caffe::LayerParameter &layer) { + if (layer.include_size() != 0) { + bool filter_flag = false; + for (int32_t j = 0; j < layer.include_size(); j++) { + // Determine whether there is a data node for train in a Caffe model + if (layer.include(j).phase() == domi::caffe::TRAIN) { + filter_flag = true; + break; + } + } + + if (filter_flag) { + // If the phase of the data node's include is train, the data node ignores + return false; + } + } + + return true; +} + +bool CaffeModelParser::IsInplaceTopBlob(const domi::caffe::LayerParameter &layer, const std::string &top_name) { + for (auto &bottom_name : layer.bottom()) { + if (top_name == bottom_name) { + return true; + } + } + return false; +} + +std::string CaffeModelParser::RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, + int index) { + return (top_name + "_" + layer.name() + "_" + std::to_string(index)); +} + +Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) { + // Add layer in the model to PreChecker and check the general parameters + PreChecker::Instance().SetModelName(net.name()); + for (int i = 0; i < net.layer_size(); i++) { + const LayerParameter &layer = net.layer(i); + + // Skip training related layers and python layers + if (!CheckValidLayer(layer) || layer.type() == kPython) { + continue; + } + + // validate opname + string mode = "^[A-Za-z0-9./_-]+$"; + if (!ge::parser::ValidateStr(layer.name(), mode)) { + ErrorManager::GetInstance().ATCReportErrMessage("E11018", {"opname"}, {layer.name()}); + GELOGE(ge::FAILED, + "Parse caffe pbtxt validate op[%s] failed, opname can only contain " + "'a-z' 'A-Z' '0-9' '-' '.' '_' '/'", + layer.name().c_str()); + return ge::FAILED; + } + + GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&layer, layer.name(), layer.type()), + "Add layer to PreChecker failed, layer name: %s.", layer.name().c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&layer) != SUCCESS, return FAILED, + "Check op[%s] failed, name repeat in caffe prototxt.", layer.name().c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&layer) != SUCCESS, return FAILED, + "Check op[%s]'s optype failed, type is not supported.", layer.name().c_str()); + } + + return SUCCESS; +} + +Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) { + bool has_error = false; + + GE_CHK_BOOL_RET_STATUS(data != nullptr, FAILED, "model data is nullptr."); + GE_CHK_BOOL_RET_STATUS(graph != nullptr, FAILED, "graph is nullptr."); + + domi::caffe::NetParameter proto_message; + + // Get Caffe network model information + if (!ge::parser::ReadProtoFromMem(data, static_cast(size), &proto_message)) { + GELOGE(FAILED, "read proto from text ret fail"); + return FAILED; + } + + GE_CHK_BOOL_RET_STATUS( + !(proto_message.layer_size() == 0 && proto_message.layers_size() > 0), FAILED, + "The model file is consisted of layers-structure which is deprecated in caffe and unsupport in OMG. " + "It is recommended to convert layers-structure to layer-structure by caffe tool."); + GE_CHK_BOOL_RET_STATUS((proto_message.layer_size() != 0), FAILED, + "net layer num is zero, prototxt file may be invalid."); + + // Set network name + GE_IF_BOOL_EXEC((proto_message.has_name()), graph->SetName(proto_message.name())); + + // Add layer in the model to PreChecker, and perform general checks + GE_RETURN_IF_ERROR(PreCheck(proto_message)); + has_error = PreChecker::Instance().HasError(); + + if (ReorderInput(proto_message) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Reorder input failed."); + return INTERNAL_ERROR; + } + + bool input_data_flag = false; + + // Process input of type input + CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; + GELOGE(FAILED, "ParseInput ret fail.")); + + // build output layer + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(graph->GetName() + "_" + ge::parser::NODE_NAME_NET_OUTPUT); + layer->set_type(ge::parser::NETOUTPUT); + + int32_t layer_count = proto_message.layer_size(); + std::map inplace_blob_name_remapping; + // Map of operator name and occurrence times + std::map layer_name_map; + + // + std::map> layer_params_map; + // same param name set + // std::map, std::vector> params_share_map; + for (int32_t i = 0; i < layer_count; i++) { + domi::caffe::LayerParameter &layer = const_cast(proto_message.layer(i)); + + GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.", + layer.name().c_str(), layer.type().c_str()); + + CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && (input_data_flag == true)), has_error = true; + GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str())); + + // All layer names cannot be duplicate + // 20181208 Modified to support the existence of duplicate operators in Caffe model + GE_IF_BOOL_EXEC(layer_name_map.find(layer.name()) != layer_name_map.end(), + // duplicate operator modification + string new_name = layer.name() + "_same_" + std::to_string(layer_name_map[layer.name()]); + // Times accumulation of duplicate operators + layer_name_map[layer.name()]++; + // Set the name in proto and layer + domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(i); + duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) + + // Insert the new operator name, the number of times of duplicate name is recorded as 1 + layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); + + // Do not exit immediately when there is an error, wait until all errors are collected before exiting + Status ret = AddNode(layer, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Caffe parser add node fail."); + has_error = true; + continue; + } + + // parse ParamSpec + std::vector v_param_names; + for (int i = 0; i < layer.param_size(); i++) { + const domi::caffe::ParamSpec ¶m = layer.param(i); + GE_IF_BOOL_EXEC((param.has_name()), v_param_names.emplace_back(param.name())); + } + + // Save the layer with param name parameter to map + GE_IF_BOOL_EXEC((v_param_names.size() > 0), layer_params_map.emplace(layer.name(), v_param_names)); + + GE_RETURN_WITH_LOG_IF_ERROR(AddBlobsToMap(layer, inplace_blob_name_remapping), + "Caffe parser add blobs to map ret fail."); + } + // Find a layer with the same param name and save it to graph + GE_RETURN_WITH_LOG_IF_ERROR(FindShareParamLayers(layer_params_map), + "Caffe parser find share param layers map ret fail."); + + // Exit if an error occurs + GE_IF_BOOL_EXEC(has_error, return FAILED); + + GE_CHK_BOOL_RET_STATUS(top_blobs_map_.size() > 0, FAILED, "current net has no output!"); + + GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); + + if (!(ge::GetParserContext().user_out_nodes.empty())) { + GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed."); + } else { + GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail."); + } + GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); + + auto nodes = graph->GetDirectNode(); + GELOGI("graph node size = %zu.", nodes.size()); + for (auto &node : nodes) { + GELOGI("node name = %s.", node->GetName().c_str()); + for (auto &out_node : node->GetOutDataNodes()) { + GELOGI("out node name = %s.", out_node->GetName().c_str()); + } + } + + return SUCCESS; +} + +Status CaffeModelParser::Parse(const char *model_path, ge::Graph &graph) { + GE_CHECK_NOTNULL(model_path); + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + Status ret = Parse(model_path, compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Parser model for graph %s failed.", graph.GetName().c_str()); + return ret; + } + + GELOGI("Parser model for graph %s success.", graph.GetName().c_str()); + return SUCCESS; +} + +void CaffeModelParser::SaveOrigionLayerTops(domi::caffe::LayerParameter &layer) { + string name = layer.name(); + vector tops; + for (auto top : layer.top()) { + tops.push_back(top); + } + auto it = layer_tops_map_.find(name); + if (it == layer_tops_map_.end()) { + layer_tops_map_[name] = tops; + } + return; +} + +Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) { + bool has_error = false; + GE_CHECK_NOTNULL(model_path); + GE_CHECK_NOTNULL(graph); + GELOGI("Caffe Parse model file %s", model_path); + + PreChecker::Instance().Clear(); + + domi::caffe::NetParameter proto_message; + + // Get Caffe network model information + if (ReadModelWithoutWarning(model_path, &proto_message) != SUCCESS) { + GELOGE(FAILED, "read caffe model from text ret fail, model path: %s.", model_path); + return FAILED; + } + + // parse network model by custom proto and get custom operators + string custom_proto_path = ge::GetParserContext().custom_proto_path + "custom.proto"; + string caffe_proto_path = ge::GetParserContext().caffe_proto_path + "caffe.proto"; + Status result = CustomProtoParse(model_path, custom_proto_path, caffe_proto_path, custom_operator_); + if (result != SUCCESS) { + GELOGE(FAILED, "Parse model by custom proto failed, model: %s.", model_path); + return FAILED; + } + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + proto_message.layer_size() == 0 && proto_message.layers_size() > 0, + ErrorManager::GetInstance().ATCReportErrMessage("E11021", {"realpath"}, {model_path}); + return FAILED, + "The model file[%s] is consisted of layers-structure which is deprecated in Caffe " + "and unsupported in ATC. The \"layers\" should be changed to \"layer\".", + model_path); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto_message.layer_size() == 0), + ErrorManager::GetInstance().ATCReportErrMessage("E11022"); + return FAILED, "net layer num is zero, prototxt file may be invalid."); + + // Set network name + GE_IF_BOOL_EXEC((proto_message.has_name() && !proto_message.name().empty()), graph->SetName(proto_message.name())); + + // Add layer in the model to PreChecker, and perform general checks + GE_RETURN_IF_ERROR(PreCheck(proto_message)); + + if (PreChecker::Instance().HasError()) { + GELOGE(INTERNAL_ERROR, "Precheck failed. Please read check report."); + return FAILED; + } + + if (ReorderInput(proto_message) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Reorder input failed."); + return INTERNAL_ERROR; + } + + bool input_data_flag = false; + + // Process input of type input + CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; + GELOGE(FAILED, "ParseInput ret fail.")); + + // build output layer + domi::caffe::LayerParameter *layer = proto_message.add_layer(); + GE_CHECK_NOTNULL(layer); + layer->set_name(graph->GetName() + "_" + ge::parser::NODE_NAME_NET_OUTPUT); + layer->set_type(ge::parser::NETOUTPUT); + + int32_t layer_count = proto_message.layer_size(); + + if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { + GELOGW("The out_put info has top_name items."); + GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputNodeTopInfo(proto_message), + "Caffe parser parse output node-top info failed."); + ge::GetParserContext().user_out_nodes_top_vec.clear(); + } + + std::map inplace_blob_name_remapping; + // Map of operator name and occurrence times + std::map layer_name_map; + + // + std::map> layer_params_map; + // same param name set + for (int32_t i = 0; i < layer_count; i++) { + domi::caffe::LayerParameter &layer = const_cast(proto_message.layer(i)); + SaveOrigionLayerTops(layer); + GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.", + layer.name().c_str(), layer.type().c_str()); + + CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && (input_data_flag == true)), has_error = true; + GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str())); + + // All layer names cannot be duplicate + // Modified to support the existence of duplicate operators in Caffe model + GE_IF_BOOL_EXEC(layer_name_map.find(layer.name()) != layer_name_map.end(), + // duplicate operator modification + string new_name = layer.name() + "_same_" + std::to_string(layer_name_map[layer.name()]); + // Times accumulation of duplicate operators + layer_name_map[layer.name()]++; + // Set the name in proto and layer + domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(i); + duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) + + // Insert the new operator name, the number of times of duplicate name is recorded as 1 + layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); + + // Do not exit immediately when there is an error, wait until all errors are collected before exiting + Status ret = AddNode(layer, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Caffe parser add node fail."); + has_error = true; + continue; + } + + // parse ParamSpec + std::vector v_param_names; + for (int i = 0; i < layer.param_size(); i++) { + const domi::caffe::ParamSpec ¶m = layer.param(i); + GE_IF_BOOL_EXEC((param.has_name()), v_param_names.emplace_back(param.name())); + } + + // Save the layer with param name parameter to map + GE_IF_BOOL_EXEC((v_param_names.size() > 0), layer_params_map.emplace(layer.name(), v_param_names)); + + GE_RETURN_WITH_LOG_IF_ERROR(AddBlobsToMap(layer, inplace_blob_name_remapping), + "Caffe parser add blobs to map ret fail."); + } + // Find a layer with the same param name and save it to graph + GE_RETURN_WITH_LOG_IF_ERROR(FindShareParamLayers(layer_params_map), + "Caffe parser find share param layers map ret fail."); + + // Exit if an error occurs + GE_IF_BOOL_EXEC(has_error, return FAILED); + + GE_CHK_BOOL_RET_STATUS(top_blobs_map_.size() > 0, FAILED, "current net has no output!"); + + GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); + + if (!(ge::GetParserContext().user_out_nodes.empty())) { + GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed."); + } else { + GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail."); + } + GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); + GE_RETURN_WITH_LOG_IF_ERROR(GetLeafNodeTops(graph), "Caffe parser get out nodes top names failed."); + + auto nodes = graph->GetDirectNode(); + GELOGI("graph node size = %zu.", nodes.size()); + for (auto &node : nodes) { + GELOGI("node name = %s.", node->GetName().c_str()); + for (auto &out_node : node->GetOutDataNodes()) { + GELOGI("out node name = %s.", out_node->GetName().c_str()); + } + } + + return SUCCESS; +} + +Status CaffeModelParser::FindShareParamLayers(const std::map> &layer_params_map) { + for (auto p_iter = layer_params_map.begin(); p_iter != layer_params_map.end(); ++p_iter) { + for (auto p2_iter = p_iter; p2_iter != layer_params_map.end(); ++p2_iter) { + if (p_iter->first != p2_iter->first && p_iter->second == p2_iter->second) { + if (params_share_map.find(p_iter->second) == params_share_map.end()) { // Unsaved layer + vector tmp_v; + tmp_v.push_back(p_iter->first); + tmp_v.push_back(p2_iter->first); + params_share_map.emplace(p_iter->second, tmp_v); + } else { + vector::iterator iter = + find(params_share_map[p_iter->second].begin(), params_share_map[p_iter->second].end(), p2_iter->first); + if (iter == params_share_map[p_iter->second].end()) { + params_share_map[p_iter->second].push_back(p2_iter->first); + } + } + } + } + } + return SUCCESS; +} + +Status CaffeModelParser::ToJson(const char *model_file, const char *json_file) { + GE_CHK_BOOL_RET_STATUS(model_file != nullptr, FAILED, "model_file is nullptr."); + GE_CHK_BOOL_RET_STATUS(json_file != nullptr, FAILED, "json_file is nullptr."); + domi::caffe::NetParameter net; + nlohmann::json j; + + GE_RETURN_WITH_LOG_IF_FALSE(ReadModelWithoutWarning(model_file, &net) == SUCCESS, + "ReadModelWithoutWarning failed, Please Check file:%s.", model_file); + Pb2Json::Message2Json(net, set(), j, true); + return ModelSaver::SaveJsonToFile(json_file, j); +} + +Status CaffeModelParser::ReorderInput(domi::caffe::NetParameter &net) { + int layer_size = net.layer_size(); + for (int i = 0; i < layer_size; ++i) { + domi::caffe::LayerParameter *layer = net.mutable_layer(i); + const std::vector &move_input_vec = + domi::OpRegistry::Instance()->GetRemoveInputConfigure(layer->type()); + if (move_input_vec.empty()) { + continue; + } + for (const auto &it : move_input_vec) { + if (it.moveType == domi::OMG_INPUT_REORDER) { + auto inputs = layer->bottom(); + if (static_cast(inputs.size()) != it.input_order.size()) { + GELOGE(INTERNAL_ERROR, "Size of input is mismatched, new order size is %zu, input size is %d.", + it.input_order.size(), inputs.size()); + return INTERNAL_ERROR; + } + for (size_t j = 0; j < it.input_order.size(); ++j) { + int new_index = it.input_order[j]; + if (new_index < 0 || new_index >= inputs.size()) { + GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", layer->name().c_str(), new_index); + return INTERNAL_ERROR; + } + layer->set_bottom(j, inputs[new_index]); + } + GELOGI("The input sequence of the node has been rearranged, node name:%s.", layer->name().c_str()); + } + } + } + return SUCCESS; +} + +Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) { + if (data == nullptr) { + GELOGE(PARAM_INVALID, "Caffe weights data is nullptr"); + return PARAM_INVALID; + } + + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "Caffe weights graph is nullptr"); + return PARAM_INVALID; + } + + // Resolve proto file to netparameter + NetParameter proto; + bool success = ge::parser::ReadProtoFromArray(reinterpret_cast(data), static_cast(size), &proto); + if (!success) { + GELOGE(domi::PARSE_WEIGHTS_FAILED, "ReadProto from Memory fail"); + return domi::PARSE_WEIGHTS_FAILED; + } + + // Convert netparameter to opdef and save to graph + Status status = ConvertNetParameter(proto, graph); + GE_IF_BOOL_EXEC(status != SUCCESS, GELOGE(FAILED, "Parse weights ConvertNetParameter failed, status=%d", status); + return domi::PARSE_WEIGHTS_FAILED;); + + return SUCCESS; +} + +Status CaffeWeightsParser::Parse(const char *file, ge::Graph &graph) { + GE_CHECK_NOTNULL(file); + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + Status ret = Parse(file, compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Parser weight for graph %s failed.", graph.GetName().c_str()); + return ret; + } + + GELOGI("Parser weight for graph %s success.", graph.GetName().c_str()); + return SUCCESS; +} + +Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) { + if (file == nullptr) { + GELOGE(FAILED, "Caffe weights parse fail, Parameter file invalid"); + return PARAM_INVALID; + } + + if (graph == nullptr) { + GELOGE(FAILED, "Caffe weights parse fail, Parameter graph invalid"); + return PARAM_INVALID; + } + + GELOGI("Parse weights file:%s", file); + + string caffe_proto_path = ge::GetParserContext().caffe_proto_path + "caffe.proto"; + string custom_proto_path = ge::GetParserContext().custom_proto_path + "custom.proto"; + ProtoFileParser proto_file_parser; + + GELOGD("caffe_proto_path:%s custom_proto_path:%s", caffe_proto_path.c_str(), custom_proto_path.c_str()); + string fusion_proto_file; + string custom_proto_file = ge::parser::RealPath(custom_proto_path.c_str()); + if (custom_proto_file.empty()) { + GELOGW("custom_proto_path:%s is not existed", custom_proto_path.c_str()); + fusion_proto_file = caffe_proto_path; + } else { + if (proto_file_parser.CombineProtoFile(caffe_proto_path.c_str(), custom_proto_path.c_str(),\ + fusion_proto_file) != SUCCESS) { + GELOGE(FAILED, "Create tmp fusion proto file from caffe and custom proto failed."); + return FAILED; + } + } + + string fusion_proto_path = ge::parser::RealPath(fusion_proto_file.c_str()); + GELOGI("Get fusion proto file[%s]-[%s].", fusion_proto_file.c_str(), fusion_proto_path.c_str()); + if (fusion_proto_path.empty()) { + GELOGE(FAILED, "Fusion proto file path [%s]-[%s] is not real existed.", fusion_proto_file.c_str(), + fusion_proto_path.c_str()); + return FAILED; + } + + string fusion_proto_name; + if (CheckPathValid(file, fusion_proto_file, fusion_proto_path, fusion_proto_name) != SUCCESS) { + GELOGE(FAILED, "CheckPathValid of weight file[%s] and tmp proto[%s] failed.", file, + fusion_proto_file.c_str()); + return FAILED; + } + + GELOGI("Start to parse weight: %s by fusion proto: %s.", file, fusion_proto_file.c_str()); + Status status = ParseWeightByFusionProto(file, fusion_proto_path, fusion_proto_name, graph); + if (status != SUCCESS) { + GELOGE(FAILED, "Parse weight by fusion proto failed."); + return status; + } + + status = CheckNodes(graph); + if (status != SUCCESS) { + GELOGE(ge::GRAPH_FAILED, "Check Nodes failed, status=%u", status); + return domi::PARSE_WEIGHTS_FAILED; + } + + return SUCCESS; +} + +Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, const string &fusion_proto_path, + const string &fusion_proto_name, ge::ComputeGraphPtr &graph) { + google::protobuf::compiler::DiskSourceTree source_tree; + source_tree.MapPath(kProjectRoot, fusion_proto_path); + google::protobuf::compiler::Importer importer(&source_tree, nullptr); + importer.Import(fusion_proto_name.c_str()); + GELOGI("Import fusion proto %s success, proto_name %s.", fusion_proto_path.c_str(), fusion_proto_name.c_str()); + + const google::protobuf::Descriptor *descriptor = importer.pool()->FindMessageTypeByName(kBeginningMessageType); + if (descriptor == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19021", {"reason"}, {"Does not find domi.caffe.NetParameter in google::protobuf::Descriptor."}); + GELOGE(FAILED, "Does not find domi.caffe.NetParameter in google::protobuf::Descriptor, \ + which may be caused by problematic fusion proto."); + return FAILED; + } + google::protobuf::DynamicMessageFactory factory; + const google::protobuf::Message *proto = factory.GetPrototype(descriptor); + GE_CHECK_NOTNULL(proto); + google::protobuf::Message *message = proto->New(); + GE_CHECK_NOTNULL(message); + + if (!ge::parser::ReadProtoFromBinaryFile(weight_path, message)) { + delete message; + message = nullptr; + ErrorManager::GetInstance().ATCReportErrMessage( + "E19021", {"reason"}, {"ReadProtoFromBinaryFile based on fusion proto failed."}); + GELOGE(FAILED, "ReadProtoFromBinaryFile %s failed.", weight_path); + return FAILED; + } + + GELOGI("Start to parse weight file: %s.", weight_path); + const google::protobuf::Descriptor *layer_descriptor = importer.pool()->FindMessageTypeByName(kLayerMessageType); + if (layer_descriptor == nullptr) { + delete message; + message = nullptr; + ErrorManager::GetInstance().ATCReportErrMessage( + "E19021", {"reason"}, {"Does not find domi.caffe.LayerParameter in google::protobuf::Descriptor"}); + GELOGE(FAILED, "Does not find domi.caffe.LayerParameter in google::protobuf::Descriptor"); + return FAILED; + } + + if (CheckLayersSize(message) != SUCCESS) { + delete message; + message = nullptr; + return FAILED; + } + + if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) { + delete message; + message = nullptr; + ErrorManager::GetInstance().ATCReportErrMessage( + "E19021", {"reason"}, {"ParseLayerParameter failed."}); + GELOGE(FAILED, "ParseLayerParameter failed."); + return FAILED; + } + + delete message; + message = nullptr; + GELOGI("Parse weight: %s by proto: %s success.", weight_path, fusion_proto_path.c_str()); + return SUCCESS; +} + +Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, + const google::protobuf::Message *message, + ge::ComputeGraphPtr &graph) { + auto field_name = layer_descriptor->FindFieldByName(kFieldName); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); + auto field_type = layer_descriptor->FindFieldByName(kFieldType); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); + + const google::protobuf::Reflection *reflection = message->GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); + vector field_desc; + reflection->ListFields(*message, &field_desc); + + NetParameter tmp_net; + for (auto &field : field_desc) { + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); + // Only care about layers + GE_CHECK_NOTNULL(field); + if (field->name() != kLayerName) { + continue; + } + if (!field->is_repeated()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {field->name().c_str(), "LayerParameter should be repeated"}); + GELOGE(FAILED, "LayerParameter should be repeated."); + return FAILED; + } + + int field_size = reflection->FieldSize(*message, field); + GELOGI("Total Layer num of model file is %d", field_size); + for (int i = 0; i < field_size; ++i) { + const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); + + LayerParameter *layer = tmp_net.add_layer(); + if (ConvertLayerProto(&layer_message, layer) != SUCCESS) { + GELOGE(FAILED, "Convert message to layer proto failed."); + return FAILED; + } + + const string &layer_name = layer->name(); + if (skiped_layer_type_.find(layer->type()) != skiped_layer_type_.end()) { + GELOGI("Skip layer %s", layer_name.c_str()); + continue; + } + + GELOGI("Parse layer %s", layer_name.c_str()); + auto ret = ConvertLayerParameter(layer, graph); + if (ret != SUCCESS) { + return ret; + } + } + } + return SUCCESS; +} + +Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message, + google::protobuf::Message *layer) { + const google::protobuf::Reflection *layer_reflection = message->GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); + vector field_desc; + layer_reflection->ListFields(*message, &field_desc); + + for (auto &field : field_desc) { + GE_CHECK_NOTNULL(field); + if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) { + GELOGE(FAILED, "Parse field %s failed.", field->name().c_str()); + return FAILED; + } + } + return SUCCESS; +} + +Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, + google::protobuf::Message *layer) { + GELOGD("Start to parse field: %s.", field->name().c_str()); + domi::caffe::LayerParameter *layer_proto = reinterpret_cast(layer); + string filed_name = field->name(); +#define CASE_FIELD_NAME(kName, method) \ + if (filed_name == kField##kName) { \ + string value = reflection->GetString(*message, field); \ + GELOGD("Parse result(%s : %s)", filed_name.c_str(), value.c_str());\ + layer_proto->set_##method(value); \ + return SUCCESS; \ + } + CASE_FIELD_NAME(Name, name); + CASE_FIELD_NAME(Type, type); +#undef CASE_FIELD_NAME +#define CASE_FIELD_NAME_REPEATED(kName, method) \ + if (filed_name == kField##kName) { \ + int field_size = reflection->FieldSize(*message, field); \ + for (int i = 0; i < field_size; ++i) { \ + string value = reflection->GetRepeatedString(*message, field, i);\ + layer_proto->add_##method(value); \ + } \ + return SUCCESS; \ + } + CASE_FIELD_NAME_REPEATED(Bottom, bottom); + CASE_FIELD_NAME_REPEATED(Top, top); +#undef CASE_FIELD_NAME_REPEATED + if (filed_name == kFieldBlobs) { + int field_size = reflection->FieldSize(*message, field); + for (int i = 0; i < field_size; ++i) { + BlobProto *item_message = layer_proto->add_blobs(); + const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); + if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { + GELOGE(FAILED, "ParseLayerField of field: %s failed.", field->name().c_str()); + return FAILED; + } + } + return SUCCESS; + } + if (filed_name == kFieldConvParam) { + const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); + ConvertConvParamProto(&sub_message, conv_param); + } + if (filed_name == kFieldInnerPro) { + const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); + ConvertInnerProdcutProto(&sub_message, inner_product); + } + return SUCCESS; +} + +Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, + google::protobuf::Message *blobs) { + const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); + vector field_desc; + blobs_reflection->ListFields(*message, &field_desc); + + domi::caffe::BlobProto *blobs_proto = reinterpret_cast(blobs); + + for (auto &field : field_desc) { + GE_CHECK_NOTNULL(field); + string feild_name = field->name(); +#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \ + if (feild_name == #kName) { \ + int field_size = blobs_reflection->FieldSize(*message, field); \ + for (int i = 0; i < field_size; ++i) { \ + valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \ + blobs_proto->add_##name(value); \ + } \ + continue; \ + } + CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data); + CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff); + CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data); + CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff); + CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data); + CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data); +#undef CASE_BLOBS_FIELD_NAME_REPEATED +#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \ + if (feild_name == #kName) { \ + valuetype value = blobs_reflection->Get##method(*message, field); \ + blobs_proto->set_##name(value); \ + continue; \ + } + CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data); + CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num); + CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels); + CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height); + CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width); +#undef CASE_BLOBS_FIELD_NAME + if (feild_name == kFieldShape) { + const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field); + domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); + ConvertBlobShapeProto(&sub_message, blob_shape); + } + } + return SUCCESS; +} + +Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message, + google::protobuf::Message *dest_message) { + const google::protobuf::Reflection *reflection = message->GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); + vector field_desc; + reflection->ListFields(*message, &field_desc); + + domi::caffe::BlobShape *shape_proto = reinterpret_cast(dest_message); + + for (auto &field : field_desc) { + if (field->name() != kFieldDim) { + continue; + } + int field_size = reflection->FieldSize(*message, field); + for (int i = 0; i < field_size; ++i) { + int64_t value = reflection->GetRepeatedInt64(*message, field, i); + shape_proto->add_dim(value); + } + } + return SUCCESS; +} + +Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message, + google::protobuf::Message *dest_message) { + const google::protobuf::Reflection *reflection = message->GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); + vector field_desc; + reflection->ListFields(*message, &field_desc); + + domi::caffe::ConvolutionParameter *conv_param_proto = + reinterpret_cast(dest_message); + + for (auto &field : field_desc) { + if (field->name() != kFieldBiasTerm) { + continue; + } + bool value = reflection->GetBool(*message, field); + conv_param_proto->set_bias_term(value); + } + return SUCCESS; +} + +Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message, + google::protobuf::Message *dest_message) { + const google::protobuf::Reflection *reflection = message->GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); + vector field_desc; + reflection->ListFields(*message, &field_desc); + + domi::caffe::InnerProductParameter *inner_product_proto = + reinterpret_cast(dest_message); + + for (auto &field : field_desc) { + if (field->name() != kFieldBiasTerm) { + continue; + } + bool value = reflection->GetBool(*message, field); + inner_product_proto->set_bias_term(value); + } + return SUCCESS; +} + +Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) { + const google::protobuf::Reflection *reflection = message->GetReflection(); + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); + vector field_desc; + reflection->ListFields(*message, &field_desc); + + int num_layer = 0; + int num_layers = 0; + + for (auto &field : field_desc) { + CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); + // Only care about layers + if (field->name() != kLayerName && field->name() != kLayersName) { + continue; + } + if (!field->is_repeated()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11032", {"name", "reason"}, + {field->name().c_str(), "LayerParameter should be repeated"}); + GELOGE(FAILED, "LayerParameter should be repeated."); + return FAILED; + } + + int field_size = reflection->FieldSize(*message, field); + if (field->name() == kLayerName) { + num_layer = field_size; + } else { + num_layers = field_size; + } + } + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(num_layer == 0 && num_layers > 0, + ErrorManager::GetInstance().ATCReportErrMessage("E11023"); + return FAILED, + "The weight file is consisted of layers-structure which is deprecated in Caffe " + "and unsupported in ATC. The \"layers\" should be changed to \"layer\"."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((num_layer == 0), ErrorManager::GetInstance().ATCReportErrMessage("E11024"); + return FAILED, "Weight layer num is zero, weight file may be invalid."); + + return SUCCESS; +} + +Status CaffeWeightsParser::ConvertLayerParameter(const google::protobuf::Message *layer_message, + ge::ComputeGraphPtr &graph) { + vector need_share_layers; + const domi::caffe::LayerParameter *layer = reinterpret_cast(layer_message); + const string &layer_name = layer->name(); + const string &layer_type = layer->type(); + for (auto p_iter = params_share_map.begin(); p_iter != params_share_map.end(); ++p_iter) { + if (find(p_iter->second.begin(), p_iter->second.end(), layer_name) != p_iter->second.end()) { + GELOGI("layer:%s need share weights !", layer_name.c_str()); + need_share_layers = p_iter->second; + } + } + + if (need_share_layers.size() == 0) { + need_share_layers.push_back(layer_name); + } + + for (auto share_iter = need_share_layers.begin(); share_iter != need_share_layers.end(); ++share_iter) { + // Find created nodes + string layer_name = *share_iter; + GE_IF_BOOL_EXEC(layer_name_record_map_.find(layer_name) != layer_name_record_map_.end(), + string temp_layer_name = layer_name; + // duplicate operator modification + layer_name = temp_layer_name + "_same_" + std::to_string(layer_name_record_map_[temp_layer_name]); + // Times accumulation of duplicate operators + layer_name_record_map_[temp_layer_name]++; + // Set the name in proto and layer + ) + ge::NodePtr node = graph->FindNode(layer_name); + layer_name_record_map_.insert(std::make_pair(layer_name, kNumOne)); + if (node == nullptr) { + // If there are redundant layers in the weight file, they should be skipped rather than returned with an error. + GELOGI("Layer %s not found in graph", layer_name.c_str()); + continue; + } + + // The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. + auto iter = caffe_op_map.find(layer_type); + if (iter == caffe_op_map.end()) { + GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer_type.c_str(), layer_name.c_str()); + continue; + } + GELOGD("Caffe layer name: %s , layer type: %s.", layer_name.c_str(), layer_type.c_str()); + string op_type = iter->second; + + // create OpParser + std::shared_ptr factory = OpParserFactory::Instance(domi::CAFFE); + GE_CHECK_NOTNULL(factory); + std::shared_ptr op_parser = factory->CreateOpParser(op_type); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + (op_parser.get() == nullptr), + ErrorManager::GetInstance().ATCReportErrMessage("E11025", {"opname", "optype"}, {layer_name, op_type}); + return FAILED, "Op[%s] create OpParser failed, optype is %s", layer_name.c_str(), op_type.c_str()); + + // Parsing weight information through op parser + Status status = op_parser->ParseWeights(layer_message, node); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + (status != SUCCESS), ErrorManager::GetInstance().ATCReportErrMessage("E11026", {"opname"}, {layer_name}); + return status, "Parse op weights for op[%s] failed", layer_name.c_str()); + } + return SUCCESS; +} + +Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + for (const ge::NodePtr &node : graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (const auto &in_anchor_ptr : node->GetAllInDataAnchors()) { + if (op_desc->GetType() == ge::parser::DATA || op_desc->GetType() == ge::parser::CONSTANT) { + continue; + } + auto index = in_anchor_ptr->GetIdx(); + auto input_desc = op_desc->MutableInputDesc(index); + if (in_anchor_ptr->GetPeerAnchors().empty() && input_desc != nullptr) { + if (layer_name_record_map_.find(node->GetName()) == layer_name_record_map_.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E11029", {"opname"}, {node->GetName()}); + GELOGE(ge::GRAPH_FAILED, "Op[%s] in model file does not exist in weight file.", node->GetName().c_str()); + PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::PARAM_INVALID, + "Node does not exist in weight file."); + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, + {node->GetName(), std::to_string(in_anchor_ptr->GetIdx())}); + GELOGE(ge::GRAPH_FAILED, "Op[%s]'s input %d is not linked.", node->GetName().c_str(), + in_anchor_ptr->GetIdx()); + string check_msg = "input " + to_string(in_anchor_ptr->GetIdx()) + "is not linked in weight file"; + PreChecker::Instance().RefreshErrorMessageByName(node->GetName(), PreChecker::PARAM_INVALID, check_msg); + } + return FAILED; + } + } + } + return SUCCESS; +} + +Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + int num_layer = param.layer_size(); + int num_layers = param.layers_size(); + + // Operator name and occurrence map, handle duplicate operators + std::map layer_name_map; + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(num_layer == 0 && num_layers > 0, + ErrorManager::GetInstance().ATCReportErrMessage("E11023"); + return FAILED, + "The weight file is consisted of layers-structure which is deprecated in Caffe " + "and unsupported in ATC. The \"layers\" should be changed to \"layer\"."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((num_layer == 0), ErrorManager::GetInstance().ATCReportErrMessage("E11024"); + return FAILED, "weight layer num is zero, weight file may be invalid."); + + for (int i = 0; i < num_layer; ++i) { + const LayerParameter &layer = param.layer(i); + const string &layer_name = layer.name(); + + // Skip some layer types + if (skiped_layer_type_.find(layer.type()) != skiped_layer_type_.end()) { + GELOGI("Skip layer %s", layer_name.c_str()); + continue; + } + + GELOGI("Parse layer %s", layer_name.c_str()); + + vector need_share_layers; + + for (auto p_iter = params_share_map.begin(); p_iter != params_share_map.end(); ++p_iter) { + if (find(p_iter->second.begin(), p_iter->second.end(), layer_name) != p_iter->second.end()) { + GELOGI("Layer: %s need share weights !", layer_name.c_str()); + need_share_layers = p_iter->second; + } + } + + if (need_share_layers.size() == 0) { + need_share_layers.push_back(layer_name); + } + + for (auto share_iter = need_share_layers.begin(); share_iter != need_share_layers.end(); ++share_iter) { + // Find created nodes + string layer_name = *share_iter; + GE_IF_BOOL_EXEC(layer_name_map.find(layer_name) != layer_name_map.end(), string temp_layer_name = layer_name; + // duplicate operator modification + layer_name = temp_layer_name + "_same_" + std::to_string(layer_name_map[temp_layer_name]); + // Times accumulation of duplicate operators + layer_name_map[temp_layer_name]++; + // Set the name in proto and layer + ) + ge::NodePtr node = graph->FindNode(layer_name); + layer_name_map.insert(std::make_pair(layer_name, kNumOne)); + if (node == nullptr) { + // If there are redundant layers in the weight file, they should be skipped rather than returned with an error. + GELOGI("Layer %s not found in graph", layer_name.c_str()); + continue; + } + + // The weight processing also needs to judge the duplicate operator, which is reserved here and processed later. + auto iter = caffe_op_map.find(layer.type()); + if (iter == caffe_op_map.end()) { + GELOGW("Unrecognized layer type %s , layer name: %s, layer ignored.", layer.type().c_str(), layer_name.c_str()); + continue; + } + GELOGD("Caffe layer name: %s , layer type: %s.", layer_name.c_str(), layer.type().c_str()); + string op_type = iter->second; + + // create OpParser + std::shared_ptr factory = OpParserFactory::Instance(domi::CAFFE); + GE_CHECK_NOTNULL(factory); + std::shared_ptr op_parser = factory->CreateOpParser(op_type); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + (op_parser.get() == nullptr), + ErrorManager::GetInstance().ATCReportErrMessage("E11025", {"opname", "optype"}, {layer_name, op_type}); + return FAILED, "Op[%s] create OpParser failed, optype is %s", layer_name.c_str(), op_type.c_str()); + + // Parsing weight information through op parser + Status status = op_parser->ParseWeights(&layer, node); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + (status != SUCCESS), ErrorManager::GetInstance().ATCReportErrMessage("E11026", {"opname"}, {layer_name}); + return status, "Parse op weights for op[%s] failed", layer_name.c_str()); + } + } + + return SUCCESS; +} + +Status CaffeModelParser::GetLeafNodeTops(ge::ComputeGraphPtr &graph) { + auto netout = graph->FindFirstNodeMatchType(ge::parser::NETOUTPUT); + GE_CHECK_NOTNULL(netout); + for (const auto &in_anchor : netout->GetAllInDataAnchors()) { + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_data_anchor); + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_out_data_node); + int idx = peer_out_data_anchor->GetIdx(); + string node_name = peer_out_data_node->GetName(); + auto layer_iter = layer_tops_map_.find(node_name); + if (layer_iter != layer_tops_map_.end()) { + ge::GetParserContext().out_top_names.push_back(layer_iter->second[idx]); + GELOGI("The top of out node [%s] is [%s]", node_name.c_str(), layer_iter->second[idx].c_str()); + } else { + GELOGW("The out node [%s] can not find its top.", node_name.c_str()); + } + } + return SUCCESS; +} + +Status CaffeModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { + return SUCCESS; +} +Status CaffeModelParser::ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, + domi::GetGraphCallback callback, + ge::ComputeGraphPtr &graph) { + return SUCCESS; +} +} // namespace ge + +namespace domi { + REGISTER_MODEL_PARSER_CREATOR(CAFFE, ge::CaffeModelParser); + REGISTER_WEIGHTS_PARSER_CREATOR(CAFFE, ge::CaffeWeightsParser); +} diff --git a/parser/caffe/caffe_parser.h b/parser/caffe/caffe_parser.h new file mode 100644 index 0000000..ef3d1f1 --- /dev/null +++ b/parser/caffe/caffe_parser.h @@ -0,0 +1,433 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_CAFFE_CAFFE_PARSER_H_ +#define PARSER_CAFFE_CAFFE_PARSER_H_ + +#include +#include +#include +#include +#include +#include +#include "external/graph/operator.h" +#include "omg/parser/op_parser.h" +#include "omg/parser/model_parser.h" +#include "omg/parser/weights_parser.h" +#include "proto/caffe/caffe.pb.h" +#include "proto/om.pb.h" + +namespace ge { +using domi::caffe::NetParameter; +using std::map; +using std::set; +using std::string; +using std::unordered_map; +using std::vector; +static std::map, std::vector> params_share_map; + +class CaffeModelParser : public domi::ModelParser { + public: + CaffeModelParser() {} + virtual ~CaffeModelParser() {} + + /** + * @ingroup domi_omg + * @brief Parse the relevant data from the model file and save it to graph + * @param [in] file Path of model file + * @param [in|out] graph graph for saving model information + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status Parse(const char *file, ge::Graph &graph) override; + Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; + + /** + * @ingroup domi_omg + * @brief Convert model files to JSON format + * @param [in] model_file Path of model file + * @param [out] json_file Converted JSON file path + * @return SUCCESS parse successfully + * @return others parse failed + */ + Status ToJson(const char *model_file, const char *json_file) override; + /** + * @ingroup domi_omg + * @brief Parse the relevant data from the model file and save it to graph + * @param [in] graph_def input tensorflow model + * @param [in|out] graph graph for saving model information + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override; + Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, domi::GetGraphCallback callback, + ge::ComputeGraphPtr &graph) override; + /* + * @ingroup domi_omg + * @brief Mapping CAFFE's datatype to GE's datatype + * @param [in] type, datatype types of operators in CAFFE networks + * @return ge::DataType + */ + ge::DataType ConvertToGeDataType(const uint32_t type) override { return ge::DT_FLOAT; } + + Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override { + return domi::SUCCESS; + } + + private: + Status Parse(const char *file, ge::ComputeGraphPtr &graph); + + /** + * @ingroup domi_omg + * @brief Add the Layer in the model to the PreChecker + * @param [in] net caffe net information + * @return SUCCESS build successfully + * @return FAILED build failed + */ + Status PreCheck(const domi::caffe::NetParameter &net); + + /** + * @ingroup domi_omg + * @brief Parsing input related information from model files + * @param [in] proto_message caffe net information + * @param [in|out] net_input_name Used to store the acquired input name information + * @param [in|out] net_input_data Used to store the acquired input data information + * @return SUCCESS build successfully + * @return FAILED build failed + */ + Status ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag); + + /* + * @ingroup domi_omg + * @brief Parse model by custom proto and save info to operators + * @param [in] model_path, file path of model(prototxt file) + * @param [in] custom_proto, file path of custom proto + * @param [in] caffe_proto, file path of caffe proto + * @param [out] operators, operators saving custom info + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status CustomProtoParse(const char *model_path, const string &custom_proto, const string &caffe_proto, + std::vector &operators); + + /* + * @ingroup domi_omg + * @brief Parse model by custom proto and save info to operators + * @param [in] model_path, file path of model(prototxt file) + * @param [in] custom_proto_path, file path of custom proto + * @param [in] custom_proto_name, custom proto name + * @param [out] operators, operators saving custom info + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path, + const string &custom_proto_name, std::vector &operators); + + /* + * @ingroup domi_omg + * @brief Parse caffe proto file + * @param [in] proto_file, file path of caffe proto + * @param [out] identifier_op_map, identifer and op map + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status ParseProtoFile(const string &proto_file, std::map &identifier_op_map); + + /* + * @ingroup domi_omg + * @brief Save identifier op map info + * @param [in] line, line of proto + * @param [out] identifier_op_map, identifer and op map + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status SaveIdentifierOpMapInfo(const string &line, std::map &identifier_op_map); + + /* + * @ingroup domi_omg + * @brief Get op identifier + * @param [in] line, line of proto + * @param [out] identifier, identifer of op + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status GetIdentifier(const std::string &line, int32_t &identifier); + /* + * @ingroup domi_omg + * @brief Read caffe model and shield google warning + * @param [in] model_path, file path of model(prototxt file) + * @param [out] message, message saving custom info + * @return SUCCESS read file successfully + * @return FAILED read file failed + */ + Status ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message); + + /* + * @ingroup domi_omg + * @brief Read caffe model and save it to message + * @param [in] model_path, file path of model(prototxt file) + * @param [out] message, message saving custom info + * @return SUCCESS read file successfully + * @return FAILED read file failed + */ + Status ReadCaffeModelFromText(const char *model_path, google::protobuf::Message *message); + + /* + * @ingroup domi_omg + * @brief Parse layer message and save custom info to operators + * @param [in] layer_descriptor, layer description of message + * @param [in] message, message of model + * @param [out] operators, operators saving custom info + * @return SUCCESS parse layer successfully + * @return FAILED parse layer failed + */ + Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, + const google::protobuf::Message *message, std::vector &operators); + + /* + * @ingroup domi_omg + * @brief Create custom operator by op_name and op_type + * @param [in] op_name, name of operator + * @param [in] op_type, type of operator + * @param [in] message, message of model + * @param [in] index, index of field + * @param [out] operators, operators saving custom info + * @return SUCCESS create operator successfully + * @return FAILED create operator failed + */ + Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message, + int index, std::vector &operators); + + /* + * @ingroup domi_omg + * @brief Parse message and set operator attrs + * @param [in] message, message of model + * @param [in/out] depth, depth of recursion + * @param [out] ops, operator saving custom info + * @return SUCCESS parse message successfully + * @return FAILED parse message failed + */ + Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); + + /* + * @ingroup domi_omg + * @brief Parse field and set operator attrs + * @param [in] reflection, reflection of message + * @param [in] message, message of model + * @param [in] field, field of message + * @param [in/out] depth, depth of recursion + * @param [out] ops, operator saving custom info + * @return SUCCESS parse field successfully + * @return FAILED parse field failed + */ + Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); + + /* + * @ingroup domi_omg + * @brief Parse repeated field and set operator attrs + * @param [in] reflection, reflection of message + * @param [in] message, message of model + * @param [in] field, field of message + * @param [in/out] depth, depth of recursion + * @param [out] ops, operator saving custom info by vector + * @return SUCCESS parse field successfully + * @return FAILED parse field failed + */ + Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); + + /** + * @ingroup domi_omg + * @brief Add blob information to the bottom_blobs_map and top_blobs_map_ + * @param [in] layer layer information + * @param [in|out] inplace_blob_name_remapping save blob information + * @return Status + */ + Status AddBlobsToMap(const domi::caffe::LayerParameter &layer, + std::map &inplace_blob_name_remapping); + /** + * @ingroup domi_omg + * @brief Add node information to graph + * @param [in] layer layer infromation + * @param [in|out] graph graph for saving model information + * @return SUCCESS add successfully + * @return FAILED add failed + */ + Status AddNode(const domi::caffe::LayerParameter &layer, ge::ComputeGraphPtr &graph); + /** + * @ingroup domi_omg + * @brief Add edge information to graph + * @param [in|out] graph graph for saving model information + * @return SUCCESS add successfully + * @return FAILED add failed + */ + Status AddEdges(ge::ComputeGraphPtr &graph); + + /** + * @ingroup domi_omg + * @brief Add edge information to graph + * @param [in|out] graph graph for saving model information + * @return SUCCESS add successfully + * @return FAILED add failed + */ + Status AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph); + + /** + * @ingroup domi_omg + * @brief Check if the current layer is valid + * @return true valid + * @return false invalid + */ + bool CheckValidLayer(const domi::caffe::LayerParameter &layer); + + /** + * @ingroup domi_omg + * @brief Check whether the top of the current layer is 'Inplace' + * @return true is 'Inplace' + * @return false not is 'Inplace' + */ + bool IsInplaceTopBlob(const domi::caffe::LayerParameter &layer, const std::string &top_name); + + /** + * @ingroup domi_omg + * @brief Check whether the top of the current layer is user's specified output top + * @return true yes + * @return false no + */ + bool IsOutputTop(const string &op_name, int32_t index); + + /** + * @ingroup domi_omg + * @brief Find a layer set with the same param + * @param [in] Param name set of each layer + * @param [in|out] Layer set of the same param + * @return Status + */ + Status FindShareParamLayers(const std::map> &); + + Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer); + + Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, + const string &op_type); + + Status AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph); + + std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index); + + Status GetCustomOp(const domi::caffe::LayerParameter &layer, vector &operators); + + bool IsOpAttrEmpty(const ge::Operator &op, const std::string &type); + + Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op, + std::shared_ptr &op_parser); + + Status GetLeafNodeTops(ge::ComputeGraphPtr &graph); + + void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer); + + Status ReorderInput(domi::caffe::NetParameter &net); + + void AddOutputInfoToContext(string layer_name, int32_t top_index); + + Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message); + + std::map node_map; + + // key: blob name, value: layer name and index + std::unordered_map>> bottom_blobs_map_; + + // key: blob name, value: layer name and index + std::unordered_map>> top_blobs_map_; + + std::vector custom_operator_; + std::map> layer_tops_map_; +}; + +/** + * @ingroup domi_omg + * @brief Caffe weight parser + */ +class CaffeWeightsParser : public domi::WeightsParser { + public: + /** + * @ingroup domi_omg + * @brief Parse weight data from file and save to graph + * @param [in] file Path of weight file after training + * @param [in|out] graph Save weight information after parsing + * @return SUCCESS parse successfully + * @return PARAM_INVALID param invalid + * @return PARSE_WEIGHTS_FAILED parse failed + */ + Status Parse(const char *file, ge::Graph &graph) override; + + Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; + + private: + Status CheckNodes(ge::ComputeGraphPtr &graph); + /** + * @ingroup domi_omg + * @brief Convert netparameter to modedef and save in graph + * @param [in] param Caffe network parameters to be converted + * @param [in|out] graph Save weight information after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + static Status ConvertNetParameter(const NetParameter ¶m, ge::ComputeGraphPtr &graph); + + Status Parse(const char *file, ge::ComputeGraphPtr &graph); + + Status ParseWeightByFusionProto(const char *model_path, const string &custom_proto_path, + const string &custom_proto_name, ge::ComputeGraphPtr &graph); + + Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, + const google::protobuf::Message *message, + ge::ComputeGraphPtr &graph); + + Status ConvertLayerParameter(const google::protobuf::Message *layer_message, + ge::ComputeGraphPtr &graph); + + Status CheckLayersSize(const google::protobuf::Message *message); + + Status ConvertLayerProto(const google::protobuf::Message *message, + google::protobuf::Message *layer); + + Status ParseLayerField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, + google::protobuf::Message *layer); + + Status ConvertBlobsProto(const google::protobuf::Message *message, + google::protobuf::Message *blobs); + + Status ConvertBlobShapeProto(const google::protobuf::Message *message, + google::protobuf::Message *dest_message); + + Status ConvertInnerProdcutProto(const google::protobuf::Message *message, + google::protobuf::Message *dest_message); + + Status ConvertConvParamProto(const google::protobuf::Message *message, + google::protobuf::Message *dest_message); + /** + * @ingroup domi_omg + * @brief Layer types to be ignored in weight resolution + */ + static const set skiped_layer_type_; + std::map layer_name_record_map_; +}; +} // namespace domi + +#endif // PARSER_CAFFE_CAFFE_PARSER_H_ diff --git a/parser/caffe/caffe_reshape_parser.cc b/parser/caffe/caffe_reshape_parser.cc new file mode 100644 index 0000000..c32c2c1 --- /dev/null +++ b/parser/caffe/caffe_reshape_parser.cc @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/caffe/caffe_reshape_parser.h" +#include +#include "common/debug/log.h" +#include "common/ge/ge_util.h" +#include "common/op/op_parser_util.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/utils/graph_utils.h" +#include "parser/common/op_parser_factory.h" +#include "framework/omg/parser/parser_types.h" +#include "proto/om.pb.h" + +using namespace ge::parser; +using domi::CAFFE; + +namespace ge { +namespace { +const int kAnchorIndexZero = 0; +const int kAnchorIndexOne = 1; +} // namespace + +Status CaffeReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op); + const LayerParameter *layer = DOMI_DYNAMIC_CAST(op_src); + if (layer == nullptr) { + GELOGE(FAILED, "Reshape Dynamic cast op_src to LayerParameter failed"); + return FAILED; + } + + GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); + const ReshapeParameter &reshape_parameter = layer->reshape_param(); + + GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_AXIS, RESHAPE_AXIS_DEFAULT_VALUE)), + GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return + GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_NUM_AXES, RESHAPE_NUM_AXES_DEFAULT_VALUE)), + GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return + + if (!reshape_parameter.has_shape()) { + GELOGE(FAILED, "Reshape has no shape info, ret fail"); + return FAILED; + } + const BlobShape &blob_shape = reshape_parameter.shape(); + std::vector dims; + for (int i = 0; i < blob_shape.dim_size(); i++) { + dims.push_back(blob_shape.dim(i)); + } + + if (reshape_parameter.has_axis()) { + GE_LOGW_IF(reshape_parameter.axis() == -1, + "axis with -1 may lead to calculation errors when input less than 4 dims."); + GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_AXIS, reshape_parameter.axis())), + GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return + } + if (reshape_parameter.has_num_axes()) { + GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_NUM_AXES, reshape_parameter.num_axes())), + GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return + } + GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetListInt(op, RESHAPE_ATTR_SHAPE, dims)), + GELOGW("SetListInt failed for op %s.", op->GetName().c_str());); // no need to return + return SUCCESS; +} + +Status CaffeReshapeParser::ParseWeights(const Message *op_src, ge::OpDescPtr &op) { + (void)op_src; + (void)op; + return SUCCESS; +} + +Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto owner_graph = node->GetOwnerComputeGraph(); + if (owner_graph == nullptr) { + GELOGE(FAILED, "node's graph is empty, name: %s", node->GetName().c_str()); + return FAILED; + } + ge::OpDescPtr op = node->GetOpDesc(); + GE_CHECK_NOTNULL(op); + vector attr_shape; + GE_IF_BOOL_EXEC(!(ge::AttrUtils::GetListInt(op, RESHAPE_ATTR_SHAPE, attr_shape)), + GELOGW("GetListInt failed for op %s.", op->GetName().c_str());); // no need to return + size_t dims_size = attr_shape.size(); + + // construct GeTensorDesc + ge::GeTensorDesc const_desc = ge::GeTensorDesc(); + std::vector shape_vec = {static_cast(dims_size)}; + ge::GeShape shape(shape_vec); + const_desc.Update(shape, ge::FORMAT_NCHW, ge::DT_INT64); + ge::graphStatus state = op->UpdateInputDesc(RESHAPE_ATTR_SHAPE, const_desc); + if (state != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "Updata input_shape desc failed."); + return FAILED; + } + + // construct GeTensorPtr + ge::GeTensorPtr constTensor = ge::MakeShared(); + GE_CHECK_NOTNULL(constTensor); + constTensor->SetTensorDesc(const_desc); + + std::unique_ptr data(new (std::nothrow) int64_t[dims_size]()); + GE_CHECK_NOTNULL(data); + for (size_t i = 0; i < dims_size; ++i) { + data[i] = attr_shape[i]; + } + GE_IF_BOOL_EXEC( + constTensor->SetData(reinterpret_cast(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS, + GELOGW("SetData failed for GeTensor.");); // no need to return + + // construct const node and add edge + auto const_opdesc = ge::OpDescUtils::CreateConstOp(constTensor); + GE_CHECK_NOTNULL(const_opdesc); + auto const_node = owner_graph->AddNodeFront(const_opdesc); + GE_CHECK_NOTNULL(const_node); + ge::OutDataAnchorPtr out_archor_ptr = const_node->GetOutDataAnchor(kAnchorIndexZero); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(kAnchorIndexOne); + GE_CHECK_NOTNULL(in_archor_ptr); + state = ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr); + if (state != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "AddEdge failed of from Node %s to Node %s", const_node->GetName().c_str(), node->GetName().c_str()); + return domi::FAILED; + } + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(CAFFE, RESHAPE, CaffeReshapeParser); +} // namespace ge diff --git a/parser/caffe/caffe_reshape_parser.h b/parser/caffe/caffe_reshape_parser.h new file mode 100644 index 0000000..9051f77 --- /dev/null +++ b/parser/caffe/caffe_reshape_parser.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ +#define PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ + +#include "parser/caffe/caffe_op_parser.h" + +namespace ge { +class CaffeReshapeParser : public CaffeOpParser { + public: + /** + * @ingroup domi_omg + * @brief parse params of the operation + * @param [in] op_src params to be parsed + * @param [out] op_dest params after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; + + /** + * @ingroup domi_omg + * @brief parse weight of the operation + * @param [in] op_src params to be parsed + * @param [out] op_dest params after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseWeights(const Message *op_src, ge::OpDescPtr &op); + + /** + * @ingroup domi_omg + * @brief add const input node + * @param [in] node to add const input + * @param [out] node after add const input + * @return SUCCESS add const input successfully + * @return FAILED add const input failed + * @author + */ + Status AddConstInput(ge::NodePtr &node) override; +}; +} // namespace ge + +#endif // PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_ diff --git a/parser/caffe/proto/caffe/caffe.proto b/parser/caffe/proto/caffe/caffe.proto new file mode 100644 index 0000000..3f45aae --- /dev/null +++ b/parser/caffe/proto/caffe/caffe.proto @@ -0,0 +1,1821 @@ +syntax = "proto2"; + +package domi.caffe; + +// Specifies the shape (dimensions) of a Blob. +message BlobShape { + repeated int64 dim = 1 [packed = true]; +} + +message BlobProto { + optional BlobShape shape = 7; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; + repeated double double_data = 8 [packed = true]; + repeated double double_diff = 9 [packed = true]; + optional bytes int8_data = 10; + repeated int32 int32_data = 11 [packed = true]; + repeated uint64 uint64_data = 12 [packed = true]; + // 4D dimensions -- deprecated. Use "shape" instead. + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero output weights for a given input in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; + // Normalize the filler variance by fan_in, fan_out, or their average. + // Applies to 'xavier' and 'msra' fillers. + enum VarianceNorm { + FAN_IN = 0; + FAN_OUT = 1; + AVERAGE = 2; + } + optional VarianceNorm variance_norm = 8 [default = FAN_IN]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + // DEPRECATED. See InputParameter. The input blobs to the network. + repeated string input = 3; + // DEPRECATED. See InputParameter. The shape of the input blobs. + repeated BlobShape input_shape = 8; + + // 4D input dimensions -- deprecated. Use "input_shape" instead. + // If specified, for each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; + + // Print debugging information about results while running Net::Forward, + // Net::Backward, and Net::Update. + optional bool debug_info = 7 [default = false]; + + // The layers that make up the net. Each of their configurations, including + // connectivity and behavior, is specified as a LayerParameter. + repeated LayerParameter layer = 100; // ID 100 so layers are printed last. + + // DEPRECATED: use 'layer' instead. + repeated V1LayerParameter layers = 2; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 42 (last added: layer_wise_reduce) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + // accumulate gradients over `iter_size` x `batch_size` instances + optional int32 iter_size = 36 [default = 1]; + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + + // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, + // whenever their actual L2 norm is larger. + optional float clip_gradients = 35 [default = -1]; + + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + enum SnapshotFormat { + HDF5 = 0; + BINARYPROTO = 1; + } + optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // type of the solver + optional string type = 40 [default = "SGD"]; + + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam + optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; + + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38 [default = 0.99]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; + + // Overlap compute and communication for data parallel training + optional bool layer_wise_reduce = 41 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// Specifies training parameters (multipliers on global learning constants, +// and the name and other settings used for weight sharing). +message ParamSpec { + // The names of the parameter blobs -- useful for sharing parameters among + // layers, but never required otherwise. To share a parameter between two + // layers, give it a (non-empty) name. + optional string name = 1; + + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + optional DimCheckMode share_mode = 2; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + + // The multiplier on the global learning rate for this parameter. + optional float lr_mult = 3 [default = 1.0]; + + // The multiplier on the global weight decay for this parameter. + optional float decay_mult = 4 [default = 1.0]; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) +message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type + repeated string bottom = 3; // the name of each bottom blob + repeated string top = 4; // the name of each top blob + + // The train / test phase for computation. + optional Phase phase = 10; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 5; + + // Specifies training parameters (multipliers on global learning constants, + // and the name and other settings used for weight sharing). + repeated ParamSpec param = 6; + + // The blobs containing the numeric parameters of the layer. + repeated BlobProto blobs = 7; + + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // + // The size must be either 0 or equal to the number of bottoms. + repeated bool propagate_down = 11; + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 8; + repeated NetStateRule exclude = 9; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 100; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 101; + + // Layer type-specific parameters. + // + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + optional AccuracyParameter accuracy_param = 102; + optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; + optional BiasParameter bias_param = 141; + optional ConcatParameter concat_param = 104; + optional ContrastiveLossParameter contrastive_loss_param = 105; + optional ConvolutionParameter convolution_param = 106; + optional CropParameter crop_param = 144; + optional DataParameter data_param = 107; + optional DetectionOutputParameter detection_output_param = 150; + optional DropoutParameter dropout_param = 108; + optional DummyDataParameter dummy_data_param = 109; + optional EltwiseParameter eltwise_param = 110; + optional ELUParameter elu_param = 140; + optional EmbedParameter embed_param = 137; + optional ExpParameter exp_param = 111; + optional FlattenParameter flatten_param = 135; + optional HDF5DataParameter hdf5_data_param = 112; + optional HDF5OutputParameter hdf5_output_param = 113; + optional HingeLossParameter hinge_loss_param = 114; + optional ImageDataParameter image_data_param = 115; + optional InfogainLossParameter infogain_loss_param = 116; + optional InnerProductParameter inner_product_param = 117; + optional InputParameter input_param = 143; + optional LogParameter log_param = 134; + optional LRNParameter lrn_param = 118; + optional MemoryDataParameter memory_data_param = 119; + optional MVNParameter mvn_param = 120; + optional ParameterParameter parameter_param = 145; + optional PoolingParameter pooling_param = 121; + optional PowerParameter power_param = 122; + optional PReLUParameter prelu_param = 131; + optional PythonParameter python_param = 130; + optional RecurrentParameter recurrent_param = 146; + optional ReductionParameter reduction_param = 136; + optional ReLUParameter relu_param = 123; + optional ReshapeParameter reshape_param = 133; + optional ScaleParameter scale_param = 142; + optional SigmoidParameter sigmoid_param = 124; + optional SmoothL1LossParameter smooth_l1_loss_param = 148; + optional SoftmaxParameter softmax_param = 125; + optional SPPParameter spp_param = 132; + optional SliceParameter slice_param = 126; + optional TanHParameter tanh_param = 127; + optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; + optional WindowDataParameter window_data_param = 129; + optional PermuteParameter permute_param = 202; + optional PriorBoxParameter prior_box_param = 203; + optional NormalizeParameter norm_param = 206; + optional PSROIPoolingParameter psroi_pooling_param = 207; + optional FreespaceExtractParameter freespace_extract_param = 151; + optional PostprocessParameter postprocess_param = 152; + optional SpatialTransformParameter spatial_transform_param = 153; + optional ROIAlignParameter roi_align_param = 154; + optional ReorgParameter reorg_param = 155; + optional RegionParameter region_param = 156; + optional ReverseParameter reverse_param = 157; + optional InterpParameter interp_param = 158; + optional ShuffleChannelParameter shuffle_channel_param = 159; + optional UpsampleParameter upsample_param = 160; + optional ROIPoolingParameter roi_pooling_param = 161; + optional YoloParameter yolo_param = 199; + optional YoloV3DetectionOutputParameter yolov3_detection_output_param = 200; + optional ProposalParameter proposal_param = 201; + optional FSRDetectionOutputParameter fsrdetectionoutput_param = 222; + optional SSDDetectionOutputParameter ssddetectionoutput_param = 232; + optional YoloV2DetectionOutputParameter yolov2_detection_output_param = 204; + optional QuantParameter quant_param = 208; + optional CondTakeParameter condtake_param = 233; + optional MatrixInverseParameter matrix_inverse_param = 210; + optional WarpPerspectiveParameter warp_perspective_param = 234; + optional BatchMatMulParameter batch_matmul_param = 235; + optional SpatialTransformerParameter st_param = 5000; + optional YoloV3DetectionOutputV2Parameter yolov3_detection_output_v2_param = 5001; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; + // Force the decoded image to have 3 color channels. + optional bool force_color = 6 [default = false]; + // Force the decoded image to have 1 color channels. + optional bool force_gray = 7 [default = false]; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // How to normalize the loss for loss layers that aggregate across batches, + // spatial dimensions, or other dimensions. Currently only implemented in + // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. + enum NormalizationMode { + // Divide by the number of examples in the batch times spatial dimensions. + // Outputs that receive the ignore label will NOT be ignored in computing + // the normalization factor. + FULL = 0; + // Divide by the total number of output locations that do not take the + // ignore_label. If ignore_label is not set, this behaves like FULL. + VALID = 1; + // Divide by the batch size. + BATCH_SIZE = 2; + // Do not normalize the loss. + NONE = 3; + } + // For historical reasons, the default normalization for + // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. + optional NormalizationMode normalization = 3 [default = VALID]; + // Deprecated. Ignored if normalization is specified. If normalization + // is not specified, then setting this to false will be equivalent to + // normalization = BATCH_SIZE to be consistent with previous behavior. + optional bool normalize = 2; +} + +// Messages that store parameters used by individual layer types follow, in +// alphabetical order. + +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; + + // If specified, ignore instances with the given label. + optional int32 ignore_label = 3; +} + +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; +} + +message ConcatParameter { + // The axis along which to concatenate -- may be negative to index from the + // end (e.g., -1 for the last axis). Other axes must have the + // same dimension for all the bottom blobs. + // By default, ConcatLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 2 [default = 1]; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 concat_dim = 1 [default = 1]; +} + +message BatchNormParameter { + // If false, normalization is performed over the current mini-batch + // and global statistics are accumulated (but not yet used) by a moving + // average. + // If true, those accumulated mean and variance values are used for the + // normalization. + // By default, it is set to false when the network is in the training + // phase and true when the network is in the testing phase. + optional bool use_global_stats = 1; + // What fraction of the moving average remains each iteration? + // Smaller values make the moving average decay faster, giving more + // weight to the recent values. + // Each iteration updates the moving average @f$S_{t-1}@f$ with the + // current mean @f$ Y_t @f$ by + // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ + // is the moving_average_fraction parameter. + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + +message BiasParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar bias. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the bias + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to add a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer.) + // The initialization for the learned bias parameter. + // Default is the zero (0) initialization, resulting in the BiasLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + optional bool bias_from_blob = 4 [default = true]; +} + +message ContrastiveLossParameter { + // margin for dissimilar pair + optional float margin = 1 [default = 1.0]; + // The first implementation of this cost did not exactly match the cost of + // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. + // legacy_version = false (the default) uses (margin - d)^2 as proposed in the + // Hadsell paper. New models should probably use this version. + // legacy_version = true uses (margin - d^2). This is kept to support / + // reproduce existing models and results + optional bool legacy_version = 2 [default = false]; +} + +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in all spatial dimensions, or once per spatial dimension. + repeated uint32 pad = 3; // The padding size; defaults to 0 + repeated uint32 kernel_size = 4; // The kernel size + repeated uint32 stride = 6; // The stride; defaults to 1 + // Factor used to dilate the kernel, (implicitly) zero-filling the resulting + // holes. (Kernel dilation is sometimes referred to by its use in the + // algorithme à trous from Holschneider et al. 1987.) + repeated uint32 dilation = 18; // The dilation; defaults to 1 + + // For 2D convolution only, the *_h and *_w versions may also be used to + // specify both spatial dimensions. + optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) + optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) + optional uint32 kernel_h = 11; // The kernel height (2D only) + optional uint32 kernel_w = 12; // The kernel width (2D only) + optional uint32 stride_h = 13; // The stride height (2D only) + optional uint32 stride_w = 14; // The stride width (2D only) + + optional uint32 group = 5 [default = 1]; // The group size for group conv + + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; + + // The axis to interpret as "channels" when performing convolution. + // Preceding dimensions are treated as independent inputs; + // succeeding dimensions are treated as "spatial". + // With (N, C, H, W) inputs, and axis == 1 (the default), we perform + // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for + // groups g>1) filters across the spatial axes (H, W) of the input. + // With (N, C, D, H, W) inputs, and axis == 1, we perform + // N independent 3D convolutions, sliding (C/g)-channels + // filters across the spatial axes (D, H, W) of the input. + optional int32 axis = 16 [default = 1]; + + // Whether to force use of the general ND convolution, even if a specific + // implementation for blobs of the appropriate number of spatial dimensions + // is available. (Currently, there is only a 2D-specific convolution + // implementation; for input blobs with num_axes != 2, this option is + // ignored and the ND implementation will be used.) + optional bool force_nd_im2col = 17 [default = false]; +} + +message CropParameter { + // To crop, elements of the first bottom are selected to fit the dimensions + // of the second, reference bottom. The crop is configured by + // - the crop `axis` to pick the dimensions for cropping + // - the crop `offset` to set the shift for all/each dimension + // to align the cropped bottom with the reference bottom. + // All dimensions up to but excluding `axis` are preserved, while + // the dimensions including and trailing `axis` are cropped. + // If only one `offset` is set, then all dimensions are offset by this amount. + // Otherwise, the number of offsets must equal the number of cropped axes to + // shift the crop in each dimension accordingly. + // Note: standard dimensions are N,C,H,W so the default is a spatial crop, + // and `axis` may be negative to index from the end (e.g., -1 for the last + // axis). + optional int32 axis = 1 [default = 2]; + repeated uint32 offset = 2; +} + +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + // Force the encoded image to have 3 color channels + optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Increase if data feeding bandwidth varies, within the + // limit of device memory for GPU training) + optional uint32 prefetch = 10 [default = 4]; +} + +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase +} + +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // shape fields, and 0, 1 or N data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated BlobShape shape = 6; + + // 4D dimensions -- deprecated. Use "shape" instead. + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ELULayer +message ELUParameter { + // Described in: + // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate + // Deep Network Learning by Exponential Linear Units (ELUs). arXiv + optional float alpha = 1 [default = 1]; +} + +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +/// Message that stores parameters used by FlattenLayer +message FlattenParameter { + // The first axis to flatten: all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 1 [default = 1]; + + // The last axis to flatten: all following axes are retained in the output. + // May be negative to index from the end (e.g., the default -1 for the last + // axis). + optional int32 end_axis = 2 [default = -1]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; + + // Specify whether to shuffle the data. + // If shuffle == true, the ordering of the HDF5 files is shuffled, + // and the ordering of data within any given HDF5 file is shuffled, + // but data between different files are not interleaved; all of a file's + // data are output (in a random order) before moving onto another file. + optional bool shuffle = 3 [default = false]; +} + +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4 [default = 1]; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; + optional int32 axis = 2 [default = 1]; // axis of prob +} + +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias + + // The first axis to be lumped into a single inner product computation; + // all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 5 [default = 1]; + // Specify whether to transpose the weight matrix or not. + // If transpose == true, any operations will be performed on the transpose + // of the weight matrix. The weight matrix itself is not going to be transposed + // but rather the transfer flag of operations will be toggled accordingly. + optional bool transpose = 6 [default = false]; +} + +message InputParameter { + // This layer produces N >= 1 top blob(s) to be assigned manually. + // Define N shapes to set a shape for each top. + // Define 1 shape to set the same shape for every top. + // Define no shape to defer to reshaping manually. + repeated BlobShape shape = 1; +} + +// Message that stores parameters used by LogLayer +message LogParameter { + // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = ln(shift + scale * x) = log_e(shift + scale * x) + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; + + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 3 [default = 1e-9]; +} + +message ParameterParameter { + optional BlobShape shape = 1; +} + +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; + optional bool ceil_mode = 13 [default = true]; + // How to calculate the output size - using ceil (default) or floor rounding. + enum RoundMode { + CEIL = 0; + FLOOR = 1; + } + optional RoundMode round_mode = 14 [default = CEIL]; +} + +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +message PythonParameter { + optional string module = 1; + optional string layer = 2; + // This value is set to the attribute `param_str` of the `PythonLayer` object + // in Python before calling the `setup()` method. This could be a number, + // string, dictionary in Python dict format, JSON, etc. You may parse this + // string in `setup` method and use it in `forward` and `backward`. + optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; +} + +// Message that stores parameters used by RecurrentLayer +message RecurrentParameter { + // The dimension of the output (and usually hidden state) representation -- + // must be explicitly set to non-zero. + optional uint32 num_output = 1 [default = 0]; + + optional FillerParameter weight_filler = 2; // The filler for the weight + optional FillerParameter bias_filler = 3; // The filler for the bias + + // Whether to enable displaying debug_info in the unrolled recurrent net. + optional bool debug_info = 4 [default = false]; + + // Whether to add as additional inputs (bottoms) the initial hidden state + // blobs, and add as additional outputs (tops) the final timestep hidden state + // blobs. The number of additional bottom/top blobs required depends on the + // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. + optional bool expose_hidden = 5 [default = false]; +} + +// Message that stores parameters used by ReductionLayer +message ReductionParameter { + enum ReductionOp { + SUM = 1; + ASUM = 2; + SUMSQ = 3; + MEAN = 4; + } + + optional ReductionOp operation = 1 [default = SUM]; // reduction operation + + // The first axis to reduce to a scalar -- may be negative to index from the + // end (e.g., -1 for the last axis). + // (Currently, only reduction along ALL "tail" axes is supported; reduction + // of axis M through N, where N < num_axes - 1, is unsupported.) + // Suppose we have an n-axis bottom Blob with shape: + // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). + // If axis == m, the output Blob will have shape + // (d0, d1, d2, ..., d(m-1)), + // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) + // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. + // If axis == 0 (the default), the output Blob always has the empty shape + // (count 1), performing reduction across the entire input -- + // often useful for creating new loss functions. + optional int32 axis = 2 [default = 0]; + + optional float coeff = 3 [default = 1.0]; // coefficient for output +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +message ReshapeParameter { + // Specify the output dimensions. If some of the dimensions are set to 0, + // the corresponding dimension from the bottom layer is used (unchanged). + // Exactly one dimension may be set to -1, in which case its value is + // inferred from the count of the bottom blob and the remaining dimensions. + // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: + // + // layer { + // type: "Reshape" bottom: "input" top: "output" + // reshape_param { ... } + // } + // + // If "input" is 2D with shape 2 x 8, then the following reshape_param + // specifications are all equivalent, producing a 3D blob "output" with shape + // 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } + // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } + // + optional BlobShape shape = 1; + + // axis and num_axes control the portion of the bottom blob's shape that are + // replaced by (included in) the reshape. By default (axis == 0 and + // num_axes == -1), the entire bottom blob shape is included in the reshape, + // and hence the shape field must specify the entire output shape. + // + // axis may be non-zero to retain some portion of the beginning of the input + // shape (and may be negative to index from the end; e.g., -1 to begin the + // reshape after the last axis, including nothing in the reshape, + // -2 to include only the last axis, etc.). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are all equivalent, + // producing a blob "output" with shape 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } + // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } + // + // num_axes specifies the extent of the reshape. + // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on + // input axes in the range [axis, axis+num_axes]. + // num_axes may also be -1, the default, to include all remaining axes + // (starting from axis). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are equivalent, + // producing a blob "output" with shape 1 x 2 x 8. + // + // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } + // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } + // reshape_param { shape { dim: 1 } num_axes: 0 } + // + // On the other hand, these would produce output blob shape 2 x 1 x 8: + // + // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } + // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } + // + optional int32 axis = 2 [default = 0]; + optional int32 num_axes = 3 [default = -1]; +} + + +message ScaleParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar multiplier. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the scale + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer.) + // The initialization for the learned scale parameter. + // Default is the unit (1) initialization, resulting in the ScaleLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + + // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but + // may be more efficient). Initialized with bias_filler (defaults to 0). + optional bool bias_term = 4 [default = false]; + optional FillerParameter bias_filler = 5; + optional bool scale_from_blob = 6 [default = true]; +} + +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +message SliceParameter { + // The axis along which to slice -- may be negative to index from the end + // (e.g., -1 for the last axis). + // By default, SliceLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 3 [default = 1]; + repeated uint32 slice_point = 2; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 slice_dim = 1 [default = 1]; +} + +message SmoothL1LossParameter { + // SmoothL1Loss(x) = + // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma + // |x| - 0.5 / sigma / sigma -- otherwise + optional float sigma = 1 [default = 1]; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; + + // The axis along which to perform the softmax -- may be negative to index + // from the end (e.g., -1 for the last axis). + // Any other axes will be evaluated as independent softmaxes. + optional int32 axis = 2 [default = 1]; +} + +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by TileLayer +message TileParameter { + // The index of the axis to tile. + optional int32 axis = 1 [default = 1]; + + // The number of copies (tiles) of the blob to output. + optional int32 tiles = 2; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +message SPPParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional uint32 pyramid_height = 1; + optional PoolMethod pool = 2 [default = MAX]; // The pooling method + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +// DEPRECATED: use LayerParameter. +message V1LayerParameter { + repeated string bottom = 2; + repeated string top = 3; + optional string name = 4; + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + enum LayerType { + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + QUANT = 208; + DEQUANT = 209; + } + optional LayerType type = 5; + repeated BlobProto blobs = 6; + repeated string param = 1001; + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + STRICT = 0; + PERMISSIVE = 1; + } + repeated float blobs_lr = 7; + repeated float weight_decay = 8; + repeated float loss_weight = 35; + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + optional TransformationParameter transform_param = 36; + optional LossParameter loss_param = 42; + optional V0LayerParameter layer = 1; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} + +message PReLUParameter { + // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: + // Surpassing Human-Level Performance on ImageNet Classification, 2015. + + // Initial value of a_i. Default is a_i=0.25 for all i. + optional FillerParameter filler = 1; + // Whether or not slope parameters are shared across channels. + optional bool channel_shared = 2 [default = false]; +} + +// Message that stores parameters used by DetectionOutputLayer +//message DetectionOutputParameter { +// optional int32 num_classes = 1 [default = 21]; +// optional float nms_threshold = 2 [default = 0.3]; +// optional int32 top_k = 3; +// optional float confidence_threshold = 4 [default = 0.8]; +//} + +// Message that store parameters used by PriorBoxLayer +message PriorBoxParameter { + // Encode/decode type. + enum CodeType { + CORNER = 1; + CENTER_SIZE = 2; + CORNER_SIZE = 3; + } + // Minimum box size (in pixels). Required! + repeated float min_size = 1; + // Maximum box size (in pixels). Required! + repeated float max_size = 2; + // Various of aspect ratios. Duplicate ratios will be ignored. + // If none is provided, we use default ratio 1. + repeated float aspect_ratio = 3; + // If true, will flip each aspect ratio. + // For example, if there is aspect ratio "r", + // we will generate aspect ratio "1.0/r" as well. + optional bool flip = 4 [default = true]; + // If true, will clip the prior so that it is within [0, 1] + optional bool clip = 5 [default = false]; + // Variance for adjusting the prior bboxes. + repeated float variance = 6; + // By default, we calculate img_height, img_width, step_x, step_y based on + // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely + // provided. + // Explicitly provide the img_size. + optional uint32 img_size = 7; + // Either img_size or img_h/img_w should be specified; not both. + optional uint32 img_h = 8; + optional uint32 img_w = 9; + + // Explicitly provide the step size. + optional float step = 10; + // Either step or step_h/step_w should be specified; not both. + optional float step_h = 11; + optional float step_w = 12; + + // Offset to the top left corner of each cell. + optional float offset = 13 [default = 0.5]; +} + +// Message that stores parameters used by PermutetLayer +message PermuteParameter { + // The new orders of the axes of data. Notice it should be with + // in the same range as the input data, and it starts from 0. + // Do not provide repeated order. + repeated uint32 order = 1; +} + +message NormalizeParameter { + optional bool across_spatial = 1 [default = true]; + // Initial value of scale. Default is 1.0 for all + optional FillerParameter scale_filler = 2; + // Whether or not scale parameters are shared across channels. + optional bool channel_shared = 3 [default = true]; + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 4 [default = 1e-10]; +} + +// needed by ssd +message SaveOutputParameter { + // Output directory. If not empty, we will save the results. + optional string output_directory = 1; + // Output name prefix. + optional string output_name_prefix = 2; + // Output format. + // VOC - PASCAL VOC output format. + // COCO - MS COCO output format. + optional string output_format = 3; + // If you want to output results, must also provide the following two files. + // Otherwise, we will ignore saving results. + // label map file. + optional string label_map_file = 4; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + optional string name_size_file = 5; + // Number of test images. It can be less than the lines specified in + // name_size_file. For example, when we only want to evaluate on part + // of the test images. + optional uint32 num_test_image = 6; + // The resize parameter used in saving the data. + // optional ResizeParameter resize_param = 7; +} + +message NonMaximumSuppressionParameter { + // Threshold to be used in nms. + optional float nms_threshold = 1 [default = 0.3]; + // Maximum number of results to be kept. + optional int32 top_k = 2; + // Parameter for adaptive nms. + optional float eta = 3 [default = 1.0]; +} + +message GeneralNmsParameter { + optional int32 post_top_k = 1 ; + optional float nms_threshold = 2 [default = 0]; + optional float iou_threshold_decay = 3 [default = 1.0]; + optional float coor_scale_factor = 4 [default = 1.0]; +} + +// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn +message DetectionOutputParameter { + optional int32 num_classes = 1; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional NonMaximumSuppressionParameter nms_param = 4; + optional SaveOutputParameter save_output_param = 5; + optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; + optional bool variance_encoded_in_target = 8 [default = true]; + optional int32 keep_top_k = 7; + optional float confidence_threshold = 9; + optional float nms_threshold = 13; + optional int32 top_k = 14; + optional int32 boxes = 15 [default = 1]; + optional bool relative = 17 [default = true]; + optional float objectness_threshold = 18 [default = 0.5]; + optional float class_threshold = 19 [default = 0.5]; + repeated float biases = 20; + optional GeneralNmsParameter general_nms_param = 21; + optional float objectness_score = 22; +} +message PSROIPoolingParameter { + required float spatial_scale = 1; + required int32 output_dim = 2; // output channel number + required int32 group_size = 3; // number of groups to encode position-sensitive score maps +} +// Message that stores parameters used by FreespaceExtractLayer +message FreespaceExtractParameter { + optional float org_height = 1; +} + +// Message that stores parameters used by DetectpostprocessLayer +message PostprocessParameter { + optional float nms_thresh = 1 [default = 0.3]; + optional float conf_thresh = 2 [default = 0.5]; + optional uint32 post_nms_topn = 3 [default = 100]; + optional uint32 cls_num = 4 [default = 12]; + repeated float bbox_reg_weights = 5; +} + +// Message that stores parameters used by SpatialTransformLayer +message SpatialTransformParameter { + optional uint32 output_h = 1 [default = 0]; + optional uint32 output_w = 2 [default = 0]; + optional float border_value = 3 [default = 0]; + repeated float affine_transform = 4; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} +message ROIAlignParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; + optional int32 sampling_ratio = 4 [default = -1]; + optional int32 roi_end_mode = 5 [default = 0]; +} + +message RegionParameter { + optional uint32 classes = 1 [default = 20]; // Category of classification + optional uint32 coords = 2 [default = 4]; // Coordinates of box + optional uint32 boxes = 3 [default = 1]; // Number of boxes predicted per grid + optional uint32 softmax = 4 [default = 0]; + optional string softmax_tree = 5 [default = ""]; + optional uint32 background = 6 [default = 0]; +} +message ReorgParameter{ + optional uint32 stride = 2 [default = 2]; + optional bool reverse = 1 [default = false]; +} +message ReverseParameter{ + repeated int32 axis = 1; +} +message InterpParameter{ + optional int32 height = 1 [default = 0];//Height of output + optional int32 width = 2 [default = 0];//Width of output + optional int32 zoom_factor = 3 [default = 1];//zoom factor + optional int32 shrink_factor = 4 [default = 1];//shrink factor + optional int32 pad_beg = 5 [default = 0];//padding at begin of input + optional int32 pad_end = 6 [default = 0];//padding at end of input +} +message ShuffleChannelParameter{ + optional uint32 group = 1[default = 1]; // The number of group +} +message UpsampleParameter{ + optional float scale = 1[default = 1]; + optional int32 stride = 2[default = 2]; + optional int32 stride_h = 3[default = 2]; + optional int32 stride_w = 4[default=2]; +} +message ROIPoolingParameter { + required int32 pooled_h = 1; + required int32 pooled_w = 2; + optional float spatial_scale = 3 [default=0.0625]; + optional float spatial_scale_h = 4; + optional float spatial_scale_w = 5; +} + +message YoloParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 coords = 2 [default = 4]; + optional int32 classes = 3 [default = 80]; + optional string yolo_version = 4 [default = "V3"]; + optional bool softmax = 5 [default = false]; + optional bool background = 6 [default = false]; + optional bool softmaxtree = 7 [default = false]; +} + +message YoloV3DetectionOutputParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; +} + +message YoloV3DetectionOutputV2Parameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; + optional int32 out_box_dim = 15 [default = 3]; +} + +message ProposalParameter { + optional float feat_stride = 1 [default = 16]; + optional float base_size = 2 [default = 16]; + optional float min_size = 3 [default = 16]; + repeated float ratio = 4; + repeated float scale = 5; + optional int32 pre_nms_topn = 6 [default = 3000]; + optional int32 post_nms_topn = 7 [default = 304]; + optional float iou_threshold = 8 [default = 0.7]; + optional bool output_actual_rois_num = 9 [default = false]; +} + +message FSRDetectionOutputParameter { + required int32 num_classes = 1; + required float score_threshold = 2; + required float iou_threshold = 3; + optional int32 batch_rois = 4 [default = 1]; +} + +message SSDDetectionOutputParameter { + required int32 num_classes= 1 [default = 2]; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional float iou_threshold = 4 [default = 0.3]; + optional int32 top_k = 5 [default = 200]; + optional float eta = 6 [default = 1.0]; + optional bool variance_encoded_in_target = 7 [default = false]; + optional int32 code_type = 8 [default = 1]; + optional int32 keep_top_k = 9 [default = -1]; + optional float confidence_threshold = 10 [default = 0.0]; +} +message YoloV2DetectionOutputParameter { + optional int32 boxes = 1 [default = 5]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases = 9; + optional int32 coords = 10 [default = 4]; + optional bool resize_origin_img_to_net = 11 [default = false]; +} + +message QuantParameter { + optional float scale = 2; + optional bytes offset = 3; +} + +message BatchMatMulParameter{ + optional bool adj_x1 = 1 [default = false]; + optional bool adj_x2 = 2 [default = false]; +} + +message CondTakeParameter { + required string mode = 1; + required float val = 2; + optional float eps = 3 [default = 1e-06]; +} + +message MatrixInverseParameter { + optional bool adjoint = 1 [default = false]; +} + +message WarpPerspectiveParameter { + required int32 out_height = 1; + required int32 out_width = 2; + optional float constant = 3; + optional string border_type = 4 [default = 'BORDER_CONSTANT']; +} + +message SpatialTransformerParameter { + // How to use the parameter passed by localisation network + optional string transform_type = 1 [default = "affine"]; + // What is the sampling technique + optional string sampler_type = 2 [default = "bilinear"]; + + // If not set,stay same with the input dimension H and W + optional int32 output_H = 3; + optional int32 output_W = 4; + // If false, only compute dTheta, DO NOT compute dU + optional bool to_compute_dU = 5 [default = true]; + + // The default value for some parameters + optional double theta_1_1 = 6; + optional double theta_1_2 = 7; + optional double theta_1_3 = 8; + optional double theta_2_1 = 9; + optional double theta_2_2 = 10; + optional double theta_2_3 = 11; +} diff --git a/parser/caffe/proto/ge_ir.proto b/parser/caffe/proto/ge_ir.proto new file mode 100644 index 0000000..e7bfe0c --- /dev/null +++ b/parser/caffe/proto/ge_ir.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/parser/caffe/proto/om.proto b/parser/caffe/proto/om.proto new file mode 100644 index 0000000..e15e5f8 --- /dev/null +++ b/parser/caffe/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/parser/caffe/proto/task.proto b/parser/caffe/proto/task.proto new file mode 100644 index 0000000..d0c0984 --- /dev/null +++ b/parser/caffe/proto/task.proto @@ -0,0 +1,165 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/parser/common/CMakeLists.txt b/parser/common/CMakeLists.txt new file mode 100644 index 0000000..77c831b --- /dev/null +++ b/parser/common/CMakeLists.txt @@ -0,0 +1,76 @@ +set(SRC_LIST + "parser_factory.cc" + "data_op_parser.cc" + "op_parser_factory.cc" + "pre_checker.cc" + "register_tbe.cc" + "parser_api.cc" + "parser_inner_ctx.cc" + "proto_file_parser.cc" + "acl_graph_parser_util.cc" + "tbe_plugin_loader.cc" + "model_saver.cc" + "../tensorflow/tensorflow_custom_parser_adapter.cc" + "../tensorflow/tensorflow_fusion_custom_parser_adapter.cc" + "../tensorflow/tensorflow_fusion_op_parser.cc" + "../tensorflow/tensorflow_util.cc" + "convert/pb2json.cc" + "op_def/ir_pb_converter.cc" + "op_def/defs.cc" + "op_def/op_schema.cc" + "op_def/operator.cc" + "op_map.cc" + "parser_types.cc" + "pass_manager.cc" + "parser_fp16_t.cc" + "thread_pool.cc" +) + +############ libparser_common.so ############ +add_library(parser_common SHARED ${SRC_LIST}) + +target_compile_options(parser_common PRIVATE + -Werror +) + +target_compile_definitions(parser_common PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 +) + +target_include_directories(parser_common PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${TOP_DIR}/framework/domi + ${TOP_DIR}/framework/domi/common + ${TOP_DIR}/framework/domi/parser + ${TOP_DIR}/inc + ${TOP_DIR}/inc/common/util + ${TOP_DIR}/inc/external + ${TOP_DIR}/inc/external/graph + ${TOP_DIR}/inc/framework + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge +) + +target_link_libraries(parser_common PRIVATE + $ + -Wl,--no-as-needed + graph + protobuf + register + c_sec + slog + mmpa + error_manager + -Wl,--as-needed + json + -lrt + -ldl +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS parser_common OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc new file mode 100644 index 0000000..0a16d38 --- /dev/null +++ b/parser/common/acl_graph_parser_util.cc @@ -0,0 +1,492 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "parser/common/acl_graph_parser_util.h" + +#include +#include +#include +#include +#include + +#include "common/string_util.h" +#include "common/debug/log.h" +#include "common/op/ge_op_utils.h" +#include "ge/ge_api_types.h" +#include "graph/opsproto_manager.h" +#include "omg/parser/parser_inner_ctx.h" +#include "tbe_plugin_loader.h" +#include "framework/common/debug/ge_log.h" +#include "parser/common/register_tbe.h" +#include "framework/omg/parser/parser_types.h" +#include "common/util/error_manager/error_manager.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" + +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::FileInputStream; +using google::protobuf::io::ZeroCopyInputStream; +using namespace ge::parser; + +namespace { +/// The maximum length of the file. +/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 +const int kMaxFileSizeLimit = INT_MAX; +const int kMaxBuffSize = 256; +const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. +const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M + +static string GetSoPath() { + Dl_info dl_info; + if (dladdr(reinterpret_cast(&GetSoPath), &dl_info) == 0) { + GELOGW("Failed to read so_path!"); + return string(); + } else { + std::string so_path = dl_info.dli_fname; + char path[PATH_MAX] = {0}; + if (so_path.length() >= PATH_MAX) { + GELOGW("File path is too long!"); + return string(); + } + if (realpath(so_path.c_str(), path) == nullptr) { + GELOGW("Failed to get realpath of %s", so_path.c_str()); + return string(); + } + + so_path = path; + so_path = so_path.substr(0, so_path.rfind('/') + 1); + return so_path; + } +} + +static void GetOpsProtoPath(string &opsproto_path) { + GELOGD("Start to get ops proto path schedule."); + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env != nullptr) { + string path = path_env; + string file_path = ge::parser::RealPath(path.c_str()); + if (file_path.empty()) { + GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str()); + return; + } + opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); + GELOGI("Get opsproto so path from env : %s", path.c_str()); + return; + } + string path_base = GetSoPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); +} +} // namespace + +namespace ge { +domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, + std::vector> &output_nodes_info) { + ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); + if (tmpDescPtr == nullptr) { + GELOGE(domi::FAILED, "Get outnode op desc fail."); + return domi::FAILED; + } + size_t size = tmpDescPtr->GetOutputsSize(); + if (node->GetType() != NETOUTPUT) { + for (size_t index = 0; index < size; ++index) { + output_nodes_info.push_back(std::make_pair(node, index)); + } + } else { + const auto in_anchors = node->GetAllInDataAnchors(); + for (auto in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGE(domi::FAILED, "Get leaf node op desc fail."); + return domi::FAILED; + } + auto out_node = out_anchor->GetOwnerNode(); + output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); + } + } + return SUCCESS; +} + +void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name) { + output_nodes_name.clear(); + if (ge::GetParserContext().out_top_names.empty()) { + // tf process, no top name. + for (const auto output_node_info : output_nodes_info) { + std::string node_name = output_node_info.first->GetName(); + int32_t index = output_node_info.second; + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + return; + } + // caffe process reserved place; +} + +domi::Status AclGrphParseUtil::SetDefaultOutputNode(ge::Graph &graph) { + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + if (compute_graph == nullptr) { + GELOGE(FAILED, "compute_graph is nullptr."); + return FAILED; + } + + std::vector> output_nodes_info; + std::vector output_nodes_name; + + for (ge::NodePtr node : compute_graph->GetDirectNode()) { + if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) { + Status ret = AclGrphParseUtil::GetOutputLeaf(node, output_nodes_info); + if (ret != SUCCESS) { + GELOGE(FAILED, "find leaf fail."); + return FAILED; + } + } + } + + AclGrphParseUtil::GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); + compute_graph->SetGraphOutNodesInfo(output_nodes_info); + ge::GetParserContext().net_out_nodes = output_nodes_name; + GELOGI("Set graph %s default output node success.", graph.GetName().c_str()); + return SUCCESS; +} + +domi::Status AclGrphParseUtil::LoadOpsProtoLib() { + string opsproto_path; + GetOpsProtoPath(opsproto_path); + GELOGI("Get opsproto path is %s", opsproto_path.c_str()); + OpsProtoManager *manager = OpsProtoManager::Instance(); + map option_tmp; + option_tmp.emplace(std::pair(string("ge.opsProtoLibPath"), opsproto_path)); + bool is_proto_init = manager->Initialize(option_tmp); + if (!is_proto_init) { + GELOGE(FAILED, "Load ops_proto lib failed, ops proto path is invalid."); + return FAILED; + } + return SUCCESS; +} + +void AclGrphParseUtil::SaveCustomCaffeProtoPath() { + GELOGD("Enter save custom caffe proto path."); + std::string path_base = GetSoPath(); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; + + string custom_op_path; + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env != nullptr) { + std::string path = path_env; + custom_op_path = path + "/framework/custom/caffe/"; + GELOGI("Get custom proto path from env : %s", path_env); + GetParserContext().custom_proto_path = custom_op_path; + return; + } + custom_op_path = path_base + "ops/framework/custom/caffe/"; + ge::GetParserContext().custom_proto_path = custom_op_path; + return; +} + +// Initialize PARSER, load custom op plugin +// options will be used later for parser decoupling +domi::Status AclGrphParseUtil::AclParserInitialize(const std::map &options) { + GELOGT(TRACE_INIT, "AclParserInitialize start"); + // check init status + if (parser_initialized) { + GELOGW("AclParserInitialize is called more than once"); + return SUCCESS; + } + + // load custom op plugin + TBEPluginLoader::Instance().LoadPluginSo(options); + + // load and save custom op proto for prediction + (void)LoadOpsProtoLib(); + SaveCustomCaffeProtoPath(); + + auto op_registry = domi::OpRegistry::Instance(); + if (op_registry == nullptr) { + GELOGE(FAILED, "Get OpRegistry instance failed"); + return FAILED; + } + + std::vector registrationDatas = op_registry->registrationDatas; + GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); + for (OpRegistrationData ®_data : registrationDatas) { + (void)OpRegistrationTbe::Instance()->Finalize(reg_data, false); + domi::OpRegistry::Instance()->Register(reg_data); + } + + // set init status + if (!parser_initialized) { + // Initialize success, first time calling initialize + parser_initialized = true; + } + + GELOGT(TRACE_STOP, "AclParserInitialize finished"); + return SUCCESS; +} +namespace parser { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { + if (path == nullptr) { + GELOGE(ge::FAILED, "path pointer is NULL."); + return ""; + } + if (strlen(path) >= PATH_MAX) { + ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); + GELOGE(ge::FAILED, "Path[%s] len is too long, it must be less than %d", path, PATH_MAX); + return ""; + } + // Nullptr is returned when the path does not exist or there is no permission + // Return absolute path when path is accessible + std::string res; + char resolved_path[PATH_MAX] = {0}; + if (realpath(path, resolved_path) != nullptr) { + res = resolved_path; + } + + return res; +} + +// Get file length +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null."); + + std::string real_path = RealPath(input_file.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); + unsigned long long file_length = 0; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, + {input_file, strerror(errno)}); + return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), + ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); + return -1, "File[%s] size is 0, not valid.", input_file.c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit, + ErrorManager::GetInstance().ATCReportErrMessage( + "E19016", {"filepath", "filesize", "maxlen"}, + {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); + return -1, "File[%s] size %lld is out of limit: %d.", + input_file.c_str(), file_length, kMaxFileSizeLimit); + return static_cast(file_length); +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { + struct timeval tv{}; + int ret = gettimeofday(&tv, nullptr); + GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret); + auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds + return static_cast(total_use_time); +} + +static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, + return false, "incorrect parameter. nullptr == proto"); + + coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit, kWarningThreshold); + return proto->ParseFromCodedStream(&coded_stream); +} + +/** @ingroup domi_common + * @brief Read all data from binary file + * @param [in] file_name File path + * @param [out] buffer The address of the output memory, which needs to be released by the caller + * @param [out] length Output memory size + * @return false fail + * @return true success + */ +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, + int &length) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr"); + + std::string real_path = RealPath(file_name); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "file path '%s' not valid", file_name); + + std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate); + if (!file.is_open()) { + GELOGE(ge::FAILED, "Read file %s failed.", file_name); + return false; + } + + length = static_cast(file.tellg()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); return false, "file length <= 0"); + + file.seekg(0, std::ios::beg); + + *buffer = new(std::nothrow) char[length](); + GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(), "new an object failed."); + + file.read(*buffer, length); + file.close(); + return true; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), + return false, + "Input parameter file or proto is nullptr!"); + + std::string real_path = RealPath(file); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), + return false, "pb file path '%s' not valid", file); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); + + std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); + if (!fs.is_open()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"}); + GELOGE(ge::FAILED, "Open real path[%s] failed.", file); + return false; + } + + google::protobuf::io::IstreamInputStream istream(&fs); + google::protobuf::io::CodedInputStream coded_stream(&istream); + + bool ret = ReadProtoFromCodedInputStream(coded_stream, proto); + + fs.close(); + + if (!ret) { + ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file}); + GELOGE(ge::FAILED, "Parse file[%s] failed.", file); + return ret; + } + + return ret; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false, + "incorrect parameter. proto is nullptr || data is nullptr || size is 0"); + + google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast(const_cast(data)), size); + return ReadProtoFromCodedInputStream(coded_stream, proto); +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, + google::protobuf::Message *message) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false, + "incorrect parameter. nullptr == file || nullptr == message"); + + std::string real_path = RealPath(file); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), + ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, + {file, strerror(errno)}); + return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, + strerror(errno)); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); + + std::ifstream fs(real_path.c_str(), std::ifstream::in); + + if (!fs.is_open()) { + ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file}); + GELOGE(ge::FAILED, + "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file); + return false; + } + + google::protobuf::io::IstreamInputStream input(&fs); + bool ret = google::protobuf::TextFormat::Parse(&input, message); + GE_IF_BOOL_EXEC(!ret, + ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file}); + GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, " + "please check whether the file is a valid protobuf format file.", file)); + fs.close(); + + return ret; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, + google::protobuf::Message *message) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false, + "incorrect parameter. data is nullptr || message is nullptr"); + std::string str(data, static_cast(size)); + std::istringstream fs(str); + + google::protobuf::io::IstreamInputStream input(&fs); + bool ret = google::protobuf::TextFormat::Parse(&input, message); + GE_IF_BOOL_EXEC( + !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); + + return ret; +} + +/// +/// @brief get the Original Type of FrameworkOp +/// @param [in] node +/// @param [out] type +/// @return Status +/// +Status GetOriginalType(const ge::NodePtr &node, string &type) { + GE_CHECK_NOTNULL(node); + type = node->GetType(); + GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); + GE_CHECK_NOTNULL(node->GetOpDesc()); + bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); + if (!ret) { + GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str()); + return INTERNAL_ERROR; + } + GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) { + char ebuff[kMaxBuffSize]; + regex_t reg; + int cflags = REG_EXTENDED | REG_NOSUB; + int ret = regcomp(®, mode.c_str(), cflags); + if (ret) { + regerror(ret, ®, ebuff, kMaxBuffSize); + GELOGW("regcomp failed, reason: %s", ebuff); + regfree(®); + return true; + } + + ret = regexec(®, str.c_str(), 0, nullptr, 0); + if (ret) { + regerror(ret, ®, ebuff, kMaxBuffSize); + GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff); + regfree(®); + return false; + } + + regfree(®); + return true; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() { + std::time_t now = std::time(nullptr); + std::tm *ptm = std::localtime(&now); + if (ptm == nullptr) { + GELOGE(ge::FAILED, "Localtime failed."); + return ""; + } + + const int kTimeBufferLen = 32; + char buffer[kTimeBufferLen + 1] = {0}; + // format: 20171122042550 + std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm); + return std::string(buffer); +} +} // namespace parser +} // namespace ge diff --git a/parser/common/acl_graph_parser_util.h b/parser/common/acl_graph_parser_util.h new file mode 100644 index 0000000..fad1c15 --- /dev/null +++ b/parser/common/acl_graph_parser_util.h @@ -0,0 +1,161 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef ACL_GRAPH_PARSE_UTIL_ +#define ACL_GRAPH_PARSE_UTIL_ + +#include +#include +#include +#include + +#include "framework/omg/parser/parser_types.h" +#include "register/register_error_codes.h" +#include "graph/utils/graph_utils.h" + +namespace ge { + +using google::protobuf::Message; + +class AclGrphParseUtil { + public: + AclGrphParseUtil() {} + virtual ~AclGrphParseUtil() {} + domi::Status LoadOpsProtoLib(); + void SaveCustomCaffeProtoPath(); + domi::Status AclParserInitialize(const std::map &options); + domi::Status SetDefaultOutputNode(ge::Graph &graph); + + private: + bool parser_initialized = false; + domi::Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info); + void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name); +}; + +namespace parser { +/// +/// @ingroup: domi_common +/// @brief: get length of file +/// @param [in] input_file: path of file +/// @return long: File length. If the file length fails to be obtained, the value -1 is returned. +/// +extern long GetFileLength(const std::string &input_file); + +/// +/// @ingroup domi_common +/// @brief Absolute path for obtaining files. +/// @param [in] path of input file +/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned +/// +std::string RealPath(const char *path); + +/// +/// @ingroup domi_common +/// @brief Obtains the absolute time (timestamp) of the current system. +/// @return Timestamp, in microseconds (US) +/// +/// +uint64_t GetCurrentTimestamp(); + +/// +/// @ingroup domi_common +/// @brief Reads all data from a binary file. +/// @param [in] file_name path of file +/// @param [out] buffer Output memory address, which needs to be released by the caller. +/// @param [out] length Output memory size +/// @return false fail +/// @return true success +/// +bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); + +/// +/// @ingroup domi_common +/// @brief proto file in bianary format +/// @param [in] file path of proto file +/// @param [out] proto memory for storing the proto file +/// @return true success +/// @return false fail +/// +bool ReadProtoFromBinaryFile(const char *file, Message *proto); + +/// +/// @ingroup domi_common +/// @brief Reads the proto structure from an array. +/// @param [in] data proto data to be read +/// @param [in] size proto data size +/// @param [out] proto Memory for storing the proto file +/// @return true success +/// @return false fail +/// +bool ReadProtoFromArray(const void *data, int size, Message *proto); + +/// +/// @ingroup domi_proto +/// @brief Reads the proto file in the text format. +/// @param [in] file path of proto file +/// @param [out] message Memory for storing the proto file +/// @return true success +/// @return false fail +/// +bool ReadProtoFromText(const char *file, google::protobuf::Message *message); + +bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); + +/// +/// @brief get the Original Type of FrameworkOp +/// @param [in] node +/// @param [out] type +/// @return Status +/// +domi::Status GetOriginalType(const ge::NodePtr &node, string &type); + +/// +/// @ingroup domi_common +/// @brief Check whether the file path meets the whitelist verification requirements. +/// @param [in] filePath file path +/// @param [out] result +/// +bool ValidateStr(const std::string &filePath, const std::string &mode); + +/// +/// @ingroup domi_common +/// @brief Obtains the current time string. +/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555 +/// +std::string CurrentTimeInStr(); +} // namespace parser +} // namespace ge + +/*lint --emacro((773),GE_TIMESTAMP_START)*/ +/*lint -esym(773,GE_TIMESTAMP_START)*/ +#define PARSER_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::parser::GetCurrentTimestamp() + +#define PARSER_TIMESTAMP_END(stage, stage_name) \ + do { \ + uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \ + GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ + (endUsec_##stage - startUsec_##stage)); \ + } while (0); + +#define PARSER_TIMESTAMP_EVENT_END(stage, stage_name) \ + do { \ + uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \ + GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ + (endUsec_##stage - startUsec_##stage)); \ + } while (0); + +#endif // ACL_GRAPH_PARSE_UTIL_ diff --git a/parser/common/convert/pb2json.cc b/parser/common/convert/pb2json.cc new file mode 100644 index 0000000..af13ed2 --- /dev/null +++ b/parser/common/convert/pb2json.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// File: pb2json.h +// Description: This imply file for protobuf message and json interconversion + +#include "common/convert/pb2json.h" +#include +#include +#include "securec.h" +#include "framework/common/fmk_types.h" +#include "framework/common/debug/ge_log.h" + +using std::set; +using std::string; + +namespace ge { +namespace { +const int kSignificantDigits = 10; +} +// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, + const set &black_fields, Json &json, + bool enum2str) { + auto descriptor = message.GetDescriptor(); + auto reflection = message.GetReflection(); + if (descriptor == nullptr || reflection == nullptr) { + return; + } + + auto count = descriptor->field_count(); + + for (auto i = 0; i < count; ++i) { + const auto field = descriptor->field(i); + if (field == nullptr) { + return; + } + + // Do not display weight data + if (black_fields.find(field->name()) != black_fields.end()) { + continue; + } + + if (field->is_repeated()) { + if (reflection->FieldSize(message, field) > 0) { + RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str); + } + continue; + } + + if (!reflection->HasField(message, field)) { + continue; + } + + OneField2Json(message, field, reflection, black_fields, json, enum2str); + } +} + +void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const set &black_fields, Json &json, + bool enum2str) { + switch (field->type()) { + case ProtobufFieldDescriptor::TYPE_MESSAGE: { + const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); + if (0 != tmp_message.ByteSize()) { + Message2Json(tmp_message, black_fields, json[field->name()], enum2str); + } + break; + } + + case ProtobufFieldDescriptor::TYPE_BOOL: + json[field->name()] = reflection->GetBool(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_ENUM: { + auto *enum_value_desc = reflection->GetEnum(message, field); + Enum2Json(enum_value_desc, field, enum2str, json); + break; + } + + case ProtobufFieldDescriptor::TYPE_INT32: + case ProtobufFieldDescriptor::TYPE_SINT32: + case ProtobufFieldDescriptor::TYPE_SFIXED32: + json[field->name()] = reflection->GetInt32(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_UINT32: + case ProtobufFieldDescriptor::TYPE_FIXED32: + json[field->name()] = reflection->GetUInt32(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_INT64: + case ProtobufFieldDescriptor::TYPE_SINT64: + case ProtobufFieldDescriptor::TYPE_SFIXED64: + json[field->name()] = reflection->GetInt64(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_UINT64: + case ProtobufFieldDescriptor::TYPE_FIXED64: + json[field->name()] = reflection->GetUInt64(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_FLOAT: + char str[kSignificantDigits]; + if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){ + json[field->name()] = str; + } else { + json[field->name()] = reflection->GetFloat(message, field); + } + + break; + + case ProtobufFieldDescriptor::TYPE_STRING: + json[field->name()] = reflection->GetString(message, field); + break; + + case ProtobufFieldDescriptor::TYPE_BYTES: { + string field_name = field->name(); + string type_bytes = reflection->GetString(message, field); + json[field_name] = TypeBytes2String(field_name, type_bytes); + break; + } + + default: + break; + } +} + +string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { + if (field_name != "offset") { + return type_bytes; + } + string result = ""; + for (char temp_value : type_bytes) { + uint8_t *value = 0; + value = reinterpret_cast(&temp_value); + char str[kSignificantDigits]; + if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){ + GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); + continue; + } + result += str; + } + return result; +} + +void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const set &black_fields, Json &json, + bool enum2str) { + if ((field == nullptr) || (reflection == nullptr)) { + Message2Json(message, black_fields, json, enum2str); + return; + } + + for (auto i = 0; i < reflection->FieldSize(message, field); ++i) { + Json tmp_json; + switch (field->type()) { + case ProtobufFieldDescriptor::TYPE_MESSAGE: { + const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); + if (0 != tmp_message.ByteSize()) { + Message2Json(tmp_message, black_fields, tmp_json, enum2str); + } + } break; + + case ProtobufFieldDescriptor::TYPE_BOOL: + tmp_json = reflection->GetRepeatedBool(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_ENUM: { + auto *enum_value_desc = reflection->GetRepeatedEnum(message, field, i); + RepeatedEnum2Json(enum_value_desc, enum2str, tmp_json); + } break; + + case ProtobufFieldDescriptor::TYPE_INT32: + case ProtobufFieldDescriptor::TYPE_SINT32: + case ProtobufFieldDescriptor::TYPE_SFIXED32: + tmp_json = reflection->GetRepeatedInt32(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_UINT32: + case ProtobufFieldDescriptor::TYPE_FIXED32: + tmp_json = reflection->GetRepeatedUInt32(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_INT64: + case ProtobufFieldDescriptor::TYPE_SINT64: + case ProtobufFieldDescriptor::TYPE_SFIXED64: + tmp_json = reflection->GetRepeatedInt64(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_UINT64: + case ProtobufFieldDescriptor::TYPE_FIXED64: + tmp_json = reflection->GetRepeatedUInt64(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_FLOAT: + tmp_json = reflection->GetRepeatedFloat(message, field, i); + break; + + case ProtobufFieldDescriptor::TYPE_STRING: + case ProtobufFieldDescriptor::TYPE_BYTES: + tmp_json = reflection->GetRepeatedString(message, field, i); + break; + + default: + break; + } + json += tmp_json; + } +} + +void Pb2Json::Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, + bool enum2str, Json &json) { + if (enum_value_desc != nullptr) { + if (field == nullptr) { + return; + } + if (enum2str) { + json[field->name()] = enum_value_desc->name(); + } else { + json[field->name()] = enum_value_desc->number(); + } + } +} + +void Pb2Json::RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json) { + if (enum_value_desc != nullptr) { + if (enum2str) { + json = enum_value_desc->name(); + } else { + json = enum_value_desc->number(); + } + } +} +} // namespace ge diff --git a/parser/common/convert/pb2json.h b/parser/common/convert/pb2json.h new file mode 100644 index 0000000..7bc55b1 --- /dev/null +++ b/parser/common/convert/pb2json.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// File: pb2json.h +// Description: This header file for protobuf message and json interconversion + +#ifndef PARSER_COMMON_CONVERT_PB2JSON_H_ +#define PARSER_COMMON_CONVERT_PB2JSON_H_ +#include +#include +#include +#include +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "nlohmann/json.hpp" + +namespace ge { +using Json = nlohmann::json; +using ProtobufMsg = ::google::protobuf::Message; +using ProtobufReflection = ::google::protobuf::Reflection; +using ProtobufFieldDescriptor = ::google::protobuf::FieldDescriptor; +using ProtobufDescriptor = ::google::protobuf::Descriptor; +using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor; + +class Pb2Json { + public: + /** + * @ingroup domi_omg + * @brief Transfer protobuf object to JSON object + * @param [out] json Converted JSON object + * @return void success + * @author + */ + static void Message2Json(const ProtobufMsg &message, const std::set &black_fields, Json &json, + bool enum2str = false); + + protected: + static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const std::set &black_fields, + Json &json, bool enum2str); + + static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, + bool enum2str, Json &json); + + static void RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json); + + static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const std::set &black_fields, Json &json, + bool enum2str); + + static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); +}; +} // namespace ge + +#endif // PARSER_COMMON_CONVERT_PB2JSON_H_ diff --git a/parser/common/data_op_parser.cc b/parser/common/data_op_parser.cc new file mode 100644 index 0000000..525c7c9 --- /dev/null +++ b/parser/common/data_op_parser.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/data_op_parser.h" +#include +#include "common/debug/log.h" +#include "common/op/ge_op_utils.h" +#include "common/math/math_util.h" +#include "common/util.h" +#include "graph/utils/type_utils.h" +#include "omg/omg.h" + +using namespace cce; +namespace { +const int kDataMemAlignSize = 32; +const int kTwoTimesAlign = 2; +const int kDynamicBatchInputSize = -1; +const uint32_t kScalarLength = 1; +} // namespace + +namespace ge { +FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector &shape, ge::OpDescPtr op) { + GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "ParseShape failed for data_op, op is null"); + + const string &data_op_name = op->GetName(); + GetParserContext().input_dims.emplace(data_op_name, shape); + + int64_t attr_type = 0; + ge::DataType data_type; + if (ge::AttrUtils::GetInt(op, ge::DATA_ATTR_NAME_DATA_TYPE, attr_type)) { + data_type = static_cast(attr_type); + } else { + data_type = ge::DT_FLOAT; + } + + // convert input + vector def_format_shape(shape); + + ge::GeTensorDesc i_tensor_desc; + ge::GeTensorDesc o_tensor_desc; + const unordered_map &input_nodes_format_map = GetParserContext().input_nodes_format_map; + auto map_iter = input_nodes_format_map.find(data_op_name); + if (map_iter != input_nodes_format_map.end() && map_iter->second == domi::DOMI_TENSOR_NC1HWC0) { + // Input 5D NC1HWC0 + GE_RETURN_WITH_LOG_IF_ERROR(Init5DInputTensor(def_format_shape, i_tensor_desc), "InitInputTensor failed"); + // Output + GE_RETURN_WITH_LOG_IF_ERROR(Init5DOutputTensor(def_format_shape, o_tensor_desc), "InitOutputTensor failed"); + } else { + // No need to consider AIPP here, + // The adjustdatanodedesc function of model_builder will process the + // input_desc and output_desc of AIPP's data node. + // Without AIPP, the data of input float is kept in cctranstensor implementation. + // The cast operator can not run in the pvmodel simulation environment, + // so the input data conversion processing maintains the original state. + // To be modified after AICPU operators support pvmodel. + if (data_type == ge::DT_FLOAT) { + // Input + GE_RETURN_WITH_LOG_IF_ERROR(InitInputTensor(def_format_shape, i_tensor_desc), "InitInputTensor failed"); + // Output + GE_RETURN_WITH_LOG_IF_ERROR(InitOutputTensor(def_format_shape, o_tensor_desc), "InitOutputTensor failed"); + } else { + // Input + GE_RETURN_WITH_LOG_IF_ERROR(InitNDTensor(def_format_shape, data_type, i_tensor_desc), + "Init ND InputTensor failed"); + // Output + GE_RETURN_WITH_LOG_IF_ERROR(InitNDTensor(def_format_shape, data_type, o_tensor_desc), + "Init ND Output Tensor failed"); + } + } + i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); + i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); + o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); + if (op->AddInputDesc(i_tensor_desc) != ge::GRAPH_SUCCESS) { + GELOGE(domi::INTERNAL_ERROR, "AddInputDesc failed for op %s.", op->GetName().c_str()); + return FAILED; + } + if (op->AddOutputDesc(o_tensor_desc) != ge::GRAPH_SUCCESS) { + GELOGE(domi::INTERNAL_ERROR, "AddOutputDesc failed for op %s.", op->GetName().c_str()); + return FAILED; + } + return SUCCESS; +} + +Status DataOpParser::Init5DInputTensor(const vector &shape, ge::GeTensorDesc &tensor_desc) { + tensor_desc.SetDataType(ge::DT_FLOAT16); + tensor_desc.SetFormat(static_cast(domi::DOMI_TENSOR_NC1HWC0)); + ge::TensorUtils::SetReuseInput(tensor_desc, false); + ge::TensorUtils::SetRealDimCnt(tensor_desc, shape.size()); + tensor_desc.SetShape(ge::GeShape(shape)); + + int64_t tensor_size = 0; + ge::graphStatus graph_status = ge::TensorUtils::GetTensorSizeInBytes(tensor_desc, tensor_size); + if (graph_status != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "GetTensorSizeInBytes failed!"); + return domi::FAILED; + } + // Set the actual occupied space size + ge::TensorUtils::SetSize(tensor_desc, tensor_size); + return SUCCESS; +} + +Status DataOpParser::InitNDTensor(const vector &shape, ge::DataType data_type, ge::GeTensorDesc &tensor_desc) { + // Fixed input ND + tensor_desc.SetFormat(static_cast(DOMI_TENSOR_ND)); + tensor_desc.SetDataType(data_type); + tensor_desc.SetOriginDataType(data_type); + ge::TensorUtils::SetReuseInput(tensor_desc, false); + ge::TensorUtils::SetRealDimCnt(tensor_desc, shape.size()); + tensor_desc.SetShape(ge::GeShape(shape)); + tensor_desc.SetOriginShape(ge::GeShape(shape)); + + int64_t size = kScalarLength; + if (!tensor_desc.GetShape().GetDims().empty()) { + size = tensor_desc.GetShape().GetShapeSize(); + } + uint32_t type_size = 0; + if (ge::TypeUtils::GetDataTypeLength(data_type, type_size)) { + FMK_INT64_UINT32_MULCHECK(size, type_size); + size *= type_size; + } else { + FMK_INT64_UINT32_MULCHECK(size, static_cast(sizeof(float))); + size *= sizeof(float); + } + ge::TensorUtils::SetSize(tensor_desc, size); + return SUCCESS; +} + +Status DataOpParser::Init5DOutputTensor(const vector &shape, ge::GeTensorDesc &output) { + output.SetDataType(ge::DT_FLOAT16); + output.SetFormat(static_cast(domi::DOMI_TENSOR_NC1HWC0)); + ge::TensorUtils::SetReuseInput(output, false); + ge::TensorUtils::SetRealDimCnt(output, shape.size()); + output.SetShape(ge::GeShape(shape)); + + int64_t output_size = 0; + ge::graphStatus graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(output, output_size); + if (graph_status != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "GetTensorMemorySizeInBytes failed!"); + return domi::FAILED; + } + // Set the actual occupied space size + ge::TensorUtils::SetSize(output, output_size); + return SUCCESS; +} + +Status DataOpParser::InitInputTensor(const vector &shape, ge::GeTensorDesc &input) { + input.SetFormat(static_cast(domiTensorFormat_t(DOMI_TENSOR_ND))); + input.SetDataType(ge::DT_FLOAT); + input.SetOriginDataType(ge::DT_FLOAT); + ge::TensorUtils::SetReuseInput(input, false); + + input.SetShape(ge::GeShape(shape)); + input.SetOriginShape(ge::GeShape(shape)); + int64_t size = 0; + // No need to check dynamic_batch_size since its first dim is -1. + if (input.GetShape().GetDim(0) != -1) { + size = input.GetShape().GetShapeSize(); + } + FMK_INT64_UINT32_MULCHECK(size, static_cast(sizeof(float))); + ge::TensorUtils::SetSize(input, size * sizeof(float)); + + return SUCCESS; +} + +Status DataOpParser::InitOutputTensor(const vector &shape, ge::GeTensorDesc &output) { + int64_t output_size = 0; + ge::GeShape output_shape = ge::GeShape(shape); + ge::Format format = ge::FORMAT_ND; + ge::DataType data_type = ge::DT_FLOAT; + output.SetFormat(format); + output.SetDataType(data_type); + ge::TensorUtils::SetReuseInput(output, false); + ge::TensorUtils::SetRealDimCnt(output, shape.size()); + output.SetShape(output_shape); + + ge::graphStatus graph_status = ge::TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_size); + if (graph_status != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "CalcTensorMemSize failed!"); + return FAILED; + } + + if (output_size == kDynamicBatchInputSize) { + GELOGI("After calc tensor memory size, output_mem_size = %ld", output_size); + return SUCCESS; + } + + int64_t size = output_size; + auto valid_max_size = INT64_MAX - kTwoTimesAlign * kDataMemAlignSize; + if (size > valid_max_size || size < 0) { + GELOGE(FAILED, "The updated mem size is out of data range [0, %ld]", valid_max_size); + return FAILED; + } else { + size = ((size + kTwoTimesAlign * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; + } + // Set the actual occupied space size + ge::TensorUtils::SetSize(output, size); + return SUCCESS; +} +} // namespace ge diff --git a/parser/common/data_op_parser.h b/parser/common/data_op_parser.h new file mode 100644 index 0000000..53bab18 --- /dev/null +++ b/parser/common/data_op_parser.h @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_DATA_OP_PARSER_H_ +#define PARSER_COMMON_DATA_OP_PARSER_H_ + +#include +#include +#include "common/debug/log.h" +#include "common/op/attr_value_util.h" +#include "framework/omg/parser/parser_types.h" +#include "omg/omg_inner_types.h" +#include "proto/om.pb.h" + +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/tensor_utils.h" + +using google::protobuf::Message; +using std::vector; + +namespace ge { +/** + * @ingroup domi_omg + * @brief Provide a public interface for DataOp + * + */ +class DataOpParser { + public: + virtual ~DataOpParser() {} + + protected: + /** + * @ingroup domi_omg + * @brief parser the Shape information of DataOp + * @param [in] shape 4D shape information (dimensions) + * @param [out] op Save converted shap information + * @return SUCCESS Parsing success + * @return FAILED Parsing failed + */ + static Status ParseShape(const vector &shape, ge::OpDescPtr op); + + private: + /** + * @ingroup domi_omg + * @brief Convert Input's Shape Information + * @param [in] 4D shape information (dimensions) + * @param [out] Save converted shap information + */ + static Status Init5DInputTensor(const vector &shape, ge::GeTensorDesc &tensorDesc); + + /** + * @ingroup domi_omg + * @brief Convert Shape of Output + * @param [in] shape 4D shape information (dimensions) + * @param [out] output Save converted shap information + * @return SUCCESS Convert success + * @return FAILED Convert failed + */ + static Status Init5DOutputTensor(const vector &shape, ge::GeTensorDesc &output); + + /** + * @ingroup domi_omg + * @brief 4D shape information (dimensions)4D shape information (dimensions)4D shape information (dimensions) + * @param [in] 4D shape information (dimensions) + * @param [out] input Save converted shap information + */ + static Status InitInputTensor(const vector &shape, ge::GeTensorDesc &input); + + /** + * @ingroup domi_omg + * @brief Convert Shape of Output + * @param [in] shape 4D shape information (dimensions) + * @param [out] output Save converted shap information + * @return SUCCESS Convert success + * @return FAILED Convert failed + */ + static Status InitOutputTensor(const vector &shape, ge::GeTensorDesc &output); + + /** + * @ingroup domi_omg + * @brief Convert Shape of Output + * @param [in] shape 4D shape information (dimensions) + * @param [out] output Save converted shap information + * @return SUCCESS Convert success + * @return FAILED Convert failed + */ + static Status InitNDTensor(const vector &shape, ge::DataType data_type, ge::GeTensorDesc &desc); +}; +} // namespace ge + +#endif // PARSER_COMMON_DATA_OP_PARSER_H_ \ No newline at end of file diff --git a/parser/common/model_saver.cc b/parser/common/model_saver.cc new file mode 100644 index 0000000..fc810ad --- /dev/null +++ b/parser/common/model_saver.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "parser/common/model_saver.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "common/util/error_manager/error_manager.h" +#include "mmpa/mmpa_api.h" + +namespace { +const int kFileOpSuccess = 0; +} // namespace + +namespace ge { +namespace parser { +const uint32_t kInteval = 2; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFile(const char *file_path, + const Json &model) { + Status ret = SUCCESS; + if (file_path == nullptr || SUCCESS != CheckPath(file_path)) { + GELOGE(FAILED, "Check output file failed."); + return FAILED; + } + std::string model_str; + try { + model_str = model.dump(kInteval, ' ', false, Json::error_handler_t::ignore); + } catch (std::exception &e) { + ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); + GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what()); + return FAILED; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E19008"); + GELOGE(FAILED, "Failed to convert JSON to string."); + return FAILED; + } + + char real_path[PATH_MAX] = {0}; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX, return FAILED, "file path is too long!"); + if (realpath(file_path, real_path) == nullptr) { + GELOGI("File %s does not exit, it will be created.", file_path); + } + + // Open file + mode_t mode = S_IRUSR | S_IWUSR; + int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode); + if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)}); + GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno)); + return FAILED; + } + const char *model_char = model_str.c_str(); + uint32_t len = static_cast(model_str.length()); + // Write data to file + mmSsize_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); + if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); + // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose + GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); + ret = FAILED; + } + // Close file + if (mmClose(fd) != EN_OK) { + GELOGE(FAILED, "Close file failed."); + ret = FAILED; + } + return ret; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::CheckPath(const std::string &file_path) { + // Determine file path length + if (file_path.size() >= PATH_MAX) { + GELOGE(FAILED, "Path is too long:%zu", file_path.size()); + return FAILED; + } + + // Find the last separator + int path_split_pos = static_cast(file_path.size() - 1); + for (; path_split_pos >= 0; path_split_pos--) { + if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') { + break; + } + } + + if (path_split_pos == 0) { + return SUCCESS; + } + + // If there is a path before the file name, create the path + if (path_split_pos != -1) { + if (CreateDirectory(std::string(file_path).substr(0, static_cast(path_split_pos))) != kFileOpSuccess) { + GELOGE(FAILED, "CreateDirectory failed, file path:%s.", file_path.c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int ModelSaver::CreateDirectory(const std::string &directory_path) { + GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); + auto dir_path_len = directory_path.length(); + if (dir_path_len >= PATH_MAX) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19002", {"filepath", "size"}, {directory_path, std::to_string(PATH_MAX)}); + GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), PATH_MAX); + return -1; + } + char tmp_dir_path[PATH_MAX] = {0}; + for (size_t i = 0; i < dir_path_len; i++) { + tmp_dir_path[i] = directory_path[i]; + if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) { + if (access(tmp_dir_path, F_OK) != 0) { + int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 + if (ret != 0) { + if (errno != EEXIST) { + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", + directory_path.c_str()); + return ret; + } + } + } + } + } + int32_t ret = mmMkdir(const_cast(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 + if (ret != 0) { + if (errno != EEXIST) { + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); + return ret; + } + } + return 0; +} + +} // namespace parser +} // namespace ge \ No newline at end of file diff --git a/parser/common/model_saver.h b/parser/common/model_saver.h new file mode 100644 index 0000000..bc31dba --- /dev/null +++ b/parser/common/model_saver.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_FILE_SAVER_H_ +#define PARSER_COMMON_FILE_SAVER_H_ + +#include + +#include "ge/ge_api_error_codes.h" +#include "register/register_types.h" +#include "nlohmann/json.hpp" + +namespace ge { +namespace parser { +using Json = nlohmann::json; +using std::string; + +class ModelSaver { +public: + /** + * @ingroup domi_common + * @brief Save JSON object to file + * @param [in] file_path File output path + * @param [in] model json object + * @return Status result + */ + static Status SaveJsonToFile(const char *file_path, const Json &model); + +private: + /// + /// @ingroup domi_common + /// @brief Check validity of the file path + /// @return Status result + /// + static Status CheckPath(const string &file_path); + + static int CreateDirectory(const std::string &directory_path); +}; +} // namespace parser +} // namespace ge + +#endif //PARSER_COMMON_FILE_SAVER_H_ diff --git a/parser/common/module.mk b/parser/common/module.mk new file mode 100644 index 0000000..5a567c0 --- /dev/null +++ b/parser/common/module.mk @@ -0,0 +1,95 @@ +LOCAL_PATH := $(call my-dir) + +include $(CLEAR_VARS) + +LOCAL_MODULE := libparser_common + +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 +LOCAL_CFLAGS += -Werror +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + +COMMON_LOCAL_SRC_FILES := \ + parser_factory.cc \ + data_op_parser.cc \ + op_parser_factory.cc \ + pre_checker.cc \ + register_tbe.cc \ + parser_api.cc \ + parser_inner_ctx.cc \ + proto_file_parser.cc \ + acl_graph_parser_util.cc \ + tbe_plugin_loader.cc \ + model_saver.cc \ + ../tensorflow/tensorflow_custom_parser_adapter.cc \ + ../tensorflow/tensorflow_fusion_custom_parser_adapter.cc \ + ../tensorflow/tensorflow_fusion_op_parser.cc \ + ../tensorflow/tensorflow_util.cc \ + convert/pb2json.cc \ + op_def/ir_pb_converter.cc \ + op_def/defs.cc \ + op_def/op_schema.cc \ + op_def/operator.cc \ + op_map.cc \ + parser_types.cc \ + pass_manager.cc \ + parser_fp16_t.cc \ + thread_pool.cc \ + +FMK_COMMON_SRC_FILES := \ +# ../../common/fmk_error_codes.cc \ + ../../common/auth/cipher.cc \ + ../../common/context/ctx.cc \ + ../../graph/passes/pass_manager.cc \ + ../../graph/common/omg_util.cc \ + ../../common/types.cc \ + ../../common/auth/file_saver.cc \ + ../../common/util.cc \ + ../../common/model_saver.cc \ + ../../common/fp16_t.cc \ + ../../common/thread_pool.cc \ + +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) +LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES) + +LOCAL_C_INCLUDES := \ + proto/om.proto \ + proto/insert_op.proto \ + proto/ge_ir.proto \ + proto/tensorflow/graph.proto \ + proto/tensorflow/node_def.proto \ + proto/tensorflow/tensor_shape.proto \ + proto/tensorflow/attr_value.proto \ + proto/tensorflow/function.proto \ + proto/tensorflow/op_def.proto \ + proto/tensorflow/resource_handle.proto \ + proto/tensorflow/tensor.proto \ + proto/tensorflow/types.proto \ + proto/tensorflow/versions.proto \ + $(LOCAL_PATH) \ + $(TOPDIR)inc \ + $(TOPDIR)inc/external \ + $(TOPDIR)inc/external/graph \ + $(TOPDIR)inc/framework \ + $(TOPDIR)inc/common/util \ + $(TOPDIR)framework/domi \ + $(TOPDIR)framework/domi/common \ + $(TOPDIR)framework/domi/parser \ + $(TOPDIR)third_party/json/include \ + $(TOPDIR)third_party/protobuf/include \ + libc_sec/include \ + third_party/openssl/include/x86/include \ + +LOCAL_SHARED_LIBRARIES := \ + libprotobuf \ + libslog \ + libgraph \ + libmmpa \ + libc_sec \ + liberror_manager \ + libregister \ + +LOCAL_LDFLAGS := -lrt -ldl + +include $(BUILD_HOST_SHARED_LIBRARY) diff --git a/parser/common/op_def/arg_op.cc b/parser/common/op_def/arg_op.cc new file mode 100644 index 0000000..7dda8b1 --- /dev/null +++ b/parser/common/op_def/arg_op.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/op_def/arg_op.h" +#include +#include "framework/common/fmk_types.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator::ArgOpOperator() : ParserOperator("Data") {} + +ArgOpOperator::~ArgOpOperator() {} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator &ArgOpOperator::Name(const std::string &name) { + (void)ParserOperator::Name(name); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator &ArgOpOperator::Index(int64_t index) { + Attr("index", static_cast(index)); + + return *this; +} + +int64_t ArgOpOperator::GetIndex() const { return GetIntAttr("index"); } +} // namespace ge diff --git a/parser/common/op_def/arg_op.h b/parser/common/op_def/arg_op.h new file mode 100644 index 0000000..b867a91 --- /dev/null +++ b/parser/common/op_def/arg_op.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DOMI_OP_ARG_OP_H_ +#define DOMI_OP_ARG_OP_H_ +#include "parser/common/op_def/operator.h" + +namespace ge { +class ArgOpOperator : public ParserOperator { + public: + ArgOpOperator(); + + ~ArgOpOperator(); + + ArgOpOperator &Name(const std::string &name); + + ArgOpOperator &Index(int64_t index); + + int64_t GetIndex() const; +}; +} // namespace ge + +#endif // DOMI_OP_ARG_OP_H_ \ No newline at end of file diff --git a/parser/common/op_def/constant_op.cc b/parser/common/op_def/constant_op.cc new file mode 100644 index 0000000..a9993f7 --- /dev/null +++ b/parser/common/op_def/constant_op.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/op_def/constant_op.h" +#include +#include + +#include "graph/debug/ge_attr_define.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator::ConstantOperator() : ParserOperator("Constant") {} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator::~ConstantOperator() {} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::Name(const std::string &name) { + ParserOperator::Name(name); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::VectorAttr( + std::string key, std::vector &value) { + Attr(key, value); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::DType(ge::DataType t) { + Attr(VAR_ATTR_DTYPE, (int64_t)t); + return *this; +} + +ge::DataType ConstantOperator::GetDType() const { return (ge::DataType)GetIntAttr(VAR_ATTR_DTYPE); } +} // namespace ge diff --git a/parser/common/op_def/constant_op.h b/parser/common/op_def/constant_op.h new file mode 100644 index 0000000..29549e5 --- /dev/null +++ b/parser/common/op_def/constant_op.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#ifndef DOMI_OP_CONSTANT_OP_H_ +#define DOMI_OP_CONSTANT_OP_H_ +#include "parser/common/op_def/operator.h" +#include "framework/omg/parser/parser_types.h" + +namespace ge { +class ConstantOperator : public ParserOperator { + public: + ConstantOperator(); + ~ConstantOperator(); + + ConstantOperator &Name(const std::string &name); + ConstantOperator &VectorAttr(std::string key, std::vector &value); + + ConstantOperator &DType(ge::DataType t); + ge::DataType GetDType() const; +}; +} // namespace ge + +#endif // DOMI_OP_CONSTANT_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT diff --git a/parser/common/op_def/defs.cc b/parser/common/op_def/defs.cc new file mode 100644 index 0000000..350710e --- /dev/null +++ b/parser/common/op_def/defs.cc @@ -0,0 +1,712 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/op_def/op_schema.h" + +namespace ge { +DOMI_OP_SCHEMA(Data).Output("y"); + +DOMI_OP_SCHEMA(Const).Output("y"); + +DOMI_OP_SCHEMA(ConvolutionDepthwise) + .Input("x") + .Input("w") + .Input("b", OpSchema::Optional) + .Output("y") + .Attr("group", AttributeType::INT, static_cast(1)) + .Attr("num_output", AttributeType::INT, static_cast(1)) + .Attr("pad_mode", AttributeType::INT, static_cast(0)) + .Attr("mode", AttributeType::INT, static_cast(1)) + .Attr("pad", AttributeType::INTLIST, IntTuple{0, 0, 0, 0}) + .Attr("stride", AttributeType::INTLIST, IntTuple{1, 1}) + .Attr("dilation", AttributeType::INTLIST, IntTuple{1, 1}) + .Attr("kernel", AttributeType::INTLIST, IntTuple{0, 0}) + .Attr("before_pad", AttributeType::INTLIST, IntTuple{0, 0, 0, 0}); + +DOMI_OP_SCHEMA(Region) + .Input("x") + .Output("y") + .Attr("casses", AttributeType::INT, static_cast(20)) + .Attr("coords", AttributeType::INT, static_cast(4)) + .Attr("boxes", AttributeType::INT, static_cast(1)) + .Attr("background", AttributeType::BOOL, static_cast(false)) + .Attr("softmax", AttributeType::BOOL, static_cast(false)) + .Attr("softmax_tree", AttributeType::BOOL, static_cast(false)) + .Attr("yolo_version", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(Gather) + .Input("params") + .Input("indices") + .Input("axis", OpSchema::Optional) + .Output("y") + .Attr("params_type", AttributeType::INT, static_cast(1)) + .Attr("indices_type", AttributeType::INT, static_cast(3)) + .Attr("validate_indices", AttributeType::BOOL, static_cast(true)); + +DOMI_OP_SCHEMA(ArgMax) + .Input("input") + .Output("output") + .Attr("axis", AttributeType::INT, static_cast(0)) + .Attr("keep_dims", AttributeType::BOOL, static_cast(true)) + .Attr("axis_type", AttributeType::INT, static_cast(3)) + .Attr("outmaxval", AttributeType::BOOL, static_cast(false)) + .Attr("topk", AttributeType::UINT, static_cast(1)); + +DOMI_OP_SCHEMA(Split) + .Input("x") + .Input("axis", OpSchema::Optional) + .Output("y") + .Attr("T", AttributeType::INT, static_cast(1)) + .Attr("num_split", AttributeType::INT, static_cast(1)); + +DOMI_OP_SCHEMA(SplitV) + .Input("x") + .Input("axis", OpSchema::Optional) + .Output("y") + .Attr("T", AttributeType::INT, static_cast(1)) + .Attr("Tlen", AttributeType::INT, static_cast(1)) + .Attr("num_split", AttributeType::INT, static_cast(1)); + +DOMI_OP_SCHEMA(Fill).Input("x").Input("value").Output("y").Attr("T", AttributeType::INT, static_cast(1)); +DOMI_OP_SCHEMA(Rsqrt).Input("x").Output("y"); +DOMI_OP_SCHEMA(BiasAdd) + .Input("x") + .Input("bias") + .Output("y") + .Attr("format", AttributeType::INT, static_cast(1)); +DOMI_OP_SCHEMA(Reverse) + .Input("x") + .Input("axis") + .Output("y") + .Attr("T", AttributeType::INT, static_cast(1)) + .Attr("Tidx", AttributeType::INT, static_cast(1)); +DOMI_OP_SCHEMA(Unpack) + .Input("x") + .Output("y") + .Attr("T", AttributeType::INT, static_cast(1)) + .Attr("axis", AttributeType::INT, static_cast(0)) + .Attr("num", AttributeType::INT, static_cast(1)); +DOMI_OP_SCHEMA(Yolo2Reorg) + .Input("x") + .Output("y") + .Attr("reverse", AttributeType::BOOL, static_cast(1)) + .Attr("stride", AttributeType::INT, static_cast(1)); + +DOMI_OP_SCHEMA(ReduceSum) + .Input("x") + .Output("y") + .Attr("Tidx", AttributeType::INT, static_cast(1)) + .Attr("keep_dims", AttributeType::BOOL, static_cast(1)); + +DOMI_OP_SCHEMA(Concat) + .Input("x") + .Output("y") + .Attr("Tidx", AttributeType::INT, static_cast(1)) + .Attr("N", AttributeType::INT, static_cast(1)); + +DOMI_OP_SCHEMA(ResizeBilinear) + .Input("x") + .Input("sizes") + .Output("y") + .Attr("output_dim_mode", AttributeType::INT, static_cast(1)) + .Attr("align_corners", AttributeType::BOOL, static_cast(1)) + .Attr("zoom_factor", AttributeType::INT, static_cast(1)) + .Attr("shrink_factor", AttributeType::INT, static_cast(1)) + .Attr("height", AttributeType::INT, static_cast(1)) + .Attr("width", AttributeType::INT, static_cast(1)) + .Attr("pad_begin", AttributeType::INT, static_cast(1)) + .Attr("pad_end", AttributeType::INT, static_cast(1)); + +DOMI_OP_SCHEMA(LRN) + .Input("x") + .Output("y") + .Attr("lrn_normregion", AttributeType::UINT, static_cast(0)) + .Attr("lrn_k", AttributeType::FLOAT, static_cast(1)) + .Attr("lrn_localsize", AttributeType::UINT, static_cast(5)) + .Attr("lrn_alpha", AttributeType::FLOAT, static_cast(1)) + .Attr("lrn_beta", AttributeType::FLOAT, static_cast(0.75)); + +DOMI_OP_SCHEMA(Maximum).Input("x").Input("w").Output("y"); + +DOMI_OP_SCHEMA(Slice) + .Input("x") + .Output("y") + .Attr("axis", AttributeType::INT, static_cast(2)) + .AttrRequired("offsets", AttributeType::INTLIST); + +DOMI_OP_SCHEMA(Pad) + .Input("x") + .Input("paddings") + .Input("constant_values", OpSchema::Optional) + .Output("y") + .Attr("T", AttributeType::INT, static_cast(1)) + .Attr("t_paddings", AttributeType::INT, static_cast(1)) + .Attr("mode", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(PadV2) + .Input("input") + .Output("output") + .Attr("constant_values", AttributeType::INT, static_cast(0)) + .AttrRequired("paddings", AttributeType::INTLIST); + +DOMI_OP_SCHEMA(MirrorPad) + .Input("input") + .Output("output") + .AttrRequired("paddings", AttributeType::INTLIST) + .Attr("mode", AttributeType::INT, static_cast(2)); + +DOMI_OP_SCHEMA(Upsample) + .Input("input") + .Input("scales") + .Output("output") + .Attr("mode", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(Cast) + .Input("x") + .Output("y") + .Attr("DstT", AttributeType::INT, static_cast(1)) + .Attr("SrcT", AttributeType::INT, static_cast(1)); +DOMI_OP_SCHEMA(LogicalNot).Input("x").Output("y"); +DOMI_OP_SCHEMA(LogicalAnd).Input("x1").Input("x2").Output("y"); +DOMI_OP_SCHEMA(LogicalOr).Input("x1").Input("x2").Output("y"); +DOMI_OP_SCHEMA(Equal).Input("x1").Input("x2").Output("y").Attr("T", AttributeType::INT, static_cast(1)); + +DOMI_OP_SCHEMA(MatMul) + .Input("a") + .Input("b") + .Output("product") + .Attr("transposeX", AttributeType::BOOL, static_cast(false)) + .Attr("transposeW", AttributeType::BOOL, static_cast(false)); + +DOMI_OP_SCHEMA(RNN) + .Input("x") + .Input("cont") + .Input("xstatic", OpSchema::Optional) + .Input("w") // filter + .Input("b") // bias + .Input("seqlen") // T + .Input("hx") // Hx + .Input("cx") // cx + .Output("y") + .Output("cyfw") + .Output("hyfw") + .Output("cybw") + .Output("hybw") + .Attr("hidden_size", AttributeType::INT, static_cast(0)) + .Attr("num_layers", AttributeType::INT, static_cast(1)) + .Attr("support_cont", AttributeType::BOOL, static_cast(false)) + .Attr("support_xstatic", AttributeType::BOOL, static_cast(false)) + .Attr("input_mode", AttributeType::INT, static_cast(0)) + .Attr("direction_mode", AttributeType::INT, static_cast(0)) + .Attr("mode", AttributeType::INT, static_cast(0)) + .Attr("input_data_layout", AttributeType::INT, static_cast(0)) + .Attr("output_data_layout", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(FrameworkOp).Attr("framework_type", AttributeType::INT, static_cast(3)); +DOMI_OP_SCHEMA(Multinomial) + .Input("logits") + .Output("output") + .Attr("num_samples", AttributeType::INT, static_cast(0)) + .AttrRequired("seed", AttributeType::INT) + .AttrRequired("seed2", AttributeType::INT); +DOMI_OP_SCHEMA(ReverseSequence) + .Input("input") + .Input("seq_lengths") + .Output("output") + .AttrRequired("seq_dim", AttributeType::INT) + .AttrRequired("batch_dim", AttributeType::INT); + +DOMI_OP_SCHEMA(Interp) + .Input("x") + .Output("y") + .Attr("output_dim_mode", AttributeType::INT, static_cast(2)) + .Attr("zoom_factor", AttributeType::INT, static_cast(1)) + .Attr("shrink_factor", AttributeType::INT, static_cast(1)) + .Attr("height", AttributeType::INT, static_cast(0)) + .Attr("width", AttributeType::INT, static_cast(0)) + .Attr("pad_begin", AttributeType::INT, static_cast(0)) + .Attr("pad_end", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(ShuffleChannel).Input("x").Output("y").Attr("group", AttributeType::UINT, static_cast(1)); + +DOMI_OP_SCHEMA(Conv2DBackpropFilter) + .Input("x") + .Input("w") + .Input("b", OpSchema::Optional) + .Output("y") + .Attr("padding", AttributeType::INT, static_cast(1)) + .Attr("pads", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0}) + .Attr("strides", AttributeType::UINTLIST, UintTuple{1, 1}) + .Attr("dilations", AttributeType::UINTLIST, UintTuple{1, 1}); + +DOMI_OP_SCHEMA(Conv2DBackpropInput) + .Input("input_sizes") + .Input("filter") + .Input("out_backprop") + .Output("output") + .Attr("data_format", AttributeType::STRING, static_cast("NHWC")) + .Attr("group", AttributeType::UINT, static_cast(1)) + .Attr("padding", AttributeType::INT, static_cast(0)) + .Attr("dilations", AttributeType::UINTLIST, UintTuple{1, 1}) + .Attr("strides", AttributeType::UINTLIST, UintTuple{1, 1}) + .Attr("pad", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0}); +DOMI_OP_SCHEMA(BiasAddGrad).Input("dy").Output("db").Attr("format", AttributeType::INT, static_cast(1)); +DOMI_OP_SCHEMA(ReluGrad).Input("dy").Input("x").Output("dx"); + +DOMI_OP_SCHEMA(MeanGrad).Input("dy").Output("dx"); + +DOMI_OP_SCHEMA(NonMaxSuppression) + .Input("boxes") + .Input("scores") + .Output("selected_indices") + .Attr("max_output_size", AttributeType::INT, static_cast(-1)) + .Attr("iou_threshold", AttributeType::FLOAT, static_cast(0.5)) + .Attr("score_threshold", AttributeType::FLOAT, static_cast(-1)); + +DOMI_OP_SCHEMA(CropAndResize) + .Input("image") + .Input("boxes") + .Input("box_ind") + .Output("crops") + .Attr("method", AttributeType::INT, static_cast(0)) + .Attr("extrapolation_value", AttributeType::FLOAT, static_cast(0)) + .Attr("crop_size_h", AttributeType::INT, static_cast(0)) + .Attr("crop_size_w", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(TopKV2) + .Input("input") + .Input("k") + .Output("value") + .Output("indices") + .AttrRequired("sorted", AttributeType::BOOL); + +DOMI_OP_SCHEMA(InvertPermutation).Input("x").Output("y"); + +DOMI_OP_SCHEMA(GatherV2) + .Input("params") + .Input("indices") + .Input("axis", OpSchema::Optional) + .Output("y") + .Attr("Tparams", AttributeType::INT, static_cast(0)) // default: DT_FLOAT + .Attr("Tindices", AttributeType::INT, static_cast(3)) // default: DT_INT32 + .Attr("Taxis", AttributeType::INT, static_cast(3)); // default: DT_INT32 + +DOMI_OP_SCHEMA(HighWay) + .Input("x") + .Input("tw") // filter + .Input("tb") // bias + .Input("uw") // filter + .Input("ub") // bias + .Output("y"); + +DOMI_OP_SCHEMA(Reciprocal).Input("x").Output("y"); + +DOMI_OP_SCHEMA(Asinh).Input("input").Output("output"); + +DOMI_OP_SCHEMA(Acosh).Input("input").Output("output"); + +DOMI_OP_SCHEMA(Minimum).Input("x").Input("y").Output("output"); + +DOMI_OP_SCHEMA(Clip).Input("input").Input("min").Input("max").Output("output"); + +DOMI_OP_SCHEMA(FusedBatchNorm) + .Input("x") + .Input("scale") + .Input("offset") + .Input("mean") + .Input("variance") + .Output("y") + .Output("batch_mean") + .Output("batch_variance") + .Output("reserve_space_1") + .Output("reserve_space_2") + .Attr("data_format", AttributeType::STRING, static_cast("NHWC")) + .Attr("epsilon", AttributeType::FLOAT, static_cast(0.0001)) + .Attr("is_training", AttributeType::BOOL, static_cast(false)); + +DOMI_OP_SCHEMA(FusedBatchNormGrad) + .Input("dy") + .Input("x") + .Input("bnscale") + .Input("save_mean") + .Input("save_variance") + .Output("dx") + .Output("result_bn_scale_diff") + .Output("result_bn_bias_diff") + .Attr("data_format", AttributeType::STRING, static_cast("NHWC")) + .Attr("epsilon", AttributeType::FLOAT, static_cast(0.0)) + .Attr("is_training", AttributeType::BOOL, static_cast(true)); + +DOMI_OP_SCHEMA(MaxPoolWithArgmax) + .Input("x") + .Output("y") + .Output("argmax") + .AttrRequired("window", AttributeType::INTLIST) + .AttrRequired("stride", AttributeType::INTLIST) + .AttrRequired("pad_mode", AttributeType::INT) + .AttrRequired("ceil_mode", AttributeType::BOOL) + .AttrRequired("data_mode", AttributeType::INT); + +DOMI_OP_SCHEMA(MaxPoolGradWithArgmax) + .Input("input") + .Input("grad") + .Output("output") + .AttrRequired("window", AttributeType::INTLIST) + .AttrRequired("stride", AttributeType::INTLIST) + .AttrRequired("pad_mode", AttributeType::INT) + .AttrRequired("ceil_mode", AttributeType::BOOL) + .AttrRequired("data_mode", AttributeType::INT); + +DOMI_OP_SCHEMA(HcomBroadcast) + .AttrRequired("root_rank", AttributeType::INT) + .AttrRequired("group", AttributeType::STRING); + +DOMI_OP_SCHEMA(HcomAllReduce) + .Input("x") + .Output("y") + .AttrRequired("reduction", AttributeType::STRING) + .AttrRequired("group", AttributeType::STRING); + +DOMI_OP_SCHEMA(HcomAllGather) + .Input("x") + .Output("y") + .AttrRequired("rank_size", AttributeType::INT) + .AttrRequired("group", AttributeType::STRING); + +DOMI_OP_SCHEMA(SparseSoftmaxCrossEntropyWithLogits) + .Input("features") + .Input("labels") + .Output("loss") + .Output("backprop") + .AttrRequired("T", AttributeType::INT) + .Attr("Tlabels", AttributeType::INT, static_cast(9)); + +DOMI_OP_SCHEMA(Snapshot).Input("input").Output("output").AttrRequired("T", AttributeType::INT); + +DOMI_OP_SCHEMA(ReduceProd) + .Input("bottom") + .Output("top") + .AttrRequired("axes", AttributeType::INTLIST) + .Attr("keep_dims", AttributeType::BOOL, static_cast(false)); + +DOMI_OP_SCHEMA(ReduceAll) + .Input("x") + .Output("y") + .AttrRequired("axes", AttributeType::INTLIST) + .Attr("keep_dims", AttributeType::BOOL, static_cast(false)); + +DOMI_OP_SCHEMA(ReduceMax) + .Input("x") + .Output("y") + .AttrRequired("axis", AttributeType::INTLIST) + .Attr("keep_dims", AttributeType::BOOL, static_cast(false)); + +DOMI_OP_SCHEMA(AddN).Input("x").Output("y"); + +DOMI_OP_SCHEMA(ShapeN) + .Input("x") + .Output("y") + .AttrRequired("N", AttributeType::INT) + .AttrRequired("in_type", AttributeType::INT) + .AttrRequired("dtype", AttributeType::INT); + +DOMI_OP_SCHEMA(ReduceMin) + .Input("x") + .Output("y") + .AttrRequired("axis", AttributeType::INTLIST) + .Attr("keep_dims", AttributeType::BOOL, static_cast(false)); + +DOMI_OP_SCHEMA(Sqrt).Input("x").Output("y"); + +DOMI_OP_SCHEMA(L2Loss).Input("x").Output("y"); + +DOMI_OP_SCHEMA(Multiply).Input("x").Input("y").Output("z"); + +DOMI_OP_SCHEMA(Add).Input("x").Output("y"); + +DOMI_OP_SCHEMA(Constant).Output("y"); + +DOMI_OP_SCHEMA(ApplyMomentum) + .Input("variable") + .Input("accumulation") + .Input("learningRate") + .Input("gradient") + .Input("momuntum") + .Input("fp16variable") + .Attr("algo", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(AvgPoolGrad) + .Input("shape") + .Input("grad") + .Output("output") + .Attr("padding", AttributeType::INT, static_cast(0)) + .Attr("data_format", AttributeType::STRING, static_cast("NHWC")) + .Attr("strides", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0}) + .Attr("ksize", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0}); + +DOMI_OP_SCHEMA(Lars) + .Input("w") + .Input("g") + .Input("weight_decay") + .Output("y") + .Attr("hyperpara", AttributeType::FLOAT, static_cast(0.001)) + .Attr("epsilon", AttributeType::FLOAT, static_cast(0.00001)); + +DOMI_OP_SCHEMA(AssignSub) + .Input("variable") + .Input("input") + .Input("output") + .Attr("mode", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(AssignAdd) + .Input("variable") + .Input("input") + .Output("output") + .Attr("mode", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(SpaceToBatchND).Input("input").Input("block_shape").Input("paddings").Output("output"); + +DOMI_OP_SCHEMA(Variable) + .Output("variable") + .Attr("container", AttributeType::STRING, static_cast("")) + .Attr("shared_name", AttributeType::STRING, static_cast("")) + .AttrRequired("dtype", AttributeType::INT); + +DOMI_OP_SCHEMA(Assign).Input("variable").Input("value").Output("variable"); + +DOMI_OP_SCHEMA(VarIsInitializedOp).Input("variable").Output("value"); + +DOMI_OP_SCHEMA(NoOp).Attr("algo", AttributeType::INT, static_cast(0)); + +DOMI_OP_SCHEMA(LogTimeStamp) + .Attr("logid", AttributeType::STRING, static_cast("")) + .Attr("notify", AttributeType::BOOL, static_cast(false)); + +DOMI_OP_SCHEMA(ResizeNearestNeighbor) + .Input("images") + .Output("resized_images") + .Attr("align_corners", AttributeType::BOOL, static_cast(false)) + .AttrRequired("height", AttributeType::INT) + .AttrRequired("width", AttributeType::INT); + +DOMI_OP_SCHEMA(BatchToSpaceND).Input("input").Input("block_shape").Input("crops").Output("output"); + +DOMI_OP_SCHEMA(Assert).Input("x").Input("w").Output("y"); + +DOMI_OP_SCHEMA(Pow).Input("x").Input("y").Output("z"); + +DOMI_OP_SCHEMA(GreaterEqual).Input("x1").Input("x2").Output("y"); + +DOMI_OP_SCHEMA(SpaceToDepth) + .Input("input") + .Output("output") + .Attr("block_size", AttributeType::INT, static_cast(0)) + .AttrRequired("T", AttributeType::INT) + .Attr("data_format", AttributeType::STRING, static_cast("NHWC")); + +DOMI_OP_SCHEMA(DepthToSpace) + .Input("input") + .Output("output") + .Attr("block_size", AttributeType::INT, static_cast(0)) + .AttrRequired("T", AttributeType::INT) + .Attr("data_format", AttributeType::STRING, static_cast("NHWC")); + +DOMI_OP_SCHEMA(Rint).Input("input").Output("output").AttrRequired("T", AttributeType::INT); + +DOMI_OP_SCHEMA(ExtractImagePatches) + .Input("images") + .Output("y") + .AttrRequired("ksizes", AttributeType::INTLIST) + .AttrRequired("strides", AttributeType::INTLIST) + .AttrRequired("rates", AttributeType::INTLIST) + .AttrRequired("padding", AttributeType::STRING); + +DOMI_OP_SCHEMA(Atan).Input("x").Output("output"); + +DOMI_OP_SCHEMA(Atanh).Input("x").Output("output"); + +DOMI_OP_SCHEMA(Acos).Input("x").Output("y"); + +DOMI_OP_SCHEMA(Asin).Input("x").Output("y"); + +DOMI_OP_SCHEMA(Log) + .Input("x") + .Output("output") + .AttrRequired("scale", AttributeType::INT) + .AttrRequired("shift", AttributeType::INT) + .AttrRequired("base", AttributeType::INT); + +DOMI_OP_SCHEMA(Neg).Input("input").Output("output"); + +DOMI_OP_SCHEMA(Tan).Input("x").Output("output"); + +DOMI_OP_SCHEMA(Round).Input("x").Output("output"); + +DOMI_OP_SCHEMA(Exp) + .Input("x") + .Output("y") + .Attr("scale", AttributeType::FLOAT, static_cast(1)) + .Attr("shift", AttributeType::FLOAT, static_cast(0)) + .Attr("base", AttributeType::FLOAT, static_cast(-1)); + +DOMI_OP_SCHEMA(Less).Input("x").Input("y").Output("output"); + +DOMI_OP_SCHEMA(LessEqual).Input("x").Input("y").Output("output"); + +DOMI_OP_SCHEMA(OneHot).Input("indices").Input("depth").Input("on_value").Input("off_value").Output("output"); + +DOMI_OP_SCHEMA(ZerosLike).Input("x").Output("y"); + +DOMI_OP_SCHEMA(Where).Input("x").Output("y"); + +DOMI_OP_SCHEMA(RefSwitch).Input("x").Output("y"); + +DOMI_OP_SCHEMA(FakeQuantWithMinMaxVars) + .Input("x") + .Input("min") + .Input("max") + .Output("y") + .Attr("narrow_range", AttributeType::BOOL, static_cast(false)) + .Attr("num_bits", AttributeType::INT, static_cast(8)); + +DOMI_OP_SCHEMA(Sinh).Input("x").Output("y"); + +DOMI_OP_SCHEMA(Cosh).Input("x").Output("y"); + +DOMI_OP_SCHEMA(Floor).Input("x").Output("output"); + +DOMI_OP_SCHEMA(RandomUniform).Input("input").Output("output"); + +DOMI_OP_SCHEMA(BatchMatMul).Input("x").Input("y").Output("output"); + +DOMI_OP_SCHEMA(FloorMod).Input("x").Input("y").Output("output"); + +DOMI_OP_SCHEMA(SquaredDifference).Input("x").Input("y").Output("output"); + +DOMI_OP_SCHEMA(LayerNorm).Input("x").Output("output").AttrRequired("Epsilon", AttributeType::FLOAT); + +DOMI_OP_SCHEMA(SSDPostProcessor) + .Input("trueImgShape") + .Input("boxEncoding") + .Input("anchors") + .Input("clsPred") + .Output("detectBoxes") + .Output("detectScores") + .Output("detectNum") + .Output("detectClasses") + .AttrRequired("numClasses", AttributeType::INT) + .AttrRequired("scoreThreshold", AttributeType::FLOAT) + .AttrRequired("iouThreshold", AttributeType::FLOAT) + .AttrRequired("maxDetectionsPerClass", AttributeType::INT) + .AttrRequired("maxTotalDetections", AttributeType::INT) + .AttrRequired("boxTypeNum", AttributeType::UINT) + .AttrRequired("scaleFactors_0", AttributeType::UINT) + .AttrRequired("scaleFactors_1", AttributeType::UINT) + .AttrRequired("scaleFactors_2", AttributeType::UINT) + .AttrRequired("scaleFactors_3", AttributeType::UINT) + .AttrRequired("imgH", AttributeType::INT) + .AttrRequired("imgW", AttributeType::INT) + .AttrRequired("useStaticShape", AttributeType::BOOL) + .AttrRequired("convertScoresMode", AttributeType::INT); + +DOMI_OP_SCHEMA(RetinaPostProcessor) + .Input("anchors") + .Input("regression") + .Input("classification") + .Output("detectBoxes") + .Output("detectScores") + .Output("detectLabels") + .Output("detectNum") + .AttrRequired("numClasses", AttributeType::INT) + .AttrRequired("maxDetections", AttributeType::INT) + .AttrRequired("nmsThreshold", AttributeType::FLOAT) + .AttrRequired("scoreThreshold", AttributeType::FLOAT) + .AttrRequired("imgH", AttributeType::INT) + .AttrRequired("imgW", AttributeType::INT) + .AttrRequired("boxTypeNum", AttributeType::UINT) + .AttrRequired("means", AttributeType::FLOATLIST) + .AttrRequired("stds", AttributeType::FLOATLIST); + +DOMI_OP_SCHEMA(ROIInterPooling) + .Input("input") + .Input("input_1") + .Output("maxPool") + .AttrRequired("hStride", AttributeType::INT) + .AttrRequired("wStride", AttributeType::INT) + .AttrRequired("hKernel", AttributeType::INT) + .AttrRequired("wKernel", AttributeType::INT) + .AttrRequired("hResize", AttributeType::INT) + .AttrRequired("wResize", AttributeType::INT) + .AttrRequired("hFeatureMap", AttributeType::INT) + .AttrRequired("wFeatureMap", AttributeType::INT); + +DOMI_OP_SCHEMA(FirstStageProcessor) + .Input("anchors") + .Input("boxEncoding") + .Input("clsPred") + .Input("trueImgShape") + .Output("detectBoxes") + .Output("detectScores") + .Output("detectLables") + .Output("detectNum") + .AttrRequired("scaleFactorsNum", AttributeType::INT) + .AttrRequired("iouThreshold", AttributeType::FLOAT) + .AttrRequired("scoreThreshold", AttributeType::FLOAT) + .AttrRequired("maxSizePerClass", AttributeType::INT) + .AttrRequired("maxTotalSize", AttributeType::INT) + .AttrRequired("imgH", AttributeType::INT) + .AttrRequired("imgW", AttributeType::INT) + .AttrRequired("boxTypeNum", AttributeType::UINT) + .AttrRequired("scaleFactors_0", AttributeType::UINT) + .AttrRequired("scaleFactors_1", AttributeType::UINT) + .AttrRequired("scaleFactors_2", AttributeType::UINT) + .AttrRequired("scaleFactors_3", AttributeType::UINT); + +DOMI_OP_SCHEMA(SecondStageProcessor) + .Input("anchors") + .Input("boxEncoding") + .Input("clsPred") + .Input("validBoxNum") + .Input("trueImgShape") + .Output("detectBoxes") + .Output("detectScores") + .Output("detectLables") + .Output("detectNum") + .AttrRequired("scaleFactorsNum", AttributeType::INT) + .AttrRequired("iouThreshold", AttributeType::FLOAT) + .AttrRequired("scoreThreshold", AttributeType::FLOAT) + .AttrRequired("maxSizePerClass", AttributeType::INT) + .AttrRequired("maxTotalSize", AttributeType::INT) + .AttrRequired("numClasses", AttributeType::INT) + .AttrRequired("scaleFactors_0", AttributeType::UINT) + .AttrRequired("scaleFactors_1", AttributeType::UINT) + .AttrRequired("scaleFactors_2", AttributeType::UINT) + .AttrRequired("scaleFactors_3", AttributeType::UINT); + +DOMI_OP_SCHEMA(StreamSwitch) + .Input("loopIndex") + .Input("itersPerLoop") + .AttrRequired("switch_condition", AttributeType::UINT) + .AttrRequired("true_branch_stream", AttributeType::INT); + +DOMI_OP_SCHEMA(StreamActive).AttrRequired("active_stream_list", AttributeType::INTLIST); + +DOMI_OP_SCHEMA(MemcpyAsync).Input("in").Output("out"); + +DOMI_OP_SCHEMA(CleanAddr) + .AttrRequired("automic_add_addr_start", AttributeType::INT) + .AttrRequired("automic_add_mem_size", AttributeType::INT); +} // namespace ge diff --git a/parser/common/op_def/fill_op.cc b/parser/common/op_def/fill_op.cc new file mode 100644 index 0000000..2228d26 --- /dev/null +++ b/parser/common/op_def/fill_op.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/op_def/fill_op.h" +#include "framework/common/fmk_types.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator::FillOperator() : ParserOperator("Fill") {} + +FMK_FUNC_DEV_VISIBILITY FillOperator::~FillOperator() {} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::DataType(int64_t dataType) { + Attr("T", static_cast(dataType)); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::Alpha(float alpha) { + Attr("alpha", static_cast(alpha)); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::Beta(float beta) { + Attr("beta", static_cast(beta)); + return *this; +} + +int64_t FillOperator::GetDataType() const { return GetIntAttr("T"); } + +float FillOperator::GetAlpha() const { return GetFloatAttr("alpha"); } + +float FillOperator::GetBeta() const { return GetFloatAttr("beta"); } +} // namespace ge diff --git a/parser/common/op_def/fill_op.h b/parser/common/op_def/fill_op.h new file mode 100644 index 0000000..8b25ee8 --- /dev/null +++ b/parser/common/op_def/fill_op.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DOMI_OP_FILL_OP_H_ +#define DOMI_OP_FILL_OP_H_ +#include "parser/common/op_def/operator.h" + +namespace ge { +class FillOperator : public ParserOperator { + public: + FillOperator(); + + ~FillOperator(); + + FillOperator &DataType(int64_t dataType); + + FillOperator &Alpha(float alpha); + + FillOperator &Beta(float beta); + + int64_t GetDataType() const; + + float GetAlpha() const; + + float GetBeta() const; +}; +} // namespace ge + +#endif // DOMI_OP_FILL_OP_H_ diff --git a/parser/common/op_def/frameworkop_op.cc b/parser/common/op_def/frameworkop_op.cc new file mode 100644 index 0000000..a762599 --- /dev/null +++ b/parser/common/op_def/frameworkop_op.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/op_def/frameworkop_op.h" +#include +#include "framework/common/fmk_types.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator::FrameworkOpOperator() + : ParserOperator("FrameworkOp") {} + +FrameworkOpOperator::~FrameworkOpOperator() {} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Name( + const std::string &name) { + ParserOperator::Name(name); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Index(int64_t index) { + Attr(RETVAL_ATTR_NAME_INDEX, static_cast(index)); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::NodeDefPkg( + const std::string &nodedef_pkg) { + Attr_bt(ATTR_NAME_FRAMEWORK_NODE_DEF, nodedef_pkg); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Frameworktype( + int64_t framework_type) { + Attr(ATTR_NAME_FRAMEWORK_FWK_TYPE, static_cast(framework_type)); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::TfOpDef( + const std::string &opdef_string) { + Attr(ATTR_NAME_FRAMEWORK_OP_DEF, opdef_string); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::OriginalType( + const std::string &type) { + Attr(ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::FuncDefPkg(const std::string &func_string) { + Attr_bt(ATTR_NAME_FRAMEWORK_FUNC_DEF, func_string); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY int64_t FrameworkOpOperator::GetFrameworkType() const { + return GetIntAttr(ATTR_NAME_FRAMEWORK_FWK_TYPE); +} + +FMK_FUNC_HOST_VISIBILITY std::string FrameworkOpOperator::GetNodeDefPkg() const { + return GetStringAttr(ATTR_NAME_FRAMEWORK_NODE_DEF); +} +} // namespace ge diff --git a/parser/common/op_def/frameworkop_op.h b/parser/common/op_def/frameworkop_op.h new file mode 100644 index 0000000..c01f0f7 --- /dev/null +++ b/parser/common/op_def/frameworkop_op.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_ +#define DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_ +#include "graph/debug/ge_attr_define.h" +#include "parser/common/op_def/operator.h" + +namespace ge { +class FrameworkOpOperator : public ParserOperator { + public: + FrameworkOpOperator(); + + ~FrameworkOpOperator(); + + FrameworkOpOperator &Name(const std::string &name); + + FrameworkOpOperator &OriginalType(const std::string &type); + + FrameworkOpOperator &NodeDefPkg(const std::string &nodedef_pkg); + + FrameworkOpOperator &Frameworktype(int64_t framework_type); + + FrameworkOpOperator &TfOpDef(const std::string &opdef_string); + + FrameworkOpOperator &Index(int64_t index); + + FrameworkOpOperator &FuncDefPkg(const std::string &func_string); + + int64_t GetFrameworkType() const; + + std::string GetNodeDefPkg() const; +}; +} // namespace ge + +#endif // DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_ diff --git a/parser/common/op_def/ir_pb_converter.cc b/parser/common/op_def/ir_pb_converter.cc new file mode 100644 index 0000000..8b5a8d4 --- /dev/null +++ b/parser/common/op_def/ir_pb_converter.cc @@ -0,0 +1,205 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/op_def/ir_pb_converter.h" +#include +#include +#include +#include +#include "google/protobuf/map.h" +#include "graph/ge_tensor.h" +#include "graph/buffer.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/parser/parser_types.h" +#include "framework/common/util.h" + +namespace ge { +static void ConvertList(const std::pair &op_attr_pair, ge::OpDescPtr op_def) { + domi::AttrDef_ListValue a_list = op_attr_pair.second.value_.list(); + + vector v_i; + for (int32_t i = 0; i < a_list.i_size(); i++) { + v_i.push_back((int64_t)a_list.i(i)); + } + if (v_i.size() > 0) { + (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i); + return; + } + vector v_f; + for (int32_t i = 0; i < a_list.f_size(); i++) { + v_f.push_back(a_list.f(i)); + } + if (v_f.size() > 0) { + (void)ge::AttrUtils::SetListFloat(op_def, op_attr_pair.first, v_f); + return; + } + vector v_b; + for (int32_t i = 0; i < a_list.b_size(); i++) { + v_b.push_back(a_list.b(i)); + } + if (v_b.size() > 0) { + (void)ge::AttrUtils::SetListBool(op_def, op_attr_pair.first, v_b); + return; + } + vector v_u; + for (int32_t i = 0; i < a_list.u_size(); i++) { + v_u.push_back((int32_t)a_list.u(i)); + } + if (v_u.size() > 0) { + (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_u); + return; + } + // set for empty list + (void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i); + GELOGI("set empty list for node %s attr %s", op_def->GetName().c_str(), op_attr_pair.first.c_str()); +} + +static void UpdateTensorForOpDesc(const ParserOperator &op, ge::OpDescPtr op_def) { + if (op_def == nullptr) { + return; + } + uint32_t in_index = 0; + for (const ge::GeTensorDesc &input_desc : op.GetInputTensorDesc()) { + if (in_index < op_def->GetInputsSize()) { + (void)op_def->UpdateInputDesc(in_index++, input_desc); + } else { + (void)op_def->AddInputDesc(input_desc); + in_index++; + } + } + + uint32_t out_index = 0; + for (const ge::GeTensorDesc &output_desc : op.GetOutputTensorDesc()) { + if (out_index < op_def->GetOutputsSize()) { + op_def->UpdateOutputDesc(out_index++, output_desc); + } else { + op_def->AddOutputDesc(output_desc); + out_index++; + } + } +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(const ParserOperator &op, + ge::OpDescPtr op_def) { + GE_RETURN_WITH_LOG_IF_TRUE(op_def == nullptr, "parameter is null."); + GE_CHK_BOOL_RET_STATUS(op.GetSchema(), domi::PARAM_INVALID, "Op schema is null, op type: %s", op.GetType().c_str()); + op_def->SetName(op.GetName()); + op_def->SetType(op.GetType()); + GE_IF_BOOL_EXEC(op.GetType() == ge::parser::YOLO, op_def->SetType(ge::parser::REGION)); + + UpdateTensorForOpDesc(op, op_def); + GELOGD("Convert to op desc: name:%s, input size: %zu, output size:%zu", op_def->GetName().c_str(), + op_def->GetInputsSize(), op_def->GetOutputsSize()); + + for (const auto &op_attr_pair : op.GetOpAttrs()) { + if (op_attr_pair.second.value_.has_list()) { + ConvertList(op_attr_pair, op_def); + } else { + if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kBt) { + auto &buffer = op_attr_pair.second.value_.bt(); + (void)ge::AttrUtils::SetZeroCopyBytes(op_def, op_attr_pair.first, + ge::Buffer::CopyFrom(reinterpret_cast(const_cast(buffer.data())), buffer.size())); + } + + if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kS) { + (void)ge::AttrUtils::SetStr(op_def, op_attr_pair.first, op_attr_pair.second.value_.s()); + } + if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kI) { + (void)ge::AttrUtils::SetInt(op_def, op_attr_pair.first, op_attr_pair.second.value_.i()); + } + if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kF) { + (void)ge::AttrUtils::SetFloat(op_def, op_attr_pair.first, op_attr_pair.second.value_.f()); + } + if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kB) { + (void)ge::AttrUtils::SetBool(op_def, op_attr_pair.first, op_attr_pair.second.value_.b()); + } + if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kU) { + (void)ge::AttrUtils::SetInt(op_def, op_attr_pair.first, op_attr_pair.second.value_.u()); + } + } + } + GE_CHK_BOOL_RET_STATUS(op.GetSchema()->Verify(op_def), domi::PARAM_INVALID, "Op schema verify failed, op name: %s", + op.GetName().c_str()); + + return domi::SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertFromOpDesc(const ge::OpDescPtr op_def, + ParserOperator &op) { + GE_RETURN_WITH_LOG_IF_TRUE(op_def == nullptr, "parameter is null."); + op.Name(op_def->GetName()); + + map allattrs = op_def->GetAllAttrs(); + + for (const auto &attr : allattrs) { + ge::GeAttrValue::ValueType v_t = attr.second.GetValueType(); + switch (v_t) { + case ge::GeAttrValue::ValueType::VT_LIST_STRING: { + std::vector vec; + (void)ge::AttrUtils::GetListStr(op_def, attr.first, vec); + op.Attr(attr.first, vec); + break; + } + case ge::GeAttrValue::ValueType::VT_LIST_FLOAT: { + std::vector vec; + (void)ge::AttrUtils::GetListFloat(op_def, attr.first, vec); + op.Attr(attr.first, vec); + break; + } + case ge::GeAttrValue::ValueType::VT_LIST_BOOL: { + std::vector vec; + (void)ge::AttrUtils::GetListBool(op_def, attr.first, vec); + op.Attr(attr.first, vec); + break; + } + case ge::GeAttrValue::ValueType::VT_LIST_INT: { + std::vector vec; + (void)ge::AttrUtils::GetListInt(op_def, attr.first, vec); + op.Attr(attr.first, vec); + break; + } + case ge::GeAttrValue::ValueType::VT_STRING: { + string s = ""; + (void)ge::AttrUtils::GetStr(op_def, attr.first, s); + op.Attr(attr.first, s); + break; + } + case ge::GeAttrValue::ValueType::VT_FLOAT: { + float f = 0.0; + (void)ge::AttrUtils::GetFloat(op_def, attr.first, f); + op.Attr(attr.first, f); + break; + } + case ge::GeAttrValue::ValueType::VT_BOOL: { + bool b = false; + (void)ge::AttrUtils::GetBool(op_def, attr.first, b); + op.Attr(attr.first, b); + break; + } + case ge::GeAttrValue::ValueType::VT_INT: { + int64_t i = 0; + (void)ge::AttrUtils::GetInt(op_def, attr.first, i); + op.Attr(attr.first, i); + break; + } + default: + break; + } + } + + return domi::SUCCESS; +} +} // namespace ge diff --git a/parser/common/op_def/ir_pb_converter.h b/parser/common/op_def/ir_pb_converter.h new file mode 100644 index 0000000..47b92fe --- /dev/null +++ b/parser/common/op_def/ir_pb_converter.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H +#define DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H + +#include "framework/common/fmk_error_codes.h" +#include "common/op_def/op_schema.h" +#include "parser/common/op_def/operator.h" +#include "graph/ge_attr_value.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "proto/om.pb.h" + +namespace ge { +domi::Status ConvertToOpDesc(const ParserOperator &op, ge::OpDescPtr op_def); + +domi::Status ConvertFromOpDesc(const ge::OpDescPtr op_def, ParserOperator &op); +} // namespace ge + +#endif // DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H diff --git a/parser/common/op_def/no_op_op.cc b/parser/common/op_def/no_op_op.cc new file mode 100644 index 0000000..472242a --- /dev/null +++ b/parser/common/op_def/no_op_op.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#include "common/op_def/no_op_op.h" +#include + +namespace ge { +FMK_FUNC_HOST_VISIBILITY NoOpOperator::NoOpOperator() : ParserOperator("NoOp") {} + +FMK_FUNC_HOST_VISIBILITY NoOpOperator::~NoOpOperator() {} + +FMK_FUNC_HOST_VISIBILITY NoOpOperator &NoOpOperator::Name(const std::string &name) { + ParserOperator::Name(name); + return *this; +} +} // namespace ge diff --git a/parser/common/op_def/no_op_op.h b/parser/common/op_def/no_op_op.h new file mode 100644 index 0000000..0208c90 --- /dev/null +++ b/parser/common/op_def/no_op_op.h @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#ifndef DOMI_OP_NO_OP_OP_H_ +#define DOMI_OP_NO_OP_OP_H_ +#include "parser/common/op_def/operator.h" +#include "framework/omg/parser/parser_types.h" + +namespace ge { +class NoOpOperator : public ParserOperator { + public: + NoOpOperator(); + ~NoOpOperator(); + + NoOpOperator &Name(const std::string &name); +}; +} // namespace ge + +#endif // DOMI_OP_NO_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT diff --git a/parser/common/op_def/op_schema.cc b/parser/common/op_def/op_schema.cc new file mode 100644 index 0000000..5882b44 --- /dev/null +++ b/parser/common/op_def/op_schema.cc @@ -0,0 +1,215 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/op_def/op_schema.h" +#include +#include +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" + +namespace ge { +OpSchema::FormalParameter::FormalParameter(const std::string &name, FormalParameterOption param_option) + : name_(name), param_option_(param_option) {} + +OpSchema::FormalParameter::~FormalParameter() {} + +const std::string &OpSchema::FormalParameter::Name() const { return name_; } + +OpSchema::FormalParameterOption OpSchema::FormalParameter::Option() const { return param_option_; } + +OpSchema::OpSchema(const std::string &name) : name_(name) {} + +OpSchema::~OpSchema() {} + +OpSchema &OpSchema::Input(const std::string &name, FormalParameterOption param_option) { + inputs_.emplace_back(FormalParameter(name, param_option)); + return *this; +} + +OpSchema &OpSchema::Output(const std::string &name, FormalParameterOption param_option) { + outputs_.emplace_back(FormalParameter(name, param_option)); + return *this; +} + +OpSchema &OpSchema::Attr(const Attribute &attr) { + (void)attributes_.insert(std::make_pair(attr.name_, attr)); + return *this; +} + +#if defined(CFG_BUILD_DEBUG) +#define ATTR_SETTER_WITH_SINGLE_VALUE(Type, field, attrtype) \ + OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Type &default_value) { \ + if (attrtype != attr_type) { \ + GELOGE(FAILED, "Attribute specification param_type mismatch, input attr type %u, required attr type %u.", \ + (uint32_t)attr_type, (uint32_t)attrtype); \ + return *this; \ + } \ + \ + domi::AttrDef a; \ + a.set_##field(default_value); \ + Attr(Attribute(name, attr_type, a)); \ + return *this; \ + } +#else +#define ATTR_SETTER_WITH_SINGLE_VALUE(Type, field, attrtype) \ + OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Type &default_value) { \ + if (attrtype != attr_type) { \ + return *this; \ + } \ + domi::AttrDef a; \ + a.set_##field(default_value); \ + Attr(Attribute(name, attr_type, a)); \ + return *this; \ + } + +#endif + +#if defined(CFG_BUILD_DEBUG) +#define ATTR_SETTER_WITH_LIST_VALUE(Type, field, attrtype) \ + OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const std::vector &default_value) { \ + if (attrtype != attr_type) { \ + GELOGE(FAILED, "Attribute specification vector param_type mismatch, input attr type %u, required attr type %u.", \ + (uint32_t)attr_type, (uint32_t)attrtype); \ + return *this; \ + } \ + domi::AttrDef vec_a; \ + for (const auto &v : default_value) { \ + vec_a.mutable_list()->add_##field(v); \ + } \ + Attr(Attribute(name, attr_type, vec_a)); \ + return *this; \ + } \ + OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Tuple &default_value) { \ + if (attrtype != attr_type) { \ + GELOGE(FAILED, "Attribute specification vector param_type mismatch, input attr type %u, required attr type %u.", \ + (uint32_t)attr_type, (uint32_t)attrtype); \ + return *this; \ + } \ + domi::AttrDef tuple_a; \ + for (const auto &v : default_value) { \ + tuple_a.mutable_list()->add_##field(v); \ + } \ + Attr(Attribute(name, attr_type, tuple_a)); \ + return *this; \ + } +#else +#define ATTR_SETTER_WITH_LIST_VALUE(Type, field, attrtype) \ + OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const std::vector &default_value) { \ + if (attrtype != attr_type) { \ + return *this; \ + } \ + domi::AttrDef vec_a; \ + for (const auto &v : default_value) { \ + vec_a.mutable_list()->add_##field(v); \ + } \ + Attr(Attribute(name, attr_type, vec_a)); \ + return *this; \ + } \ + OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Tuple &default_value) { \ + if (attrtype != attr_type) { \ + return *this; \ + } \ + domi::AttrDef tuple_a; \ + for (const auto &v : default_value) { \ + tuple_a.mutable_list()->add_##field(v); \ + } \ + Attr(Attribute(name, attr_type, tuple_a)); \ + return *this; \ + } + +#endif +ATTR_SETTER_WITH_SINGLE_VALUE(uint32_t, u, AttributeType::UINT) +ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i, AttributeType::INT) +ATTR_SETTER_WITH_SINGLE_VALUE(bool, b, AttributeType::BOOL) +ATTR_SETTER_WITH_SINGLE_VALUE(float, f, AttributeType::FLOAT) +ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s, AttributeType::STRING) + +ATTR_SETTER_WITH_LIST_VALUE(uint32_t, u, AttributeType::UINTLIST) +ATTR_SETTER_WITH_LIST_VALUE(int64_t, i, AttributeType::INTLIST) +ATTR_SETTER_WITH_LIST_VALUE(bool, b, AttributeType::BOOLLIST) +ATTR_SETTER_WITH_LIST_VALUE(float, f, AttributeType::FLOATLIST) +ATTR_SETTER_WITH_LIST_VALUE(std::string, s, AttributeType::STRINGLIST) + +OpSchema &OpSchema::AttrRequired(const std::string &name, AttributeType attr_type) { + Attr(Attribute(name, attr_type, true)); + return *this; +} + +bool OpSchema::HasDefaultAttr(const std::string &name) const { + auto it = attributes_.find(name); + if (it == attributes_.end()) { + return false; + } + + // required does not need a default value + return !it->second.required_; +} + +const domi::AttrDef &OpSchema::GetDefaultAttr(const std::string &name) const { + auto it = attributes_.find(name); + if (it == attributes_.end()) { + const static domi::AttrDef attr_def; + return attr_def; + } + return it->second.default_value_; +} + +bool OpSchema::Verify(const ge::OpDescPtr op_def) const { + if (op_def->GetType() != name_) { + GELOGE(FAILED, "Name not math, op schema name: %s, opdef type: %s.", name_.c_str(), op_def->GetType().c_str()); + return false; + } + + // Required field verification + for (const auto &pair : attributes_) { + const auto &attr = pair.second; + if (!attr.required_) { + continue; + } + if (!op_def->HasAttr(attr.name_)) { + GELOGE(FAILED, "Required attribute: %s of op: %s is missing.", attr.name_.c_str(), op_def->GetName().c_str()); + return false; + } + } + + return true; +} + +OpSchemaFactory &OpSchemaFactory::Instance() { + static OpSchemaFactory instance; + return instance; +} + +const OpSchema *OpSchemaFactory::Get(const std::string &op) const { + auto it = op_schema_map_.find(op); + if (it == op_schema_map_.end()) { + return nullptr; + } + return &it->second; +} + +OpSchemaRegistry::OpSchemaRegistry(OpSchema &op_schema) { + OpSchemaFactory &op_factory = OpSchemaFactory::Instance(); + + // save op_schema to the map + if (op_factory.op_schema_map_.count(op_schema.name_)) { + GELOGD("Failed to register op schema: %s., reason: already exist!", op_schema.name_.c_str()); + return; + } + + (void)op_factory.op_schema_map_.emplace(std::make_pair(op_schema.name_, op_schema)); +} +} // namespace ge diff --git a/parser/common/op_def/op_schema.h b/parser/common/op_def/op_schema.h new file mode 100644 index 0000000..48b5c3b --- /dev/null +++ b/parser/common/op_def/op_schema.h @@ -0,0 +1,175 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DOMI_COMMON_OP_SCHEMA_H +#define DOMI_COMMON_OP_SCHEMA_H + +#include +#include +#include +#include "common/tuple.h" +#include "graph/op_desc.h" +#include "proto/om.pb.h" +#include "framework/common/fmk_types.h" + +namespace ge { +enum class AttributeType { + UNDEFINED, + INT, + UINT, + BOOL, + FLOAT, + STRING, + BYTES, + + INTLIST, + UINTLIST, + BOOLLIST, + FLOATLIST, + STRINGLIST +}; + +class OpSchema; + +class OpSchemaRegistry; + +class FMK_FUNC_HOST_VISIBILITY OpSchema { + public: + // Formal parameter options. + enum FormalParameterOption { + // The input formal parameter is single and not optional. + // Number of this input is 1. + Single = 0, + // The input formal parameter is single and optional. + // Number of this input is 0 or 1. + Optional = 1, + // The input formal parameter is variadic. + // Number of this input is [1, n]. + Variadic = 2, + }; + + // Formal parameter represenation, including input/output name, typeStr, + // description, and type constraints. + class FormalParameter { + public: + // Constructor. + FormalParameter() = default; + + explicit FormalParameter(const std::string &name, FormalParameterOption param_option = Single); + + ~FormalParameter(); + + // Get formal parameter name. + const std::string &Name() const; + + // Get the parameter option, it could be Single, Optional or Variadic. + FormalParameterOption Option() const; + + private: + friend class OpSchema; + + // Formal parameter name. + std::string name_; + + // Formal parameter option. + FormalParameterOption param_option_; + }; + + explicit OpSchema(const std::string &name); + + ~OpSchema(); + + OpSchema &Input(const std::string &name, FormalParameterOption param_option = Single); + + OpSchema &Output(const std::string &name, FormalParameterOption param_option = Single); + + struct Attribute { + Attribute(const std::string &name, AttributeType type, bool required) + : name_(name), type_(type), required_(required) {} + + Attribute(const std::string &name, AttributeType type, domi::AttrDef default_value) + : name_(name), type_(type), required_(false), default_value_(default_value) {} + + const std::string name_; + AttributeType type_; + bool required_; + domi::AttrDef default_value_; + }; + + OpSchema &Attr(const Attribute &attr); + +// Register "optional" attribute with default value. +#define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName) \ + OpSchema &Attr(const std::string &name, AttributeType type, const TypeName &default_value); \ + OpSchema &Attr(const std::string &name, AttributeType type, const std::vector &default_value); \ + OpSchema &Attr(const std::string &name, AttributeType type, const Tuple &default_value); + + ATTR_SETTER_WITH_DEFAULT_VALUE(uint32_t) + ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t) + ATTR_SETTER_WITH_DEFAULT_VALUE(bool) + ATTR_SETTER_WITH_DEFAULT_VALUE(float) + ATTR_SETTER_WITH_DEFAULT_VALUE(std::string) + + // Register "required" attribute without default value. + OpSchema &AttrRequired(const std::string &name, AttributeType type); + + bool HasDefaultAttr(const std::string &name) const; + + const domi::AttrDef &GetDefaultAttr(const std::string &name) const; + + // verify op_def + bool Verify(const ge::OpDescPtr op_def) const; + + private: + friend class OpSchemaRegistry; + + std::string name_; + + std::vector inputs_; + + std::vector outputs_; + + std::unordered_map attributes_; +}; + +class OpSchemaFactory { + public: + // this is a singleton object + static OpSchemaFactory &Instance(); + + const OpSchema *Get(const std::string &op) const; + + private: + OpSchemaFactory() = default; + ~OpSchemaFactory() = default; + + friend class OpSchemaRegistry; + // the op schema map + std::unordered_map op_schema_map_; +}; + +class FMK_FUNC_HOST_VISIBILITY OpSchemaRegistry { + public: + OpSchemaRegistry(OpSchema &op_schema); + ~OpSchemaRegistry() = default; +}; + +#define DOMI_OP_SCHEMA(name) DOMI_OP_SCHEMA_UNIQ_HELPER(__COUNTER__, name) +#define DOMI_OP_SCHEMA_UNIQ_HELPER(ctr, name) DOMI_OP_SCHEMA_UNIQ(ctr, name) +#define DOMI_OP_SCHEMA_UNIQ(ctr, name) \ + static OpSchemaRegistry op_schema_registry##ctr __attribute__((unused)) = OpSchema(#name) +} // namespace ge +#endif // DOMI_COMMON_OP_SCHEMA_H diff --git a/parser/common/op_def/operator.cc b/parser/common/op_def/operator.cc new file mode 100644 index 0000000..d18175d --- /dev/null +++ b/parser/common/op_def/operator.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operator.h" +#include +#include "framework/common/fmk_types.h" +#include "framework/common/util.h" +#include "framework/common/debug/ge_log.h" + +using ge::BoolTuple; +using ge::FloatTuple; +using ge::IntTuple; +using ge::StringTuple; +using ge::UintTuple; + +namespace ge { +ParserOperator::ParserOperator(const std::string &type) { + type_ = type; + op_schema_ = ge::OpSchemaFactory::Instance().Get(type); + if (op_schema_ == nullptr) { + GELOGW("Cannot find op schema of op type: %s", type.c_str()); + } +} + +ParserOperator &ParserOperator::Input(const ParserOperator &in_op, uint32_t index) { + if (index == 0) { + inputs_.push_back(in_op.GetName()); + } else { + inputs_.push_back(in_op.GetName() + ":" + std::to_string(index)); + } + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::Name(const std::string &name) { + name_ = name; + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::Type(const std::string &type) { + type_ = type; + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::InputTensorDesc( + const ge::GeTensorDesc &input_tensordesc) { + input_descs_.push_back(input_tensordesc); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::OutputTensorDesc( + const ge::GeTensorDesc &output_tensordesc) { + output_descs_.push_back(output_tensordesc); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::AttrVector( + std::string key, + std::vector &value) { + domi::AttrDef out; + auto it = op_attrs_.find(key); + if (it != op_attrs_.end()) { + out = it->second.value_; + } + for (auto &v : value) { + out.mutable_list()->add_i(v); + } + (void)op_attrs_.erase(key); + (void)op_attrs_.insert(std::make_pair(key, OpAttribute(key, out))); + return *this; +} +FMK_FUNC_DEV_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::AttrVector( + std::string key, + std::vector &value) { + domi::AttrDef out; + auto it = op_attrs_.find(key); + if (it != op_attrs_.end()) { + out = it->second.value_; + } + for (auto &v : value) { + out.mutable_list()->add_i(v); + } + (void)op_attrs_.erase(key); + (void)op_attrs_.insert(std::make_pair(key, OpAttribute(key, out))); + return *this; +} + +ParserOperator &ParserOperator::Attr(const OpAttribute &attr) { + auto it = op_attrs_.find(attr.name_); + if (it != op_attrs_.end()) { + (void)op_attrs_.erase(it); + } + (void)op_attrs_.insert(std::make_pair(attr.name_, attr)); + return *this; +} + +ParserOperator &ParserOperator::Attr_bt(const std::string &name, const std::string &value) { + domi::AttrDef a; + a.set_bt(value); + Attr(OpAttribute(name, a)); + return *this; +} + +#define ATTR_SETTER_WITH_SINGLE_VALUE(type, field) \ + ParserOperator &ParserOperator::Attr(const std::string &name, const type &value) { \ + domi::AttrDef a; \ + a.set_##field(value); \ + Attr(OpAttribute(name, a)); \ + return *this; \ + } + +#define ATTR_SETTER_WITH_LIST_VALUE(type, field) \ + ParserOperator &ParserOperator::Attr(const std::string &name, const std::vector &value) { \ + domi::AttrDef a; \ + auto attr_list = a.mutable_list(); \ + for (size_t i = 0; i < value.size(); ++i) { \ + attr_list->add_##field(value[i]); \ + } \ + Attr(OpAttribute(name, a)); \ + return *this; \ + } \ + ParserOperator &ParserOperator::Attr(const std::string &name, const ge::Tuple &value) { \ + domi::AttrDef a; \ + auto attr_list = a.mutable_list(); \ + for (uint32_t i = 0; i < value.ndim(); ++i) { \ + attr_list->add_##field(value[i]); \ + } \ + Attr(OpAttribute(name, a)); \ + return *this; \ + } + +ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i) +ATTR_SETTER_WITH_SINGLE_VALUE(bool, b) +ATTR_SETTER_WITH_SINGLE_VALUE(float, f) +ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s) +ATTR_SETTER_WITH_SINGLE_VALUE(uint32_t, i) + +ATTR_SETTER_WITH_LIST_VALUE(int64_t, i) +ATTR_SETTER_WITH_LIST_VALUE(bool, b) +ATTR_SETTER_WITH_LIST_VALUE(float, f) +ATTR_SETTER_WITH_LIST_VALUE(std::string, s) +ATTR_SETTER_WITH_LIST_VALUE(uint32_t, i) + +#define ATTR_GET_SINGLE_VALUE(type, field, type_name) \ + type ParserOperator::Get##type_name##Attr(const std::string &name) const { \ + domi::AttrDef single_val; \ + auto it = op_attrs_.find(name); \ + if (it != op_attrs_.end()) { \ + single_val = it->second.value_; \ + } else { \ + if (op_schema_ && op_schema_->HasDefaultAttr(name)) { \ + single_val = op_schema_->GetDefaultAttr(name); \ + } \ + } \ + return single_val.field(); \ + } +ATTR_GET_SINGLE_VALUE(uint32_t, i, Uint) +ATTR_GET_SINGLE_VALUE(int64_t, i, Int) +ATTR_GET_SINGLE_VALUE(float, f, Float) +ATTR_GET_SINGLE_VALUE(bool, b, Bool) +ATTR_GET_SINGLE_VALUE(std::string, s, String) + +#define ATTR_GET_TUPLE_VALUE(type, field, tuple_type_name) \ + tuple_type_name ParserOperator::Get##tuple_type_name##Attr(const std::string &name) const { \ + domi::AttrDef value; \ + auto it = op_attrs_.find(name); \ + if (it != op_attrs_.end()) { \ + value = it->second.value_; \ + } else { \ + if (op_schema_ && op_schema_->HasDefaultAttr(name)) { \ + value = op_schema_->GetDefaultAttr(name); \ + } \ + } \ + const auto attr_def = value.list(); \ + std::size_t n = attr_def.field##_size(); \ + std::vector vec(n); \ + for (std::size_t i = 0; i < n; i++) { \ + vec[i] = attr_def.field(i); \ + } \ + return tuple_type_name(vec); \ + } + +ATTR_GET_TUPLE_VALUE(uint32_t, i, UintTuple) +ATTR_GET_TUPLE_VALUE(int64_t, i, IntTuple) +ATTR_GET_TUPLE_VALUE(float, f, FloatTuple) +ATTR_GET_TUPLE_VALUE(bool, b, BoolTuple) +ATTR_GET_TUPLE_VALUE(std::string, s, StringTuple) +} // namespace domi diff --git a/parser/common/op_def/operator.h b/parser/common/op_def/operator.h new file mode 100644 index 0000000..63d9ae3 --- /dev/null +++ b/parser/common/op_def/operator.h @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DOMI_COMMON_OP_OPERATOR_H +#define DOMI_COMMON_OP_OPERATOR_H + +#include +#include +#include +#include "framework/common/fmk_types.h" +#include "common/op_def/op_schema.h" +#include "common/tuple.h" +#include "graph/ge_tensor.h" +#include "proto/om.pb.h" +namespace ge { +struct OpAttribute { + OpAttribute(const std::string &name, const domi::AttrDef &value) : name_(name), value_(value) {} + const std::string name_; + domi::AttrDef value_; +}; + +class FMK_FUNC_HOST_VISIBILITY ParserOperator { + public: + explicit ParserOperator(const std::string &type); + ParserOperator() { op_schema_ = nullptr; } + + virtual ~ParserOperator() { op_schema_ = nullptr; } + + ParserOperator &Input(const ParserOperator &in_op, uint32_t index = 0); + + ParserOperator &Attr(const OpAttribute &op_attr); + + ParserOperator &AttrVector(std::string key, std::vector &value); + ParserOperator &AttrVector(std::string key, std::vector &value); + + ParserOperator &Name(const std::string &name); + + ParserOperator &Type(const std::string &type); + + ParserOperator &InputTensorDesc(const ge::GeTensorDesc &input_tensordesc); + + ParserOperator &OutputTensorDesc(const ge::GeTensorDesc &output_tensordesc); + + ParserOperator &Attr_bt(const std::string &name, const std::string &value); + +// Register "optional" attribute with default value. +#define ATTR_SETTER_WITH_VALUE(TypeName) \ + ParserOperator &Attr(const std::string &name, const TypeName &value); \ + ParserOperator &Attr(const std::string &name, const std::vector &value); \ + ParserOperator &Attr(const std::string &name, const ge::Tuple &value) + + ATTR_SETTER_WITH_VALUE(uint32_t); + ATTR_SETTER_WITH_VALUE(int64_t); + ATTR_SETTER_WITH_VALUE(bool); + ATTR_SETTER_WITH_VALUE(float); + ATTR_SETTER_WITH_VALUE(std::string); + + const std::string &GetName() const { return name_; } + + const std::string &GetType() const { return type_; } + + const std::vector &GetInputs() const { return inputs_; } + + const std::vector &GetInputTensorDesc() const { return input_descs_; } + + const std::vector &GetOutputTensorDesc() const { return output_descs_; } + + const std::unordered_map GetOpAttrs() const { return op_attrs_; } + + bool HasAttr(const std::string &name) const { return op_attrs_.find(name) != op_attrs_.end(); } + + const ge::OpSchema *GetSchema() const { return op_schema_; } + + int64_t GetIntAttr(const std::string &name) const; + + uint32_t GetUintAttr(const std::string &name) const; + + float GetFloatAttr(const std::string &name) const; + + bool GetBoolAttr(const std::string &name) const; + + std::string GetStringAttr(const std::string &name) const; + + ge::IntTuple GetIntTupleAttr(const std::string &name) const; + + ge::UintTuple GetUintTupleAttr(const std::string &name) const; + + ge::FloatTuple GetFloatTupleAttr(const std::string &name) const; + + ge::BoolTuple GetBoolTupleAttr(const std::string &name) const; + + ge::StringTuple GetStringTupleAttr(const std::string &name) const; + + private: + const ge::OpSchema *op_schema_; + std::string name_; + std::string type_; + std::vector inputs_; + std::unordered_map op_attrs_; + std::vector input_descs_; + std::vector output_descs_; +}; +} // namespace domi +#endif // DOMI_COMMON_OP_OPERATOR_H diff --git a/parser/common/op_def/ref_switch_op.cc b/parser/common/op_def/ref_switch_op.cc new file mode 100644 index 0000000..6a45868 --- /dev/null +++ b/parser/common/op_def/ref_switch_op.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#include "common/op_def/ref_switch_op.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOperator() : ParserOperator("RefSwitch") {} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::~RefSwitchOperator() {} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::Name(const std::string &name) { + ParserOperator::Name(name); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::T(ge::DataType t) { + Attr("T", (int64_t)t); + return *this; +} +} // namespace ge AUTO GEN PLEASE DO NOT MODIFY IT diff --git a/parser/common/op_def/ref_switch_op.h b/parser/common/op_def/ref_switch_op.h new file mode 100644 index 0000000..baf2167 --- /dev/null +++ b/parser/common/op_def/ref_switch_op.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#ifndef DOMI_OP_REF_SWITCH_H_ +#define DOMI_OP_REF_SWITCH_H_ +#include "parser/common/op_def/operator.h" +#include "framework/omg/parser/parser_types.h" + +namespace ge { +class RefSwitchOperator : public ParserOperator { + public: + RefSwitchOperator(); + ~RefSwitchOperator(); + + RefSwitchOperator &Name(const std::string &name); + RefSwitchOperator &T(ge::DataType t); +}; +} // namespace ge + +#endif // DOMI_OP_REF_SWITCH_H_ AUTO GEN PLEASE DO NOT MODIFY IT diff --git a/parser/common/op_def/shape_n_op.cc b/parser/common/op_def/shape_n_op.cc new file mode 100644 index 0000000..0e6e14f --- /dev/null +++ b/parser/common/op_def/shape_n_op.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#include "common/op_def/shape_n_op.h" +#include "graph/debug/ge_attr_define.h" +#include "framework/omg/parser/parser_types.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY ShapeNOperator::ShapeNOperator() : ParserOperator("ShapeN") {} + +FMK_FUNC_HOST_VISIBILITY ShapeNOperator::~ShapeNOperator() {} + +FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::Name(const std::string &name) { + ParserOperator::Name(name); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::N(int64_t n) { + Attr(SHAPEN_ATTR_N, n); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY int64_t ShapeNOperator::GetN() const { return GetIntAttr(SHAPEN_ATTR_N); } + +FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::InType(ge::DataType t) { + Attr(SHAPEN_ATTR_IN_TYPE, (int64_t)t); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetInType() const { + return (ge::DataType)GetIntAttr(SHAPEN_ATTR_IN_TYPE); +} + +FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::OutType(ge::DataType t) { + Attr(SHAPEN_ATTR_OUT_TYPE, (int64_t)t); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetOutType() const { + return (ge::DataType)GetIntAttr(SHAPEN_ATTR_OUT_TYPE); +} +} // namespace ge diff --git a/parser/common/op_def/shape_n_op.h b/parser/common/op_def/shape_n_op.h new file mode 100644 index 0000000..bb69235 --- /dev/null +++ b/parser/common/op_def/shape_n_op.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#ifndef DOMI_OP_SHAPE_N_OP_H_ +#define DOMI_OP_SHAPE_N_OP_H_ +#include "parser/common/op_def/operator.h" +#include "framework/omg/parser/parser_types.h" + +namespace ge { +class ShapeNOperator : public ParserOperator { + public: + ShapeNOperator(); + ~ShapeNOperator(); + + ShapeNOperator &Name(const std::string &name); + + ShapeNOperator &N(int64_t n); + int64_t GetN() const; + ShapeNOperator &InType(ge::DataType t); + ge::DataType GetInType() const; + ShapeNOperator &OutType(ge::DataType t); + ge::DataType GetOutType() const; +}; +} // namespace ge + +#endif // DOMI_OP_SHAPE_N_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT diff --git a/parser/common/op_def/var_is_initialized_op_op.cc b/parser/common/op_def/var_is_initialized_op_op.cc new file mode 100644 index 0000000..e0e3d62 --- /dev/null +++ b/parser/common/op_def/var_is_initialized_op_op.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#include "common/op_def/var_is_initialized_op_op.h" +#include +#include + +namespace ge { +VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::parser::VARISINITIALIZEDOP) {} + +VarIsInitializedOpOperator::~VarIsInitializedOpOperator() {} + +VarIsInitializedOpOperator &VarIsInitializedOpOperator::Name(const std::string &name) { + ParserOperator::Name(name); + return *this; +} + +VarIsInitializedOpOperator &VarIsInitializedOpOperator::VectorAttr(const std::string &key, + std::vector &value) { + Attr(key, value); + return *this; +} +} // namespace ge diff --git a/parser/common/op_def/var_is_initialized_op_op.h b/parser/common/op_def/var_is_initialized_op_op.h new file mode 100644 index 0000000..88b649f --- /dev/null +++ b/parser/common/op_def/var_is_initialized_op_op.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#ifndef DOMI_OP_VARISINITIALIZEDOP_H_ +#define DOMI_OP_VARISINITIALIZEDOP_H_ +#include "parser/common/op_def/operator.h" +#include "framework/omg/parser/parser_types.h" + +namespace ge { +class VarIsInitializedOpOperator : public ParserOperator { + public: + VarIsInitializedOpOperator(); + ~VarIsInitializedOpOperator(); + + VarIsInitializedOpOperator &Name(const std::string &name); + VarIsInitializedOpOperator &VectorAttr(const std::string &key, std::vector &value); +}; +} // namespace ge + +#endif // DOMI_OP_VARISINITIALIZEDOP_H_ AUTO GEN PLEASE DO NOT MODIFY IT diff --git a/parser/common/op_def/variable_op.cc b/parser/common/op_def/variable_op.cc new file mode 100644 index 0000000..2cf294e --- /dev/null +++ b/parser/common/op_def/variable_op.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/op_def/variable_op.h" + +#include "graph/debug/ge_attr_define.h" + +namespace ge { +VariableOperator::VariableOperator() : ParserOperator(ge::parser::VARIABLE) {} + +VariableOperator::~VariableOperator() {} + +VariableOperator &VariableOperator::Name(const std::string &name) { + ParserOperator::Name(name); + return *this; +} + +VariableOperator &VariableOperator::Container(const std::string &container) { + Attr(VAR_ATTR_CONTAINER, container); + return *this; +} + +VariableOperator &VariableOperator::SharedName(const std::string &sharedname) { + Attr(VAR_ATTR_SHARED_NAME, sharedname); + return *this; +} + +VariableOperator &VariableOperator::Placement(const std::string &placement) { + Attr(ATTR_VARIABLE_PLACEMENT, placement); + return *this; +} + +VariableOperator &VariableOperator::SrcType(const int64_t &dtype) { + Attr(VAR_ATTR_DTYPE, dtype); + return *this; +} + +VariableOperator &VariableOperator::VarShape(const std::vector &shape_value) { + Attr(VAR_ATTR_SHAPE, shape_value); + return *this; +} + +int64_t VariableOperator::GetVarSrcType() const { return GetIntAttr(VAR_ATTR_DTYPE); } +} // namespace ge diff --git a/parser/common/op_def/variable_op.h b/parser/common/op_def/variable_op.h new file mode 100644 index 0000000..c9b85d3 --- /dev/null +++ b/parser/common/op_def/variable_op.h @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// AUTO GEN PLEASE DO NOT MODIFY IT +#ifndef DOMI_OP_VARIABLE_H_ +#define DOMI_OP_VARIABLE_H_ +#include +#include "parser/common/op_def/operator.h" +#include "framework/omg/parser/parser_types.h" + +namespace ge { +class VariableOperator : public ParserOperator { + public: + VariableOperator(); + ~VariableOperator(); + + VariableOperator &Name(const std::string &name); + + VariableOperator &Container(const std::string &container); + + VariableOperator &SharedName(const std::string &sharedname); + + VariableOperator &Placement(const std::string &placement); + + VariableOperator &SrcType(const int64_t &dtype); + + VariableOperator &VarShape(const std::vector &shape_value); + + int64_t GetVarSrcType() const; +}; +} // namespace ge + +#endif // DOMI_OP_VAR_H_ AUTO GEN PLEASE DO NOT MODIFY IT diff --git a/parser/common/op_map.cc b/parser/common/op_map.cc new file mode 100644 index 0000000..486b462 --- /dev/null +++ b/parser/common/op_map.cc @@ -0,0 +1,159 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/op_map.h" + +#include +#include +#include + +#include "framework/omg/parser/parser_types.h" +#include "register/op_registry.h" + +using std::map; +using std::string; +using std::vector; +using namespace ge::parser; + +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map caffe_op_map = { + {"Input", DATA}, + {"DummyData", DATA}, + {"Reshape", RESHAPE}, + {"Dropout", DROPOUT}, + {"NetOutput", NETOUTPUT}, +}; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map tensorflow_op_map = { + {"BroadcastGradientArgs", BROADCASTGRADIENTARGS}, + {"StopGradient", STOPGRADIENT}, + {"ExpandDims", EXPANDDIMS}, + {"DestroyTemporaryVariable", DESTROYTEMPORARYVARIABLE}, + {"GuaranteeConst", GUARANTEECONST}, + {"BroadcastArgs", BROADCASTARGS}, + {"PreventGradient", PREVENTGRADIENT}, + {"Empty", EMPTY}, + {"Placeholder", DATA}, + {"ControlTrigger", CONTROLTRIGGER}, + {"_ParallelConcatStart", PARALLELCONCATSTART}, + {"Const", CONSTANT}, + {"FrameworkOp", FRAMEWORKOP}, + {"Reshape", RESHAPE}, + {"Squeeze", SQUEEZE}, + {"Enter", ENTER}, + {"RefEnter", REFENTER}, + {"Exit", EXIT}, + {"RefExit", REFEXIT}, + {"LoopCond", LOOPCOND}, + {"NextIteration", NEXTITERATION}, + {"RefNextIteration", REFNEXTITERATION}, + {"Identity", IDENTITY}, + {"IdentityN", IDENTITYN}, + {"PlaceholderWithDefault", PLACEHOLDERWITHDEFAULT}, + {"Size", SIZE}, + {"Shape", SHAPE}, + {"ShapeN", SHAPEN}, + {"Fill", FILL}, + {"Rank", RANK}, + {"Merge", MERGE}, + {"RefMerge", REFMERGE}, + {"Switch", SWITCH}, + {"RefSwitch", REFSWITCH}, + {"LayerNorm", LAYERNORM}, + {"RNN", RNN}, + {"_Arg", ARG}, + {"_Retval", FRAMEWORKOP}, + {"Bitcast", BITCAST}, + {"Snapshot", SNAPSHOT}, +}; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map tensorflow_train_op_map = { + {"BroadcastGradientArgs", BROADCASTGRADIENTARGS}, + {"StopGradient", STOPGRADIENT}, + {"ExpandDims", EXPANDDIMS}, + {"DestroyTemporaryVariable", DESTROYTEMPORARYVARIABLE}, + {"TemporaryVariable", TEMPORARYVARIABLE}, + {"GuaranteeConst", GUARANTEECONST}, + {"BroadcastArgs", BROADCASTARGS}, + {"PreventGradient", PREVENTGRADIENT}, + {"Empty", EMPTY}, + {"ControlTrigger", CONTROLTRIGGER}, + {"_Arg", ARG}, + {"_ParallelConcatStart", PARALLELCONCATSTART}, + {"Const", CONSTANTOP}, + {"VariableV2", VARIABLE}, + {"VarHandleOp", VARHANDLEOP}, + {"VarIsInitializedOp", VARISINITIALIZEDOP}, + {"IsVariableInitialized", ISVARIABLEINITIALIZED}, + {"ReadVariableOp", READVARIABLEOP}, + {"Reshape", RESHAPE}, + {"Squeeze", SQUEEZE}, + {"NoOp", NOOP}, + {"Enter", ENTER}, + {"RefEnter", REFENTER}, + {"Exit", EXIT}, + {"RefExit", REFEXIT}, + {"LoopCond", LOOPCOND}, + {"NextIteration", NEXTITERATION}, + {"RefNextIteration", REFNEXTITERATION}, + {"Identity", IDENTITY}, + {"IdentityN", IDENTITYN}, + {"PlaceholderWithDefault", PLACEHOLDERWITHDEFAULT}, + {"Size", SIZE}, + {"Shape", SHAPE}, + {"ShapeN", SHAPEN}, + {"Rank", RANK}, + {"Merge", MERGE}, + {"Switch", SWITCH}, + {"LayerNorm", LAYERNORM}, + {"LayerNormGrad", LAYERNORMGRAD}, + {"Dropout", DROPOUT}, + {"Bitcast", BITCAST}, + {"Snapshot", SNAPSHOT}, +}; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map op_output_tensor_num = { + {SSDDETECTIONOUTPUT, 3}, + {REFINEDETDETECTIONOUTPUT, 3}, + {FSRDETECTIONOUTPUT, 2}, + {FASTERRCNNFIRSTSTAGEPOSTPROCESSOR, 4}, + {FASTERRCNNSECONDSTAGEPOSTPROCESSOR, 4}, + {YOLODETECTIONOUTPUT, 2}, + {FASTRCNNPREDICTIONS, 4}, + {RPNPROPOSALS, 3}, + {MAXPOOLWITHARGMAX, 2}, + {REGION, 3}, + {TOPKV2, 2}, + {LogTimeStamp, 0}, + /* training op */ + {MAXPOOLWITHARGMAX, 2}, + {FUSEDBATCHNORM, 5}, + {FUSEDBATCHNORMGRAD, 3}, + {SHAPEN, 0}, + {SSDPOSTPROCESSOR, 4}, + {LAYERNORM, 3}, + {LAYERNORMGRAD, 3}, + {SPARSESOFTMAXCROSSENTROPYWITHLOGITS, 2}, +}; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector local_framework_op_vec = { + "TensorDataset", "QueueDataset", "DeviceQueueDataset", "ParallelMapDataset", "BatchDatasetV2", + "IteratorV2", "MakeIterator", "IteratorGetNext", "FilterDataset", "MapAndBatchDatasetV2"}; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector is_dataset_op_vec = { + "TensorDataset", "QueueDataset", "DeviceQueueDataset", "ParallelMapDataset", "BatchDatasetV2", + "IteratorV2", "MakeIterator", "IteratorGetNext", "FilterDataset", "MapAndBatchDatasetV2"}; +} // namespace ge diff --git a/parser/common/op_map.h b/parser/common/op_map.h new file mode 100644 index 0000000..cae651c --- /dev/null +++ b/parser/common/op_map.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_OP_MAP_H_ +#define GE_COMMON_OP_MAP_H_ + +#include +#include +#include + +/*lint -e1073*/ +namespace ge { +// the operator type mapping table of caffe and mindspore +extern std::map caffe_op_map; + +// the operator type mapping table of TensorFlow and mindspore +extern std::map tensorflow_op_map; + +// the network training operator type mapping table of TensorFlow and mindspore +extern std::map tensorflow_train_op_map; + +// local framework op vec +extern std::vector local_framework_op_vec; + +// dataset op vec +extern std::vector is_dataset_op_vec; + +// output tensor num +extern std::map op_output_tensor_num; +} // namespace ge +/*lint +e1073*/ +#endif // GE_COMMON_OP_MAP_H_ diff --git a/parser/common/op_parser_factory.cc b/parser/common/op_parser_factory.cc new file mode 100644 index 0000000..23d95d7 --- /dev/null +++ b/parser/common/op_parser_factory.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/op_parser_factory.h" +#include "common/debug/log.h" +#include "framework/common/debug/ge_log.h" +#include "graph/utils/type_utils.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY CustomParserAdapterRegistry *CustomParserAdapterRegistry::Instance() { + static CustomParserAdapterRegistry instance; + return &instance; +} + +FMK_FUNC_HOST_VISIBILITY void CustomParserAdapterRegistry::Register(const domi::FrameworkType framework, + CustomParserAdapterRegistry::CREATOR_FUN fun) { + if (funcs_.find(framework) != funcs_.end()) { + GELOGW("Framework type %s has already registed.", TypeUtils::FmkTypeToSerialString(framework).c_str()); + return; + } + funcs_[framework] = fun; + GELOGI("Register %s custom parser adapter success.", TypeUtils::FmkTypeToSerialString(framework).c_str()); + return; +} +FMK_FUNC_HOST_VISIBILITY CustomParserAdapterRegistry::CREATOR_FUN +CustomParserAdapterRegistry::GetCreateFunc(const domi::FrameworkType framework) { + if (funcs_.find(framework) == funcs_.end()) { + GELOGW("Framework type %s has not registed.", TypeUtils::FmkTypeToSerialString(framework).c_str()); + return nullptr; + } + return funcs_[framework]; +} + +FMK_FUNC_HOST_VISIBILITY std::shared_ptr OpParserFactory::Instance( + const domi::FrameworkType framework) { + // Each framework corresponds to one op parser factory, + // If instances are static data members of opparserfactory, the order of their construction is uncertain. + // Instances cannot be a member of a class because they may be used before initialization, resulting in a run error. + static std::map> instances; + + auto iter = instances.find(framework); + if (iter == instances.end()) { + std::shared_ptr instance(new (std::nothrow) OpParserFactory()); + if (instance == nullptr) { + GELOGE(INTERNAL_ERROR, "Create op parser factory failed."); + return nullptr; + } + instances[framework] = instance; + return instance; + } + + return iter->second; +} + +FMK_FUNC_HOST_VISIBILITY std::shared_ptr OpParserFactory::CreateOpParser(const std::string &op_type) { + // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. + auto iter = op_parser_creator_map_.find(op_type); + if (iter != op_parser_creator_map_.end()) { + return iter->second(); + } + + GELOGE(FAILED, "OpParserFactory::CreateOpParser: Not supported type: %s", op_type.c_str()); + return nullptr; +} + +FMK_FUNC_HOST_VISIBILITY std::shared_ptr OpParserFactory::CreateFusionOpParser(const std::string &op_type) { + // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser. + auto iter = fusion_op_parser_creator_map_.find(op_type); + if (iter != fusion_op_parser_creator_map_.end()) { + return iter->second(); + } + + GELOGE(FAILED, "OpParserFactory::CreateOpParser: Not supported fusion op type: %s", op_type.c_str()); + return nullptr; +} + +// This function is only called within the constructor of the global opparserregisterar object, +// and does not involve concurrency, so there is no need to lock it +FMK_FUNC_HOST_VISIBILITY void OpParserFactory::RegisterCreator(const std::string &type, CREATOR_FUN fun, + bool is_fusion_op) { + std::map *op_parser_creator_map = &op_parser_creator_map_; + if (is_fusion_op) { + op_parser_creator_map = &fusion_op_parser_creator_map_; + } + + GELOGD("OpParserFactory::RegisterCreator: op type:%s, is_fusion_op:%d.", type.c_str(), is_fusion_op); + (*op_parser_creator_map)[type] = fun; +} + +FMK_FUNC_HOST_VISIBILITY bool OpParserFactory::OpParserIsRegistered(const std::string &op_type, bool is_fusion_op) { + if (is_fusion_op) { + auto iter = fusion_op_parser_creator_map_.find(op_type); + if (iter != fusion_op_parser_creator_map_.end()) { + return true; + } + } else { + auto iter = op_parser_creator_map_.find(op_type); + if (iter != op_parser_creator_map_.end()) { + return true; + } + } + return false; +} +} // namespace ge diff --git a/parser/common/op_parser_factory.h b/parser/common/op_parser_factory.h new file mode 100644 index 0000000..bf867df --- /dev/null +++ b/parser/common/op_parser_factory.h @@ -0,0 +1,198 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_OP_PARSER_FACTORY_H_ +#define PARSER_COMMON_OP_PARSER_FACTORY_H_ + +#include +#include +#include +#include +#include +#include +#include "common/ge/ge_util.h" +#include "framework/omg/parser/parser_types.h" +#include "framework/common/debug/ge_log.h" +#include "omg/omg_inner_types.h" +#include "external/register/register.h" + +using domi::CAFFE; + +namespace ge { +class OpParser; + +/** + * @ingroup domi_omg + * @brief Used to create OpParser + * + */ +class OpParserFactory { + public: + /** + * @ingroup domi_omg + * @brief Returns the OpParserFactory instance corresponding to the Framework + * @return OpParserFactory object + */ + static std::shared_ptr Instance(const domi::FrameworkType framework); + + /** + * @ingroup domi_omg + * @brief Create OpParser based on input type + * @param [in] op_type Op type + * @return Created OpParser + */ + std::shared_ptr CreateOpParser(const std::string &op_type); + + /** + * @ingroup domi_omg + * @brief Create fusion OpParser based on input type + * @param [in] op_type Op type + * @return Created OpParser + */ + std::shared_ptr CreateFusionOpParser(const std::string &op_type); + + // The Factory instance is automatically released by shared_ptr. + // The shared_ptr internally calls the destructor indirectly. + // If the destructor is not public, it will generate a compilation error. + // Another solution is to specify the deleter for shared_ptr, and set the deleter as a friend of the current class. + // But this method is more complicated to implement. + ~OpParserFactory() {} + + bool OpParserIsRegistered(const std::string &op_type, bool is_fusion_op = false); + + protected: + /** + * @ingroup domi_omg + * @brief OpParser creation function + * @return Created OpParser + */ + // typedef shared_ptr (*CREATOR_FUN)(void); + using CREATOR_FUN = std::function(void)>; + + /** + * @ingroup domi_omg + * @brief Factory instances can only be created automatically, not new methods, so the constructor is not public. + */ + OpParserFactory() {} + + /** + * @ingroup domi_omg + * @brief Register creation function + * @param [in] type Op type + * @param [in] fun OpParser creation function + */ + void RegisterCreator(const std::string &type, CREATOR_FUN fun, bool is_fusion_op = false); + + private: + /** + * @ingroup domi_omg + * @brief Each Op corresponds to a Creator function + */ + std::map op_parser_creator_map_; // lint !e1073 + std::map fusion_op_parser_creator_map_; + + friend class OpParserRegisterar; + friend class domi::OpRegistrationData; + friend class OpRegistrationTbe; +}; + +/** + * @ingroup domi_omg + * @brief For registering Creator functions for different types of Op + * + */ +class OpParserRegisterar { + public: + /** + * @ingroup domi_omg + * @brief Constructor + * @param [in] framework Framework type + * @param [in] op_type Op type + * @param [in] fun Creator function corresponding to Op + */ + OpParserRegisterar(const domi::FrameworkType framework, const std::string &op_type, OpParserFactory::CREATOR_FUN fun, + bool is_fusion_op = false) { + OpParserFactory::Instance(framework)->RegisterCreator(op_type, fun, is_fusion_op); + } + ~OpParserRegisterar() {} +}; + +// Used to save the functions created by the xxxCustomParserAdapter class +class CustomParserAdapterRegistry { + public: + static CustomParserAdapterRegistry *Instance(); + using CREATOR_FUN = std::function(void)>; + void Register(const domi::FrameworkType framework, CREATOR_FUN fun); + CREATOR_FUN GetCreateFunc(const domi::FrameworkType framework); + + private: + map funcs_; + + friend class CustomParserAdapterRegistrar; +}; + +// Register Creator function for the custom custom operator ParserAdapter +class CustomParserAdapterRegistrar { + public: + CustomParserAdapterRegistrar(const domi::FrameworkType framework, CustomParserAdapterRegistry::CREATOR_FUN fun) { + CustomParserAdapterRegistry::Instance()->Register(framework, fun); + } + ~CustomParserAdapterRegistrar() {} +}; + +/** + * @ingroup domi_omg + * @brief OpParser Registration Macro + * @param [in] framework Framework type + * @param [in] op_type Op type + * @param [in] clazz OpParser implementation class + */ +#define REGISTER_OP_PARSER_CREATOR(framework, op_type, clazz) \ + std::shared_ptr Creator_##framework##_##op_type##_Op_Parser() { \ + std::shared_ptr ptr = ge::MakeShared(); \ + if (ptr == nullptr) { \ + GELOGW("MakeShared failed, result is nullptr."); \ + } \ + return std::shared_ptr(ptr); \ + } \ + ge::OpParserRegisterar g_##framework##_##op_type##_Op_Parser_Creator(framework, op_type, \ + Creator_##framework##_##op_type##_Op_Parser) + +#define REGISTER_FUSION_OP_PARSER_CREATOR(framework, op_type, clazz) \ + std::shared_ptr Creator_##framework##_##op_type##_Fusion_Op_Parser() { \ + std::shared_ptr ptr = ge::MakeShared(); \ + if (ptr == nullptr) { \ + GELOGW("MakeShared failed, result is nullptr."); \ + } \ + return std::shared_ptr(ptr); \ + } \ + OpParserRegisterar g_##framework##_##op_type##_Fusion_Op_Parser_Creator( \ + framework, op_type, Creator_##framework##_##op_type##_Fusion_Op_Parser, true) + +/// @brief xxxCustomParserAdapter Registration Macro +/// @param [in] framework Framework type +/// @param [in] clazz CaffeCustomParserAdapter adaptation class +#define REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(framework, clazz) \ + std::shared_ptr Creator_##framework##_Op_Parser_Adapter() { \ + std::shared_ptr ptr = ge::MakeShared(); \ + if (ptr == nullptr) { \ + GELOGW("MakeShared failed, result is nullptr."); \ + } \ + return std::shared_ptr(ptr); \ + } \ + CustomParserAdapterRegistrar g_##framework##_Op_Parser_Creator(framework, Creator_##framework##_Op_Parser_Adapter) +} // namespace ge +#endif // PARSER_COMMON_OP_PARSER_FACTORY_H_ diff --git a/parser/common/parser_api.cc b/parser/common/parser_api.cc new file mode 100644 index 0000000..d582ed7 --- /dev/null +++ b/parser/common/parser_api.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "framework/omg/parser/parser_api.h" +#include "common/debug/log.h" + +#include "tbe_plugin_loader.h" +#include "framework/common/debug/ge_log.h" +#include "parser/common/register_tbe.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "external/ge/ge_api_types.h" + +namespace ge { +static bool parser_initialized = false; +// Initialize PARSER, load custom op plugin +// options will be used later for parser decoupling +Status ParserInitialize(const std::map &options) { + GELOGT(TRACE_INIT, "ParserInitialize start"); + // check init status + if (parser_initialized) { + GELOGW("ParserInitialize is called more than once"); + return SUCCESS; + } + + // load custom op plugin + TBEPluginLoader::Instance().LoadPluginSo(options); + + std::vector registrationDatas = domi::OpRegistry::Instance()->registrationDatas; + GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); + for (OpRegistrationData ®_data : registrationDatas) { + (void)OpRegistrationTbe::Instance()->Finalize(reg_data, true); + } + + auto iter = options.find(ge::OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES); + if (iter != options.end()) { + ge::GetParserContext().enable_scope_fusion_passes = iter->second; + } + + // set init status + if (!parser_initialized) { + // Initialize success, first time calling initialize + parser_initialized = true; + } + + GELOGT(TRACE_STOP, "ParserInitialize finished"); + return SUCCESS; +} + +Status ParserFinalize() { + GELOGT(TRACE_INIT, "ParserFinalize start"); + // check init status + if (!parser_initialized) { + GELOGW("ParserFinalize is called before ParserInitialize"); + return SUCCESS; + } + + GE_CHK_STATUS(TBEPluginLoader::Instance().Finalize()); + if (parser_initialized) { + parser_initialized = false; + } + return SUCCESS; +} +} // namespace ge diff --git a/parser/common/parser_factory.cc b/parser/common/parser_factory.cc new file mode 100644 index 0000000..ce85d3c --- /dev/null +++ b/parser/common/parser_factory.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "omg/parser/parser_factory.h" +#include "common/debug/log.h" +#include "framework/common/debug/ge_log.h" + +namespace domi { +FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() { + static WeightsParserFactory instance; + return &instance; +} + +std::shared_ptr WeightsParserFactory::CreateWeightsParser(const domi::FrameworkType type) { + std::map::iterator iter = creator_map_.find(type); + if (iter != creator_map_.end()) { + return iter->second(); + } + + GELOGE(FAILED, "WeightsParserFactory::CreateWeightsParser: Not supported Type: %d", type); + return nullptr; +} + +FMK_FUNC_HOST_VISIBILITY void WeightsParserFactory::RegisterCreator(const domi::FrameworkType type, + WEIGHTS_PARSER_CREATOR_FUN fun) { + std::map::iterator iter = creator_map_.find(type); + if (iter != creator_map_.end()) { + GELOGW("WeightsParserFactory::RegisterCreator: %d creator already exist", type); + return; + } + + creator_map_[type] = fun; +} + +WeightsParserFactory::~WeightsParserFactory() { + creator_map_.clear(); +} + +FMK_FUNC_HOST_VISIBILITY ModelParserFactory *ModelParserFactory::Instance() { + static ModelParserFactory instance; + return &instance; +} + +std::shared_ptr ModelParserFactory::CreateModelParser(const domi::FrameworkType type) { + std::map::iterator iter = creator_map_.find(type); + if (iter != creator_map_.end()) { + return iter->second(); + } + + GELOGE(FAILED, "ModelParserFactory::CreateModelParser: Not supported Type: %d", type); + return nullptr; +} + +FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::FrameworkType type, + MODEL_PARSER_CREATOR_FUN fun) { + std::map::iterator iter = creator_map_.find(type); + if (iter != creator_map_.end()) { + GELOGW("ModelParserFactory::RegisterCreator: %d creator already exist", type); + return; + } + + creator_map_[type] = fun; +} + +ModelParserFactory::~ModelParserFactory() { + creator_map_.clear(); +} +} // namespace domi diff --git a/parser/common/parser_fp16_t.cc b/parser/common/parser_fp16_t.cc new file mode 100644 index 0000000..044eb5c --- /dev/null +++ b/parser/common/parser_fp16_t.cc @@ -0,0 +1,1270 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/parser_fp16_t.h" + +#include "external/register/register_types.h" + +namespace { +constexpr uint16_t kManBitLength = 11; +} +namespace ge { +namespace parser { +/// @ingroup fp16_t global filed +/// @brief round mode of last valid digital +enum TagFp16RoundMode g_round_mode = kRoundToNearest; + +void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) { + // 1.Extract + s = static_cast(FP16_EXTRAC_SIGN(val)); + e = static_cast(FP16_EXTRAC_EXP(val)); + m = static_cast(FP16_EXTRAC_MAN(val)); + // Denormal + if (e == 0) { + e = 1; + } +} + +/// @ingroup fp16_t static method +/// @param [in] man truncated mantissa +/// @param [in] shift_out left shift bits based on ten bits +/// @brief judge whether to add one to the result while converting fp16_t to other datatype +/// @return Return true if add one, otherwise false +static bool IsRoundOne(uint64_t man, uint16_t trunc_len) { + uint64_t mask0 = 0x4; + uint64_t mask1 = 0x2; + uint64_t mask2; + uint16_t shift_out = static_cast(trunc_len - kDim2); + mask0 = mask0 << shift_out; + mask1 = mask1 << shift_out; + mask2 = mask1 - 1; + + bool last_bit = ((man & mask0) > 0); + bool trunc_high = false; + bool trunc_left = false; + if (g_round_mode == kRoundToNearest) { + trunc_high = ((man & mask1) > 0); + trunc_left = ((man & mask2) > 0); + } + return (trunc_high && (trunc_left || last_bit)); +} + +/// @ingroup fp16_t public method +/// @param [in] exp exponent of fp16_t value +/// @param [in] man exponent of fp16_t value +/// @brief normalize fp16_t value +/// @return +static void Fp16Normalize(int16_t &exp, uint16_t &man) { + // set to invalid data + if (exp >= kFp16MaxExp) { + exp = static_cast(kFp16MaxExp); + man = static_cast(kFp16MaxMan); + } else if (exp == 0 && man == kFp16ManHideBit) { + exp++; + man = 0; + } +} + +/// @ingroup fp16_t math conversion static method +/// @param [in] fp_val uint16_t value of fp16_t object +/// @brief Convert fp16_t to float/fp32 +/// @return Return float/fp32 value of fp_val which is the value of fp16_t object +static float Fp16ToFloat(const uint16_t &fp_val) { + uint16_t hf_sign; + uint16_t hf_man; + int16_t hf_exp; + ExtractFp16(fp_val, hf_sign, hf_exp, hf_man); + + while (hf_man && !(hf_man & kFp16ManHideBit)) { + hf_man <<= 1; + hf_exp--; + } + + uint32_t e_ret, m_ret; + uint32_t s_ret = hf_sign; + if (hf_man == 0) { + e_ret = 0; + m_ret = 0; + } else { + e_ret = hf_exp - kFp16ExpBias + kFp32ExpBias; + m_ret = hf_man & kFp16ManMask; + m_ret = m_ret << (kFp32ManLen - kFp16ManLen); + } + uint32_t f_val = FP32_CONSTRUCTOR(s_ret, e_ret, m_ret); + auto p_ret_v = reinterpret_cast(&f_val); + + return *p_ret_v; +} + +/// @ingroup fp16_t math conversion static method +/// @param [in] fp_val uint16_t value of fp16_t object +/// @brief Convert fp16_t to double/fp64 +/// @return Return double/fp64 value of fp_val which is the value of fp16_t object +static double Fp16ToDouble(const uint16_t &fp_val) { + uint16_t hf_sign; + uint16_t hf_man; + int16_t hf_exp; + ExtractFp16(fp_val, hf_sign, hf_exp, hf_man); + + while (hf_man && !(hf_man & kFp16ManHideBit)) { + hf_man <<= 1; + hf_exp--; + } + + uint64_t e_ret; + uint64_t m_ret; + uint64_t s_ret = hf_sign; + if (!hf_man) { + e_ret = 0; + m_ret = 0; + } else { + e_ret = hf_exp - kFp16ExpBias + kFp64ExpBias; + m_ret = hf_man & kFp16ManMask; + m_ret = m_ret << (kFp64ManLen - kFp16ManLen); + } + uint64_t f_val = (s_ret << kFp64SignIndex) | (e_ret << kFp64ManLen) | (m_ret); + auto p_ret_v = reinterpret_cast(&f_val); + + return *p_ret_v; +} + +/// @ingroup fp16_t static method +/// @param [in] s_ret sign of fp16_t value +/// @param [in] long_int_m man uint64_t value of fp16_t object +/// @param [in] shift_out shift offset +/// @brief calculate uint8 value by sign,man and shift offset +/// @return Return uint8 value of fp16_t object +static uint8_t GetUint8ValByMan(uint8_t s_ret, const uint64_t &long_int_m, const uint16_t &shift_out) { + bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); + auto m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen8Max); + need_round = need_round && ((s_ret == 0 && m_ret < kInt8Max) || (s_ret == 1 && m_ret <= kInt8Max)); + if (need_round) { + m_ret++; + } + if (s_ret) { + m_ret = (~m_ret) + 1; + } + if (m_ret == 0) { + s_ret = 0; + } + return static_cast((s_ret << kBitShift7) | (m_ret)); +} + +/// @ingroup fp16_t math conversion static method +/// @param [in] fp_val uint16_t value of fp16_t object +/// @brief Convert fp16_t to int8_t +/// @return Return int8_t value of fp_val which is the value of fp16_t object +static int8_t Fp16ToInt8(const uint16_t &fp_val) { + int8_t ret; + uint8_t ret_v; + // 1.get s_ret and shift it to bit0. + uint8_t s_ret = FP16_EXTRAC_SIGN(fp_val); + // 2.get hf_e and hf_m + uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); + uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); + + if (FP16_IS_DENORM(fp_val)) { // Denormalized number + ret_v = 0; + ret = *(reinterpret_cast(&ret_v)); + return ret; + } + + uint64_t long_int_m = hf_m; + uint8_t overflow_flag = 0; + uint16_t shift_out = 0; + if (FP16_IS_INVALID(fp_val)) { // Inf or NaN + overflow_flag = 1; + } else { + while (hf_e != kFp16ExpBias) { + if (hf_e > kFp16ExpBias) { + hf_e--; + long_int_m = long_int_m << 1; + if (s_ret == 1 && long_int_m >= 0x20000u) { // sign=1,negative number(<0) + long_int_m = 0x20000u; // 10 0000 0000 0000 0000 10(fp16_t-man)+7(int8)=17bit + overflow_flag = 1; + break; + } else if (s_ret != 1 && long_int_m >= 0x1FFFFu) { // sign=0,positive number(>0) + long_int_m = 0x1FFFFu; // 01 1111 1111 1111 1111 10(fp16_t-man)+7(int8) + overflow_flag = 1; + break; + } + } else { + hf_e++; + shift_out++; + } + } + } + if (overflow_flag) { + ret_v = kInt8Max + s_ret; + } else { + // Generate final result + ret_v = GetUint8ValByMan(s_ret, long_int_m, shift_out); + } + + ret = *(reinterpret_cast(&ret_v)); + return ret; +} + +/// @ingroup fp16_t math conversion static method +/// @param [in] fp_val uint16_t value of fp16_t object +/// @brief Convert fp16_t to uint8_t +/// @return Return uint8_t value of fp_val which is the value of fp16_t object +static uint8_t Fp16ToUInt8(const uint16_t &fp_val) { + uint8_t m_ret = 0; + // 1.get s_ret and shift it to bit0. + uint8_t s_ret = FP16_EXTRAC_SIGN(fp_val); + // 2.get hf_e and hf_m + uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); + uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); + + if (FP16_IS_DENORM(fp_val)) { // Denormalized number + return 0; + } + + if (FP16_IS_INVALID(fp_val)) { // Inf or NaN + m_ret = ~0; + } else { + uint64_t long_int_m = hf_m; + uint8_t overflow_flag = 0; + uint16_t shift_out = 0; + while (hf_e != kFp16ExpBias) { + if (hf_e > kFp16ExpBias) { + hf_e--; + long_int_m = long_int_m << 1; + if (long_int_m >= 0x40000Lu) { // overflow 0100 0000 0000 0000 0000 + long_int_m = 0x3FFFFLu; // 11 1111 1111 1111 1111 10(fp16_t-man)+8(uint8)=18bit + overflow_flag = 1; + m_ret = ~0; + break; + } + } else { + hf_e++; + shift_out++; + } + } + if (!overflow_flag) { + bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); + m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen8Max); + if (need_round && m_ret != kBitLen8Max) { + m_ret++; + } + } + } + + if (s_ret == 1) { // Negative number + m_ret = 0; + } + // m_ret equal to final result + return m_ret; +} + +/// @ingroup fp16_t static method +/// @param [in] s_ret sign of fp16_t value +/// @param [in] long_int_m man uint64_t value of fp16_t object +/// @param [in] shift_out shift offset +/// @brief calculate uint16 value by sign,man and shift offset +/// @return Return uint16 value of fp16_t object +static uint16_t GetUint16ValByMan(uint16_t s_ret, const uint64_t &long_int_m, const uint16_t &shift_out) { + bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); + auto m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen16Max); + if (need_round && m_ret < kInt16Max) { + m_ret++; + } + if (s_ret) { + m_ret = (~m_ret) + 1; + } + if (m_ret == 0) { + s_ret = 0; + } + return static_cast((s_ret << kBitShift15) | (m_ret)); +} + +/// @ingroup fp16_t math conversion static method +/// @param [in] fp_val uint16_t value of fp16_t object +/// @brief Convert fp16_t to int16_t +/// @return Return int16_t value of fp_val which is the value of fp16_t object +static int16_t Fp16ToInt16(const uint16_t &fp_val) { + int16_t ret; + uint16_t ret_v; + // 1.get s_ret and shift it to bit0. + uint16_t s_ret = FP16_EXTRAC_SIGN(fp_val); + // 2.get hf_e and hf_m + uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); + uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); + + if (FP16_IS_DENORM(fp_val)) { // Denormalized number + ret_v = 0; + ret = *(reinterpret_cast(&ret_v)); + return ret; + } + + uint64_t long_int_m = hf_m; + uint8_t overflow_flag = 0; + uint16_t shift_out = 0; + if (FP16_IS_INVALID(fp_val)) { // Inf or NaN + overflow_flag = 1; + } else { + while (hf_e != kFp16ExpBias) { + if (hf_e > kFp16ExpBias) { + hf_e--; + long_int_m = long_int_m << 1; + if (s_ret == 1 && long_int_m > 0x2000000Lu) { // sign=1,negative number(<0) + long_int_m = 0x2000000Lu; // 10(fp16_t-man)+15(int16)=25bit + overflow_flag = 1; + break; + } else if (s_ret != 1 && long_int_m >= 0x1FFFFFFLu) { // sign=0,positive number(>0) Overflow + long_int_m = 0x1FFFFFFLu; // 10(fp16_t-man)+15(int16)=25bit + overflow_flag = 1; + break; + } + } else { + hf_e++; + shift_out++; + } + } + } + if (overflow_flag) { + ret_v = kInt16Max + s_ret; + } else { + // Generate final result + ret_v = GetUint16ValByMan(s_ret, long_int_m, shift_out); + } + ret = *(reinterpret_cast(&ret_v)); + return ret; +} + +/// @ingroup fp16_t math conversion static method +/// @param [in] fp_val uint16_t value of fp16_t object +/// @brief Convert fp16_t to uint16_t +/// @return Return uint16_t value of fp_val which is the value of fp16_t object +static uint16_t Fp16ToUInt16(const uint16_t &fp_val) { + uint16_t m_ret = 0; + // 1.get s_ret and shift it to bit0. + uint16_t s_ret = FP16_EXTRAC_SIGN(fp_val); + // 2.get hf_e and hf_m + uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); + uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); + + if (FP16_IS_DENORM(fp_val)) { // Denormalized number + return 0; + } + + if (FP16_IS_INVALID(fp_val)) { // Inf or NaN + m_ret = ~0; + } else { + uint64_t long_int_m = hf_m; + uint16_t shift_out = 0; + while (hf_e != kFp16ExpBias) { + if (hf_e > kFp16ExpBias) { + hf_e--; + long_int_m = long_int_m << 1; + } else { + hf_e++; + shift_out++; + } + } + bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); + m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen16Max); + if (need_round && m_ret != kBitLen16Max) { + m_ret++; + } + } + + if (s_ret == 1) { // Negative number + m_ret = 0; + } + // m_ret equal to final result + return m_ret; +} + +/// @ingroup fp16_t math convertion static method +/// @param [in] fp_val uint16_t value of fp16_t object +/// @brief Convert fp16_t to int32_t +/// @return Return int32_t value of fp_val which is the value of fp16_t object +static int32_t Fp16ToInt32(const uint16_t &fp_val) { + uint32_t ret_v; + // 1.get s_ret and shift it to bit0. + uint32_t s_ret = FP16_EXTRAC_SIGN(fp_val); + // 2.get hf_e and hf_m + uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); + uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); + + if (FP16_IS_INVALID(fp_val)) { // Inf or NaN + ret_v = kInt32Max + s_ret; + } else { + uint64_t long_int_m = hf_m; + uint16_t shift_out = 0; + + while (hf_e != kFp16ExpBias) { + if (hf_e > kFp16ExpBias) { + hf_e--; + long_int_m = long_int_m << 1; + } else { + hf_e++; + shift_out++; + } + } + bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); + auto m_ret = static_cast((long_int_m >> (kFp16ManLen + shift_out)) & kBitLen32Max); + if (need_round && m_ret < kInt32Max) { + m_ret++; + } + + if (s_ret == 1) { + m_ret = (~m_ret) + 1; + } + if (m_ret == 0) { + s_ret = 0; + } + // Generate final result + ret_v = (s_ret << kBitShift31) | (m_ret); + } + + return *(reinterpret_cast(&ret_v)); +} + +/// @ingroup fp16_t math conversion static method +/// @param [in] fp_val uint16_t value of fp16_t object +/// @brief Convert fp16_t to uint32_t +/// @return Return uint32_t value of fp_val which is the value of fp16_t object +static uint32_t Fp16ToUInt32(const uint16_t &fp_val) { + uint32_t m_ret; + // 1.get s_ret and shift it to bit0. + uint32_t s_ret = FP16_EXTRAC_SIGN(fp_val); + // 2.get hf_e and hf_m + uint16_t hf_e = FP16_EXTRAC_EXP(fp_val); + uint16_t hf_m = FP16_EXTRAC_MAN(fp_val); + + if (FP16_IS_DENORM(fp_val)) { // Denormalized number + return 0u; + } + + if (FP16_IS_INVALID(fp_val)) { // Inf or NaN + m_ret = ~0u; + } else { + uint64_t long_int_m = hf_m; + uint16_t shift_out = 0; + while (hf_e != kFp16ExpBias) { + if (hf_e > kFp16ExpBias) { + hf_e--; + long_int_m = long_int_m << 1; + } else { + hf_e++; + shift_out++; + } + } + bool need_round = IsRoundOne(long_int_m, shift_out + kFp16ManLen); + m_ret = static_cast(long_int_m >> (kFp16ManLen + shift_out)) & kBitLen32Max; + if (need_round && m_ret != kBitLen32Max) { + m_ret++; + } + } + + if (s_ret == 1) { // Negative number + m_ret = 0; + } + // m_ret equal to final result + return m_ret; +} + +static uint16_t Fp16AddCalVal(uint16_t &s_ret, int16_t e_ret, uint16_t m_ret, uint32_t m_trunc, uint16_t shift_out) { + uint16_t m_min = kFp16ManHideBit << shift_out; + uint16_t m_max = m_min << 1; + // Denormal + while (m_ret < m_min && e_ret > 0) { // the value of m_ret should not be smaller than 2^23 + m_ret = m_ret << 1; + m_ret += (kFp32SignMask & m_trunc) >> kFp32SignIndex; + m_trunc = m_trunc << 1; + e_ret = e_ret - 1; + } + while (m_ret >= m_max) { // the value of m_ret should be smaller than 2^24 + m_trunc = m_trunc >> 1; + m_trunc = m_trunc | (kFp32SignMask * (m_ret & 1)); + m_ret = m_ret >> 1; + e_ret = e_ret + 1; + } + + bool b_last_bit = ((m_ret & 1) > 0); + bool b_trunc_high = 0; + bool b_trunc_left = 0; + b_trunc_high = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32SignMask) > 0); + b_trunc_left = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32AbsMax) > 0); + m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out); + while (m_ret >= m_max) { + m_ret = m_ret >> 1; + e_ret = e_ret + 1; + } + + if (e_ret == 0 && m_ret <= m_max) { + m_ret = m_ret >> 1; + } + Fp16Normalize(e_ret, m_ret); + uint16_t ret = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); + return ret; +} + +/// @ingroup fp16_t math operator +/// @param [in] v_1 left operator value of fp16_t object +/// @param [in] v_2 right operator value of fp16_t object +/// @brief Performing fp16_t addition +/// @return Return fp16_t result of adding this and fp +static uint16_t Fp16Add(uint16_t v_1, uint16_t v_2) { + uint16_t s_a; + uint16_t s_b; + int16_t e_a; + int16_t e_b; + uint32_t m_a; + uint32_t m_b; + uint16_t m_a_tmp; + uint16_t m_b_tmp; + uint16_t shift_out = 0; + // 1.Extract + ExtractFp16(v_1, s_a, e_a, m_a_tmp); + ExtractFp16(v_2, s_b, e_b, m_b_tmp); + m_a = m_a_tmp; + m_b = m_b_tmp; + + uint16_t sum; + uint16_t s_ret; + if (s_a != s_b) { + ReverseMan(s_a > 0, m_a); + ReverseMan(s_b > 0, m_b); + sum = static_cast(GetManSum(e_a, m_a, e_b, m_b)); + s_ret = (sum & kFp16SignMask) >> kFp16SignIndex; + ReverseMan(s_ret > 0, m_a); + ReverseMan(s_ret > 0, m_b); + } else { + sum = static_cast(GetManSum(e_a, m_a, e_b, m_b)); + s_ret = s_a; + } + + if (sum == 0) { + shift_out = 3; // shift to left 3 bits + m_a = m_a << shift_out; + m_b = m_b << shift_out; + } + + uint32_t m_trunc = 0; + int16_t e_ret = std::max(e_a, e_b); + int16_t e_tmp = std::abs(e_a - e_b); + if (e_a > e_b) { + m_trunc = (m_b << (kBitShift32 - static_cast(e_tmp))); + m_b = RightShift(m_b, e_tmp); + } else if (e_a < e_b) { + m_trunc = (m_a << (kBitShift32 - static_cast(e_tmp))); + m_a = RightShift(m_a, e_tmp); + } + // calculate mantissav + auto m_ret = static_cast(m_a + m_b); + return Fp16AddCalVal(s_ret, e_ret, m_ret, m_trunc, shift_out); +} + +/// @ingroup fp16_t math operator +/// @param [in] v_1 left operator value of fp16_t object +/// @param [in] v_2 right operator value of fp16_t object +/// @brief Performing fp16_t subtraction +/// @return Return fp16_t result of subtraction fp from this +static uint16_t Fp16Sub(uint16_t v_1, uint16_t v_2) { + // Reverse + uint16_t tmp = ((~(v_2)) & kFp16SignMask) | (v_2 & kFp16AbsMax); + return Fp16Add(v_1, tmp); +} + +/// @ingroup fp16_t math operator +/// @param [in] v_1 left operator value of fp16_t object +/// @param [in] v_2 right operator value of fp16_t object +/// @brief Performing fp16_t multiplication +/// @return Return fp16_t result of multiplying this and fp +static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { + uint16_t s_a, s_b; + int16_t e_a, e_b; + uint32_t m_a, m_b; + uint16_t s_ret, m_ret; + int16_t e_ret; + uint32_t mul_m; + uint16_t m_a_tmp, m_b_tmp; + // 1.Extract + ExtractFp16(v_1, s_a, e_a, m_a_tmp); + ExtractFp16(v_2, s_b, e_b, m_b_tmp); + m_a = m_a_tmp; + m_b = m_b_tmp; + + e_ret = e_a + e_b - kFp16ExpBias - kDim10; + mul_m = m_a * m_b; + s_ret = s_a ^ s_b; + + uint32_t m_min = kFp16ManHideBit; + uint32_t m_max = m_min << 1; + uint32_t m_trunc = 0; + // the value of m_ret should not be smaller than 2^23 + while (mul_m < m_min && e_ret > 1) { + mul_m = mul_m << 1; + e_ret = e_ret - 1; + } + while (mul_m >= m_max || e_ret < 1) { + m_trunc = m_trunc >> 1; + m_trunc = m_trunc | (kFp32SignMask * (mul_m & 1)); + mul_m = mul_m >> 1; + e_ret = e_ret + 1; + } + bool b_last_bit = ((mul_m & 1) > 0); + bool b_trunc_high = 0; + bool b_trunc_left = 0; + b_trunc_high = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32SignMask) > 0); + b_trunc_left = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32AbsMax) > 0); + mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m); + + while (mul_m >= m_max || e_ret < 0) { + mul_m = mul_m >> 1; + e_ret = e_ret + 1; + } + + if (e_ret == 1 && mul_m < kFp16ManHideBit) { + e_ret = 0; + } + m_ret = static_cast(mul_m); + + Fp16Normalize(e_ret, m_ret); + + uint16_t ret = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); + return ret; +} + +/// @ingroup fp16_t math operator divided +/// @param [in] v_1 left operator value of fp16_t object +/// @param [in] v_2 right operator value of fp16_t object +/// @brief Performing fp16_t division +/// @return Return fp16_t result of division this by fp +static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) { + uint16_t ret; + if (FP16_IS_ZERO(v_2)) { // result is inf + // throw "fp16_t division by zero."; + uint16_t s_a, s_b; + uint16_t s_ret; + s_a = FP16_EXTRAC_SIGN(v_1); + s_b = FP16_EXTRAC_SIGN(v_2); + s_ret = s_a ^ s_b; + ret = FP16_CONSTRUCTOR(s_ret, kFp16MaxExp, 0u); + } else if (FP16_IS_ZERO(v_1)) { + ret = 0u; + } else { + uint16_t s_a, s_b; + int16_t e_a, e_b; + uint64_t m_a, m_b; + float m_div; + uint16_t m_a_tmp, m_b_tmp; + // 1.Extract + ExtractFp16(v_1, s_a, e_a, m_a_tmp); + ExtractFp16(v_2, s_b, e_b, m_b_tmp); + m_a = m_a_tmp; + m_b = m_b_tmp; + + uint64_t m_tmp; + if (e_a > e_b) { + m_tmp = m_a; + uint16_t tmp; + tmp = e_a - e_b; + for (int i = 0; i < tmp; i++) { + m_tmp = m_tmp << 1; + } + m_a = m_tmp; + } else if (e_a < e_b) { + m_tmp = m_b; + uint16_t tmp = e_b - e_a; + for (int i = 0; i < tmp; i++) { + m_tmp = m_tmp << 1; + } + m_b = m_tmp; + } + m_div = static_cast(m_a * 1.0f / m_b); + fp16_t fp_div; + fp_div = m_div; + ret = fp_div.val; + if (s_a != s_b) { + ret |= kFp16SignMask; + } + } + return ret; +} + +// operate +fp16_t fp16_t::operator+(const fp16_t fp) { + uint16_t ret_val = Fp16Add(val, fp.val); + fp16_t ret(ret_val); + return ret; +} + +fp16_t fp16_t::operator-(const fp16_t fp) { + uint16_t ret_val = Fp16Sub(val, fp.val); + fp16_t ret(ret_val); + return ret; +} + +fp16_t fp16_t::operator*(const fp16_t fp) { + uint16_t ret_val = Fp16Mul(val, fp.val); + fp16_t ret(ret_val); + return ret; +} + +fp16_t fp16_t::operator/(const fp16_t fp) { + uint16_t ret_val = Fp16Div(val, fp.val); + fp16_t ret(ret_val); + return ret; +} + +fp16_t fp16_t::operator+=(const fp16_t fp) { + val = Fp16Add(val, fp.val); + return *this; +} + +fp16_t fp16_t::operator-=(const fp16_t fp) { + val = Fp16Sub(val, fp.val); + return *this; +} + +fp16_t fp16_t::operator*=(const fp16_t fp) { + val = Fp16Mul(val, fp.val); + return *this; +} + +fp16_t fp16_t::operator/=(const fp16_t fp) { + val = Fp16Div(val, fp.val); + return *this; +} + +// compare +bool fp16_t::operator==(const fp16_t &fp) const { + bool result = true; + if (FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val)) { + result = true; + } else { + result = ((val & kBitLen16Max) == (fp.val & kBitLen16Max)); // bit compare + } + return result; +} + +bool fp16_t::operator!=(const fp16_t &fp) const { + bool result = true; + if (FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val)) { + result = false; + } else { + result = ((val & kBitLen16Max) != (fp.val & kBitLen16Max)); // bit compare + } + return result; +} + +bool fp16_t::operator>(const fp16_t &fp) const { + uint16_t s_a, s_b; + uint16_t e_a, e_b; + uint16_t m_a, m_b; + bool result = true; + + // 1.Extract + s_a = FP16_EXTRAC_SIGN(val); + s_b = FP16_EXTRAC_SIGN(fp.val); + e_a = FP16_EXTRAC_EXP(val); + e_b = FP16_EXTRAC_EXP(fp.val); + m_a = FP16_EXTRAC_MAN(val); + m_b = FP16_EXTRAC_MAN(fp.val); + + // Compare + if ((s_a == 0) && (s_b > 0)) { // + - + // -0=0 + result = !(FP16_IS_ZERO(val) && FP16_IS_ZERO(fp.val)); + } else if ((s_a == 0) && (s_b == 0)) { // + + + if (e_a > e_b) { // e_a - e_b >= 1; Va always larger than Vb + result = true; + } else if (e_a == e_b) { + result = m_a > m_b; + } else { + result = false; + } + } else if ((s_a > 0) && (s_b > 0)) { // - - opposite to + + + if (e_a < e_b) { + result = true; + } else if (e_a == e_b) { + result = m_a < m_b; + } else { + result = false; + } + } else { // - + + result = false; + } + + return result; +} + +bool fp16_t::operator>=(const fp16_t &fp) const { + bool result = true; + if ((*this) > fp) { + result = true; + } else if ((*this) == fp) { + result = true; + } else { + result = false; + } + + return result; +} + +bool fp16_t::operator<(const fp16_t &fp) const { + bool result = true; + if ((*this) >= fp) { + result = false; + } else { + result = true; + } + + return result; +} + +bool fp16_t::operator<=(const fp16_t &fp) const { + bool result = true; + if ((*this) > fp) { + result = false; + } else { + result = true; + } + + return result; +} + +// evaluation +fp16_t &fp16_t::operator=(const fp16_t &fp) { + if (&fp == this) { + return *this; + } + val = fp.val; + return *this; +} + +fp16_t &fp16_t::operator=(const float &f_val) { + uint16_t s_ret, m_ret; + int16_t e_ret; + uint32_t e_f, m_f; + const uint32_t ui32_v = *(reinterpret_cast(&f_val)); // 1:8:23bit sign:exp:man + uint32_t m_len_delta; + + s_ret = static_cast((ui32_v & kFp32SignMask) >> kFp32SignIndex); // 4Byte->2Byte + e_f = (ui32_v & kFp32ExpMask) >> kFp32ManLen; // 8 bit exponent + m_f = (ui32_v & kFp32ManMask); // 23 bit mantissa dont't need to care about denormal + m_len_delta = kFp32ManLen - kFp16ManLen; + + bool need_round = false; + // Exponent overflow/NaN converts to signed inf/NaN + if (e_f > 0x8Fu) { // 0x8Fu:142=127+15 + e_ret = kFp16MaxExp - 1; + m_ret = kFp16MaxMan; + } else if (e_f <= 0x70u) { // 0x70u:112=127-15 Exponent underflow converts to denormalized half or signed zero + e_ret = 0; + if (e_f >= 0x67) { // 0x67:103=127-24 Denormal + m_f = (m_f | kFp32ManHideBit); + uint16_t shift_out = kFp32ManLen; + uint64_t m_tmp = (static_cast(m_f)) << (e_f - 0x67); + + need_round = IsRoundOne(m_tmp, shift_out); + m_ret = static_cast(m_tmp >> shift_out); + if (need_round) { + m_ret++; + } + } else if (e_f == 0x66 && m_f > 0) { // 0x66:102 Denormal 0(e_f - 0x70u); + + need_round = IsRoundOne(m_f, static_cast(m_len_delta)); + m_ret = static_cast(m_f >> m_len_delta); + if (need_round) { + m_ret++; + } + if (m_ret & kFp16ManHideBit) { + e_ret++; + } + } + + Fp16Normalize(e_ret, m_ret); + val = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); + return *this; +} + +fp16_t &fp16_t::operator=(const int8_t &i_val) { + uint16_t s_ret, e_ret, m_ret; + + s_ret = static_cast(((static_cast(i_val)) & 0x80) >> kDim7); + m_ret = static_cast(((static_cast(i_val)) & kInt8Max)); + + if (m_ret == 0) { + e_ret = 0; + } else { + if (s_ret) { // negative number(<0) + m_ret = static_cast(std::abs(i_val)); // complement + } + + e_ret = kFp16ManLen; + while ((m_ret & kFp16ManHideBit) == 0) { + m_ret = m_ret << 1; + e_ret = e_ret - 1; + } + e_ret = e_ret + kFp16ExpBias; + } + + val = FP16_CONSTRUCTOR(s_ret, e_ret, m_ret); + return *this; +} + +fp16_t &fp16_t::operator=(const uint8_t &ui_val) { + uint16_t s_ret, e_ret, m_ret; + s_ret = 0; + e_ret = 0; + m_ret = ui_val; + if (m_ret) { + e_ret = kFp16ManLen; + while ((m_ret & kFp16ManHideBit) == 0) { + m_ret = m_ret << 1; + e_ret = e_ret - 1; + } + e_ret = e_ret + kFp16ExpBias; + } + + val = FP16_CONSTRUCTOR(s_ret, e_ret, m_ret); + return *this; +} + +static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, uint16_t &ret_val) { + uint32_t m_tmp = (input_val & kFp32AbsMax); + uint16_t m_min = kFp16ManHideBit; + uint16_t m_max = m_min << 1; + uint16_t len = static_cast(GetManBitLength(m_tmp)); + if (m_tmp) { + int16_t e_ret; + if (len > kDim11) { + e_ret = kFp16ExpBias + kFp16ManLen; + uint16_t e_tmp = len - kDim11; + uint32_t trunc_mask = 1; + for (int i = 1; i < e_tmp; i++) { + trunc_mask = (trunc_mask << 1) + 1; + } + uint32_t m_trunc = (m_tmp & trunc_mask) << (kBitShift32 - e_tmp); + for (int i = 0; i < e_tmp; i++) { + m_tmp = (m_tmp >> 1); + e_ret = e_ret + 1; + } + bool b_last_bit = ((m_tmp & 1) > 0); + bool b_trunc_high = 0; + bool b_trunc_left = 0; + if (kRoundToNearest == g_round_mode) { // trunc + b_trunc_high = ((m_trunc & kFp32SignMask) > 0); + b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); + } + m_tmp = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_tmp); + while (m_tmp >= m_max || e_ret < 0) { + m_tmp = m_tmp >> 1; + e_ret = e_ret + 1; + } + } else { + e_ret = kFp16ExpBias; + m_tmp = m_tmp << (kManBitLength - len); + e_ret = e_ret + (len - 1); + } + auto m_ret = static_cast(m_tmp); + ret_val = FP16_CONSTRUCTOR(sign, static_cast(e_ret), m_ret); + } +} + +fp16_t &fp16_t::operator=(const int16_t &i_val) { + if (i_val == 0) { + val = 0; + } else { + uint16_t ui_val = *(reinterpret_cast(&i_val)); + auto s_ret = static_cast(ui_val >> kBitShift15); + if (s_ret) { + int16_t iValM = -i_val; + ui_val = *(reinterpret_cast(&iValM)); + } + SetValByUint16Val(ui_val, s_ret, val); + } + return *this; +} + +fp16_t &fp16_t::operator=(const uint16_t &ui_val) { + if (ui_val == 0) { + val = 0; + } else { + int16_t e_ret; + uint16_t m_ret = ui_val; + uint16_t m_min = kFp16ManHideBit; + uint16_t m_max = m_min << 1; + uint16_t len = static_cast(GetManBitLength(m_ret)); + if (len > kManBitLength) { + e_ret = kFp16ExpBias + kFp16ManLen; + uint32_t m_trunc; + uint32_t trunc_mask = 1; + uint16_t e_tmp = len - kManBitLength; + for (int i = 1; i < e_tmp; i++) { + trunc_mask = (trunc_mask << 1) + 1; + } + m_trunc = (m_ret & trunc_mask) << (kBitShift32 - e_tmp); + for (int i = 0; i < e_tmp; i++) { + m_ret = (m_ret >> 1); + e_ret = e_ret + 1; + } + bool b_last_bit = ((m_ret & 1) > 0); + bool b_trunc_high = 0; + bool b_trunc_left = 0; + if (kRoundToNearest == g_round_mode) { // trunc + b_trunc_high = ((m_trunc & kFp32SignMask) > 0); + b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); + } + m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret); + while (m_ret >= m_max || e_ret < 0) { + m_ret = m_ret >> 1; + e_ret = e_ret + 1; + } + if (FP16_IS_INVALID(val)) { + val = kFp16Max; + } + } else { + e_ret = kFp16ExpBias; + m_ret = m_ret << (kDim11 - len); + e_ret = e_ret + (len - 1); + } + val = FP16_CONSTRUCTOR(0u, static_cast(e_ret), m_ret); + } + return *this; +} + +static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, uint16_t &ret_val) { + int16_t e_ret; + uint32_t m_tmp = (input_val & kFp32AbsMax); + uint32_t m_min = kFp16ManHideBit; + uint32_t m_max = m_min << 1; + uint16_t len = static_cast(GetManBitLength(m_tmp)); + if (len > kDim11) { + e_ret = kFp16ExpBias + kFp16ManLen; + uint32_t m_trunc = 0; + uint32_t trunc_mask = 1; + uint16_t e_tmp = len - kDim11; + for (int i = 1; i < e_tmp; i++) { + trunc_mask = (trunc_mask << 1) + 1; + } + m_trunc = (m_tmp & trunc_mask) << (kBitShift32 - e_tmp); + for (int i = 0; i < e_tmp; i++) { + m_tmp = (m_tmp >> 1); + e_ret = e_ret + 1; + } + bool b_last_bit = ((m_tmp & 1) > 0); + bool b_trunc_high = 0; + bool b_trunc_left = 0; + if (kRoundToNearest == g_round_mode) { // trunc + b_trunc_high = ((m_trunc & kFp32SignMask) > 0); + b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); + } + m_tmp = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_tmp); + while (m_tmp >= m_max || e_ret < 0) { + m_tmp = m_tmp >> 1; + e_ret = e_ret + 1; + } + if (e_ret >= kFp16MaxExp) { + e_ret = kFp16MaxExp - 1; + m_tmp = kFp16MaxMan; + } + } else { + e_ret = kFp16ExpBias; + m_tmp = m_tmp << (kDim11 - len); + e_ret = e_ret + (len - 1); + } + auto m_ret = static_cast(m_tmp); + ret_val = FP16_CONSTRUCTOR(sign, static_cast(e_ret), m_ret); +} + +fp16_t &fp16_t::operator=(const int32_t &i_val) { + if (i_val == 0) { + val = 0; + } else { + uint32_t ui_val = *(reinterpret_cast(&i_val)); + auto s_ret = static_cast(ui_val >> kBitShift31); + if (s_ret) { + int32_t iValM = -i_val; + ui_val = *(reinterpret_cast(&iValM)); + } + SetValByUint32Val(ui_val, s_ret, val); + } + return *this; +} + +fp16_t &fp16_t::operator=(const uint32_t &ui_val) { + if (ui_val == 0) { + val = 0; + } else { + int16_t e_ret; + uint32_t m_tmp = ui_val; + uint32_t m_min = kFp16ManHideBit; + uint32_t m_max = m_min << 1; + uint16_t len = static_cast(GetManBitLength(m_tmp)); + if (len > kDim11) { + e_ret = kFp16ExpBias + kFp16ManLen; + uint32_t m_trunc = 0; + uint32_t trunc_mask = 1; + uint16_t e_tmp = len - kDim11; + for (int i = 1; i < e_tmp; i++) { + trunc_mask = (trunc_mask << 1) + 1; + } + m_trunc = (m_tmp & trunc_mask) << static_cast(kBitShift32 - e_tmp); + for (uint16_t i = 0; i < e_tmp; i++) { + m_tmp = (m_tmp >> 1); + e_ret = e_ret + 1; + } + bool b_last_bit = ((m_tmp & 1) > 0); + bool b_trunc_high = false; + bool b_trunc_left = false; + if (g_round_mode == kRoundToNearest) { // trunc + b_trunc_high = ((m_trunc & kFp32SignMask) > 0); + b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); + } + m_tmp = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_tmp); + while (m_tmp >= m_max || e_ret < 0) { + m_tmp = m_tmp >> 1; + e_ret = e_ret + 1; + } + if (e_ret >= kFp16MaxExp) { + e_ret = kFp16MaxExp - 1; + m_tmp = kFp16MaxMan; + } + } else { + e_ret = kFp16ExpBias; + m_tmp = m_tmp << (kDim11 - len); + e_ret = e_ret + (len - 1); + } + auto m_ret = static_cast(m_tmp); + val = FP16_CONSTRUCTOR(0u, static_cast(e_ret), m_ret); + } + return *this; +} + +fp16_t &fp16_t::operator=(const double &d_val) { + uint16_t s_ret; + uint16_t m_ret; + int16_t e_ret; + uint64_t e_d; + uint64_t m_d; + uint64_t ui64_v = *(reinterpret_cast(&d_val)); // 1:11:52bit sign:exp:man + uint32_t m_len_delta; + + s_ret = static_cast((ui64_v & kFp64SignMask) >> kFp64SignIndex); // 4Byte + e_d = (ui64_v & kFp64ExpMask) >> kFp64ManLen; // 10 bit exponent + m_d = (ui64_v & kFp64ManMask); // 52 bit mantissa + m_len_delta = kFp64ManLen - kFp16ManLen; + + bool need_round = false; + // Exponent overflow/NaN converts to signed inf/NaN + if (e_d >= 0x410u) { // 0x410:1040=1023+16 + e_ret = kFp16MaxExp - 1; + m_ret = kFp16MaxMan; + val = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); + } else if (e_d <= 0x3F0u) { // Exponent underflow converts to denormalized half or signed zero + // 0x3F0:1008=1023-15 + // Signed zeros, denormalized floats, and floats with small + // exponents all convert to signed zero half precision. + e_ret = 0; + if (e_d >= 0x3E7u) { // 0x3E7u:999=1023-24 Denormal + // Underflows to a denormalized value + m_d = (kFp64ManHideBit | m_d); + uint16_t shift_out = kFp64ManLen; + uint64_t m_tmp = (static_cast(m_d)) << (e_d - 0x3E7u); + + need_round = IsRoundOne(m_tmp, shift_out); + m_ret = static_cast(m_tmp >> shift_out); + if (need_round) { + m_ret++; + } + } else if (e_d == 0x3E6u && m_d > 0) { + m_ret = 1; + } else { + m_ret = 0; + } + } else { // Regular case with no overflow or underflow + e_ret = static_cast(e_d - 0x3F0u); + + need_round = IsRoundOne(m_d, m_len_delta); + m_ret = static_cast(m_d >> m_len_delta); + if (need_round) { + m_ret++; + } + if (m_ret & kFp16ManHideBit) { + e_ret++; + } + } + + Fp16Normalize(e_ret, m_ret); + val = FP16_CONSTRUCTOR(s_ret, static_cast(e_ret), m_ret); + return *this; +} + +// convert +fp16_t::operator float() const { return Fp16ToFloat(val); } + +fp16_t::operator double() const { return Fp16ToDouble(val); } + +fp16_t::operator int8_t() const { return Fp16ToInt8(val); } + +fp16_t::operator uint8_t() const { return Fp16ToUInt8(val); } + +fp16_t::operator int16_t() const { return Fp16ToInt16(val); } + +fp16_t::operator uint16_t() const { return Fp16ToUInt16(val); } + +fp16_t::operator int32_t() const { return Fp16ToInt32(val); } + +fp16_t::operator uint32_t() const { return Fp16ToUInt32(val); } + +// Cannot be used, just in order to solve the compile error +fp16_t::operator int64_t() const { return 0; } + +// Cannot be used, just in order to solve the compile error +fp16_t::operator uint64_t() const { return 0; } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() { + if ((val & kFp16AbsMax) == kFp16ExpMask) { + if (val & kFp16SignMask) { + return -1; + } else { + return 1; + } + } else { + return 0; + } +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY float fp16_t::ToFloat() const { return Fp16ToFloat(val); } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY double fp16_t::ToDouble() const { return Fp16ToDouble(val); } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int8_t fp16_t::ToInt8() const { return Fp16ToInt8(val); } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint8_t fp16_t::ToUInt8() const { return Fp16ToUInt8(val); } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int16_t fp16_t::ToInt16() const { return Fp16ToInt16(val); } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint16_t fp16_t::ToUInt16() const { return Fp16ToUInt16(val); } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int32_t fp16_t::ToInt32() const { return Fp16ToInt32(val); } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint32_t fp16_t::ToUInt32() const { return Fp16ToUInt32(val); } +} // namespace parser +} // namespace ge diff --git a/parser/common/parser_fp16_t.h b/parser/common/parser_fp16_t.h new file mode 100644 index 0000000..6c361e8 --- /dev/null +++ b/parser/common/parser_fp16_t.h @@ -0,0 +1,653 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_FP16_T_H_ +#define PARSER_COMMON_FP16_T_H_ + +#include +#include +#include + +namespace ge { +namespace parser { +using DimIndex = enum { + kDim0 = 0, + kDim1, + kDim2, + kDim3, + kDim4, + kDim5, + kDim6, + kDim7, + kDim8, + kDim9, + kDim10, + kDim11, + kDim12, + kDim13, + kDim14, + kDim15, + kDim16, +}; + +using BitShift = enum { + kBitShift2 = 2, + kBitShift3 = 3, + kBitShift4 = 4, + kBitShift5 = 5, + kBitShift6 = 6, + kBitShift7 = 7, + kBitShift8 = 8, + kBitShift9 = 9, + kBitShift10 = 10, + kBitShift11 = 11, + kBitShift12 = 12, + kBitShift13 = 13, + kBitShift14 = 14, + kBitShift15 = 15, + kBitShift16 = 16, + kBitShift20 = 20, + kBitShift24 = 24, + kBitShift27 = 27, + kBitShift28 = 28, + kBitShift31 = 31, + kBitShift32 = 32, + kBitShift36 = 36, + kBitShift40 = 40, + kBitShift44 = 44, + kBitShift48 = 48, + kBitShift52 = 52, + kBitShift56 = 56, + kBitShift59 = 59, + kBitShift60 = 60, + kBitShift63 = 63, + kBitShift64 = 64, + kBitShift128 = 128, + kBitShift255 = 255, + kBitShift256 = 256, + kBitShift512 = 512, + kBitShift768 = 768, + kBitShift784 = 784, + kBitShift1020 = 1020, + kBitShift1024 = 1024, + kBitShift3136 = 3136, + kBitShift4096 = 4096, + kBitShift6144 = 6144, + kBitShift10240 = 10240, + kBitShift65536 = 65536 +}; +/// @ingroup fp16 basic parameter +/// @brief fp16 exponent bias +constexpr uint16_t kFp16ExpBias = 15; +/// @ingroup fp16 basic parameter +/// @brief the exponent bit length of fp16 is 5 +constexpr uint16_t kFp16ExpLen = 5; +/// @ingroup fp16 basic parameter +/// @brief the mantissa bit length of fp16 is 10 +constexpr uint16_t kFp16ManLen = 10; +/// @ingroup fp16 basic parameter +/// @brief bit index of sign in fp16 +constexpr uint16_t kFp16SignIndex = 15; +/// @ingroup fp16 basic parameter +/// @brief sign mask of fp16 (1 00000 00000 00000) +constexpr uint16_t kFp16SignMask = 0x8000; +/// @ingroup fp16 basic parameter +/// @brief exponent mask of fp16 ( 11111 00000 00000) +constexpr uint16_t kFp16ExpMask = 0x7C00; +/// @ingroup fp16 basic parameter +/// @brief mantissa mask of fp16 ( 11111 11111) +constexpr uint16_t kFp16ManMask = 0x03FF; +/// @ingroup fp16 basic parameter +/// @brief hide bit of mantissa of fp16( 1 00000 00000) +constexpr uint16_t kFp16ManHideBit = 0x0400; +/// @ingroup fp16 basic parameter +/// @brief maximum value (0111 1011 1111 1111) +constexpr uint16_t kFp16Max = 0x7BFF; +/// @ingroup fp16 basic parameter +/// @brief minimum value (1111 1011 1111 1111) +constexpr uint16_t kFp16Min = 0xFBFF; +/// @ingroup fp16 basic parameter +/// @brief absolute maximum value (0111 1111 1111 1111) +constexpr uint16_t kFp16AbsMax = 0x7FFF; +/// @ingroup fp16 basic parameter +/// @brief maximum exponent value of fp16 is 15(11111) +constexpr uint16_t kFp16MaxExp = 0x001F; +/// @ingroup fp16 basic parameter +/// @brief maximum valid exponent value of fp16 is 14(11110) +constexpr uint16_t kFp16MaxValidExp = 0x001E; +/// @ingroup fp16 basic parameter +/// @brief maximum mantissa value of fp16(11111 11111) +constexpr uint16_t kFp16MaxMan = 0x03FF; +/// @ingroup fp16 basic parameter +/// @brief absolute minimum normal value of fp16 +/// (E=1,M=0 D=2^(-14)=0.00006103515625) +constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14); +/// @ingroup fp16 basic operator +/// @brief get sign of fp16 +#define FP16_EXTRAC_SIGN(x) (((x) >> 15) & 1) +/// @ingroup fp16 basic operator +/// @brief get exponent of fp16 +#define FP16_EXTRAC_EXP(x) (((x) >> 10) & kFp16MaxExp) +/// @ingroup fp16 basic operator +/// @brief get mantissa of fp16 +#define FP16_EXTRAC_MAN(x) ((((x) >> 0) & 0x3FF) | (((((x) >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400)) +/// @ingroup fp16 basic operator +/// @brief constructor of fp16 from sign exponent and mantissa +#define FP16_CONSTRUCTOR(s, e, m) (((s) << kFp16SignIndex) | ((e) << kFp16ManLen) | ((m)&kFp16MaxMan)) +/// @ingroup fp16 special value judgment +/// @brief whether a fp16 is zero +#define FP16_IS_ZERO(x) (((x)&kFp16AbsMax) == 0) +/// @ingroup fp16 special value judgment +/// @brief whether a fp16 is a denormalized value +#define FP16_IS_DENORM(x) ((((x)&kFp16ExpMask) == 0)) +/// @ingroup fp16 special value judgment +/// @brief whether a fp16 is infinite +#define FP16_IS_INF(x) (((x)&kFp16AbsMax) == kFp16ExpMask) +/// @ingroup fp16 special value judgment +/// @brief whether a fp16 is NaN +#define FP16_IS_NAN(x) (((x & kFp16ExpMask) == kFp16ExpMask) && (x & kFp16ManMask)) +/// @ingroup fp16 special value judgment +/// @brief whether a fp16 is invalid +#define FP16_IS_INVALID(x) ((x & kFp16ExpMask) == kFp16ExpMask) +/// @ingroup fp32 basic parameter +/// @brief fp32 exponent bias +constexpr uint16_t kFp32ExpBias = 127; +/// @ingroup fp32 basic parameter +/// @brief the exponent bit length of float/fp32 is 8 +constexpr uint16_t kFp32ExpLen = 8; +/// @ingroup fp32 basic parameter +/// @brief the mantissa bit length of float/fp32 is 23 +constexpr uint16_t kFp32ManLen = 23; +/// @ingroup fp32 basic parameter +/// @brief bit index of sign in float/fp32 +constexpr uint16_t kFp32SignIndex = 31; +/// @ingroup fp32 basic parameter +/// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000) +constexpr uint32_t kFp32SignMask = 0x80000000u; +/// @ingroup fp32 basic parameter +/// @brief exponent mask of fp32 ( 1111 1111 0000 0000 0000 0000 000) +constexpr uint32_t kFp32ExpMask = 0x7F800000u; +/// @ingroup fp32 basic parameter +/// @brief mantissa mask of fp32 ( 1111 1111 1111 1111 111) +constexpr uint32_t kFp32ManMask = 0x007FFFFFu; +/// @ingroup fp32 basic parameter +/// @brief hide bit of mantissa of fp32 ( 1 0000 0000 0000 0000 000) +constexpr uint32_t kFp32ManHideBit = 0x00800000u; +/// @ingroup fp32 basic parameter +/// @brief absolute maximum value (0 1111 1111 1111 1111 1111 1111 111) +constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu; +/// @ingroup fp32 basic parameter +/// @brief maximum exponent value of fp32 is 255(1111 1111) +constexpr uint32_t kFp32MaxExp = 0xFF; +/// @ingroup fp32 basic parameter +/// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111) +constexpr uint32_t kFp32MaxMan = 0x7FFFFF; +/// @ingroup fp32 special value judgment +/// @brief whether a fp32 is NaN +#define FP32_IS_NAN(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (x & kFp32ManMask)) +/// @ingroup fp32 special value judgment +/// @brief whether a fp32 is infinite +#define FP32_IS_INF(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (!(x & kFp32ManMask))) +/// @ingroup fp32 special value judgment +/// @brief whether a fp32 is a denormalized value +#define FP32_IS_DENORM(x) ((((x)&kFp32ExpMask) == 0)) +/// @ingroup fp32 basic operator +/// @brief get sign of fp32 +#define FP32_EXTRAC_SIGN(x) (((x) >> kFp32SignIndex) & 1) +/// @ingroup fp32 basic operator +/// @brief get exponent of fp16 +#define FP32_EXTRAC_EXP(x) (((x)&kFp32ExpMask) >> kFp32ManLen) +/// @ingroup fp32 basic operator +/// @brief get mantissa of fp16 +#define FP32_EXTRAC_MAN(x) (((x)&kFp32ManMask) | (((((x) >> kFp32ManLen) & kFp32MaxExp) > 0 ? 1 : 0) * kFp32ManHideBit)) +/// @ingroup fp32 basic operator +/// @brief constructor of fp32 from sign exponent and mantissa +#define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m)&kFp32MaxMan)) +/// @ingroup fp64 basic parameter +/// @brief fp64 exponent bias +constexpr uint16_t kFp64ExpBias = 1023; +/// @ingroup fp64 basic parameter +/// @brief the exponent bit length of double/fp64 is 11 +constexpr uint16_t kFp64ExpLen = 11; +/// @ingroup fp64 basic parameter +/// @brief the mantissa bit length of double/fp64 is 52 +constexpr uint16_t kFp64ManLen = 52; +/// @ingroup fp64 basic parameter +/// @brief bit index of sign in double/fp64 is 63 +constexpr uint16_t kFp64SignIndex = 63; +/// @ingroup fp64 basic parameter +/// @brief sign mask of fp64 (1 000 (total 63bits 0)) +constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu; +/// @ingroup fp64 basic parameter +/// @brief exponent mask of fp64 (0 1 11111 11111 0000?-?-(total 52bits 0)) +constexpr uint64_t kFp64ExpMask = 0x7FF0000000000000LLu; +/// @ingroup fp64 basic parameter +/// @brief mantissa mask of fp64 ( 1111?-?-(total 52bits 1)) +constexpr uint64_t kFp64ManMask = 0x000FFFFFFFFFFFFFLLu; +/// @ingroup fp64 basic parameter +/// @brief hide bit of mantissa of fp64 ( 1 0000?-?-(total 52bits 0)) +constexpr uint64_t kFp64ManHideBit = 0x0010000000000000LLu; +/// @ingroup fp64 basic parameter +/// @brief absolute maximum value (0 111?-?-(total 63bits 1)) +constexpr uint64_t kFp64AbsMax = 0x7FFFFFFFFFFFFFFFLLu; +/// @ingroup fp64 basic parameter +/// @brief maximum exponent value of fp64 is 2047(1 11111 11111) +constexpr uint64_t kFp64MaxExp = 0x07FF; +/// @ingroup fp64 basic parameter +/// @brief maximum mantissa value of fp64 (111?-?-(total 52bits 1)) +constexpr uint64_t kFp64MaxMan = 0xFFFFFFFFFFFLLu; +/// @ingroup fp64 special value judgment +/// @brief whether a fp64 is NaN +#define FP64_IS_NAN(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (x & kFp64ManMask)) +/// @ingroup fp64 special value judgment +/// @brief whether a fp64 is infinite +#define FP64_IS_INF(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (!(x & kFp64ManMask))) +/// @ingroup integer special value judgment +/// @brief maximum positive value of int8_t (0111 1111) +constexpr int8_t kInt8Max = 0x7F; +/// @ingroup integer special value judgment +/// @brief maximum value of a data with 8 bits length (1111 111) +constexpr uint8_t kBitLen8Max = 0xFF; +/// @ingroup integer special value judgment +/// @brief maximum positive value of int16_t (0111 1111 1111 1111) +constexpr int16_t kInt16Max = 0x7FFF; +/// @ingroup integer special value judgment +/// @brief maximum value of a data with 16 bits length (1111 1111 1111 1111) +constexpr uint16_t kBitLen16Max = 0xFFFF; +/// @ingroup integer special value judgment +/// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111) +constexpr int32_t kInt32Max = 0x7FFFFFFFu; +/// @ingroup integer special value judgment +/// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111) +constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu; +/// @ingroup integer special value judgment +/// @brief maximum positive value of int64_t +/// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) +constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu; +/// @ingroup integer special value judgment +/// @brief maximum value of a data with 64 bits length +/// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) +constexpr uint64_t kBitLen64Max = 0xFFFFFFFFFFFFFFFFu; + +/// @ingroup fp16_t enum +/// @brief round mode of last valid digital +enum TagFp16RoundMode { + kRoundToNearest = 0, // < round to nearest even + kRoundByTruncated, // < round by truncated + kRoundModeReserved, +}; + +/// @ingroup fp16_t +/// @brief Half precision float +/// bit15: 1 bit SIGN +---+-----+------------+ +/// bit14-10: 5 bit EXP | S |EEEEE|MM MMMM MMMM| +/// bit0-9: 10bit MAN +---+-----+------------+ +using fp16_t = struct TagFp16 { + uint16_t val; + +public: + /// @ingroup fp16_t constructor + /// @brief Constructor without any param(default constructor) + TagFp16(void) { val = 0x0u; } + + /// @ingroup fp16_t constructor + /// @brief Constructor with an uint16_t value + TagFp16(const uint16_t &ui_val) : val(ui_val) {} + + /// @ingroup fp16_t constructor + /// @brief Constructor with a fp16_t object(copy constructor) + TagFp16(const TagFp16 &fp) : val(fp.val) {} + + /// @ingroup fp16_t math operator + /// @param [in] fp fp16_t object to be added + /// @brief Override addition operator to performing fp16_t addition + /// @return Return fp16_t result of adding this and fp + TagFp16 operator+(const TagFp16 fp); + + /// @ingroup fp16_t math operator + /// @param [in] fp fp16_t object to be subtracted + /// @brief Override addition operator to performing fp16_t subtraction + /// @return Return fp16_t result of subtraction fp from this + TagFp16 operator-(const TagFp16 fp); + + /// @ingroup fp16_t math operator + /// @param [in] fp fp16_t object to be multiplied + /// @brief Override multiplication operator to performing fp16_t multiplication + /// @return Return fp16_t result of multiplying this and fp + TagFp16 operator*(const TagFp16 fp); + + /// @ingroup fp16_t math operator divided + /// @param [in] fp fp16_t object to be divided + /// @brief Override division operator to performing fp16_t division + /// @return Return fp16_t result of division this by fp + TagFp16 operator/(const TagFp16 fp); + + /// @ingroup fp16_t math operator + /// @param [in] fp fp16_t object to be added + /// @brief Override addition operator to performing fp16_t addition + /// @return Return fp16_t result of adding this and fp + TagFp16 operator+=(const TagFp16 fp); + + /// @ingroup fp16_t math operator + /// @param [in] fp fp16_t object to be subtracted + /// @brief Override addition operator to performing fp16_t subtraction + /// @return Return fp16_t result of subtraction fp from this + TagFp16 operator-=(const TagFp16 fp); + + /// @ingroup fp16_t math operator + /// @param [in] fp fp16_t object to be multiplied + /// @brief Override multiplication operator to performing fp16_t multiplication + /// @return Return fp16_t result of multiplying this and fp + TagFp16 operator*=(const TagFp16 fp); + + /// @ingroup fp16_t math operator divided + /// @param [in] fp fp16_t object to be divided + /// @brief Override division operator to performing fp16_t division + /// @return Return fp16_t result of division this by fp + TagFp16 operator/=(const TagFp16 fp); + + /// @ingroup fp16_t math compare operator + /// @param [in] fp fp16_t object to be compared + /// @brief Override basic comparison operator to performing fp16_t if-equal comparison + /// @return Return boolean result of if-equal comparison of this and fp. + bool operator==(const TagFp16 &fp) const; + + /// @ingroup fp16_t math compare operator + /// @param [in] fp fp16_t object to be compared + /// @brief Override basic comparison operator to performing fp16_t not-equal comparison + /// @return Return boolean result of not-equal comparison of this and fp. + bool operator!=(const TagFp16 &fp) const; + + /// @ingroup fp16_t math compare operator + /// @param [in] fp fp16_t object to be compared + /// @brief Override basic comparison operator to performing fp16_t greater-than comparison + /// @return Return boolean result of greater-than comparison of this and fp. + bool operator>(const TagFp16 &fp) const; + + /// @ingroup fp16_t math compare operator + /// @param [in] fp fp16_t object to be compared + /// @brief Override basic comparison operator to performing fp16_t greater-equal comparison + /// @return Return boolean result of greater-equal comparison of this and fp. + bool operator>=(const TagFp16 &fp) const; + + /// @ingroup fp16_t math compare operator + /// @param [in] fp fp16_t object to be compared + /// @brief Override basic comparison operator to performing fp16_t less-than comparison + /// @return Return boolean result of less-than comparison of this and fp. + bool operator<(const TagFp16 &fp) const; + + /// @ingroup fp16_t math compare operator + /// @param [in] fp fp16_t object to be compared + /// @brief Override basic comparison operator to performing fp16_t less-equal comparison + /// @return Return boolean result of less-equal comparison of this and fp. + bool operator<=(const TagFp16 &fp) const; + + /// @ingroup fp16_t math evaluation operator + /// @param [in] fp fp16_t object to be copy to fp16_t + /// @brief Override basic evaluation operator to copy fp16_t to a new fp16_t + /// @return Return fp16_t result from fp + TagFp16 &operator=(const TagFp16 &fp); + + /// @ingroup fp16_t math evaluation operator + /// @param [in] f_val float object to be converted to fp16_t + /// @brief Override basic evaluation operator to convert float to fp16_t + /// @return Return fp16_t result from f_val + TagFp16 &operator=(const float &f_val); + + /// @ingroup fp16_t math evaluation operator + /// @param [in] d_val double object to be converted to fp16_t + /// @brief Override basic evaluation operator to convert double to fp16_t + /// @return Return fp16_t result from d_val + TagFp16 &operator=(const double &d_val); + + /// @ingroup fp16_t math evaluation operator + /// @param [in] i_val float object to be converted to fp16_t + /// @brief Override basic evaluation operator to convert float to fp16_t + /// @return Return fp16_t result from i_val + TagFp16 &operator=(const int8_t &i_val); + + /// @ingroup fp16_t math evaluation operator + /// @param [in] ui_val uint8_t object to be converted to fp16_t + /// @brief Override basic evaluation operator to convert uint8_t to fp16_t + /// @return Return fp16_t result from ui_val + TagFp16 &operator=(const uint8_t &ui_val); + + /// @ingroup fp16_t math evaluation operator + /// @param [in] i_val int16_t object to be converted to fp16_t + /// @brief Override basic evaluation operator to convert int16_t to fp16_t + /// @return Return fp16_t result from i_val + TagFp16 &operator=(const int16_t &i_val); + + /// @ingroup fp16_t math evaluation operator + /// @param [in] ui_val uint16_t object to be converted to fp16_t + /// @brief Override basic evaluation operator to convert uint16_t to fp16_t + /// @return Return fp16_t result from ui_val + TagFp16 &operator=(const uint16_t &ui_val); + + /// @ingroup fp16_t math evaluation operator + /// @param [in] i_val int32_t object to be converted to fp16_t + /// @brief Override basic evaluation operator to convert int32_t to fp16_t + /// @return Return fp16_t result from i_val + TagFp16 &operator=(const int32_t &i_val); + + /// @ingroup fp16_t math evaluation operator + /// @param [in] ui_val uint32_t object to be converted to fp16_t + /// @brief Override basic evaluation operator to convert uint32_t to fp16_t + /// @return Return fp16_t result from ui_val + TagFp16 &operator=(const uint32_t &ui_val); + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to float/fp32 + /// @return Return float/fp32 value of fp16_t + operator float() const; + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to double/fp64 + /// @return Return double/fp64 value of fp16_t + operator double() const; + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to int8_t + /// @return Return int8_t value of fp16_t + operator int8_t() const; + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to uint8_t + /// @return Return uint8_t value of fp16_t + operator uint8_t() const; + + /// @ingroup fp16_t conversion + /// @brief Override convert operator to convert fp16_t to int16_t + /// @return Return int16_t value of fp16_t + operator int16_t() const; + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to uint16_t + /// @return Return uint16_t value of fp16_t + operator uint16_t() const; + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to int32_t + /// @return Return int32_t value of fp16_t + operator int32_t() const; + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to uint32_t + /// @return Return uint32_t value of fp16_t + operator uint32_t() const; + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to int64_t + /// @return Return int64_t value of fp16_t + operator int64_t() const; + + /// @ingroup fp16_t math conversion + /// @brief Override convert operator to convert fp16_t to uint64_t + /// @return Return uint64_t value of fp16_t + operator uint64_t() const; + + /// @ingroup fp16_t judgment method + /// @param [in] fp fp16_t object to be judgement + /// @brief whether a fp16_t is inifinite + /// @return Returns 1:+INF -1:-INF 0:not INF + int IsInf(); + + /// @ingroup fp16_t math conversion + /// @brief Convert fp16_t to float/fp32 + /// @return Return float/fp32 value of fp16_t + float ToFloat() const; + + /// @ingroup fp16_t math conversion + /// @brief Convert fp16_t to double/fp64 + /// @return Return double/fp64 value of fp16_t + double ToDouble() const; + + /// @ingroup fp16_t math conversion + /// @brief Convert fp16_t to int8_t + /// @return Return int8_t value of fp16_t + int8_t ToInt8() const; + + /// @ingroup fp16_t math conversion + /// @brief Convert fp16_t to uint8_t + /// @return Return uint8_t value of fp16_t + uint8_t ToUInt8() const; + + /// @ingroup fp16_t conversion + /// @brief Convert fp16_t to int16_t + /// @return Return int16_t value of fp16_t + int16_t ToInt16() const; + + /// @ingroup fp16_t math conversion + /// @brief Convert fp16_t to uint16_t + /// @return Return uint16_t value of fp16_t + uint16_t ToUInt16() const; + + /// @ingroup fp16_t math conversion + /// @brief Convert fp16_t to int32_t + /// @return Return int32_t value of fp16_t + int32_t ToInt32() const; + + /// @ingroup fp16_t math conversion + /// @brief Convert fp16_t to uint32_t + /// @return Return uint32_t value of fp16_t + uint32_t ToUInt32() const; +}; + +/// @ingroup fp16_t public method +/// @param [in] val signature is negative +/// @param [in|out] s sign of fp16_t object +/// @param [in|out] e exponent of fp16_t object +/// @param [in|out] m mantissa of fp16_t object +/// @brief Extract the sign, exponent and mantissa of a fp16_t object +void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m); + +/// @ingroup fp16_t public method +/// @param [in] negative sign is negative +/// @param [in|out] man mantissa to be reverse +/// @brief Calculate a mantissa's complement (add ont to it's radix-minus-one complement) +/// @return Return complement of man +template +void ReverseMan(bool negative, T &man) { + if (negative) { + man = (~(man)) + 1; + } +} + +/// @ingroup fp16_t public method +/// @param [in] e_a exponent of one fp16_t/float number +/// @param [in] m_a mantissa of one fp16_t/float number +/// @param [in] e_b exponent of another fp16_t/float number +/// @param [in] m_b mantissa of another fp16_t/float number +/// @brief choose mantissa to be shift right whoes exponent is less than another one +/// @return Return mantissawhoes exponent is less than another one +template +T MinMan(const int16_t &e_a, T &m_a, const int16_t &e_b, T &m_b) { + return (e_a > e_b) ? m_b : m_a; +} + +/// @ingroup fp16_t public method +/// @param [in] man mantissa to be operate +/// @param [in] shift right shift bits +/// @brief right shift a mantissa +/// @return Return right-shift mantissa +template +T RightShift(T man, int16_t shift) { + int bits = sizeof(T) * 8; // one byte have 8 bits + T mask = (((T) 1u) << ((unsigned int) (bits - 1))); + for (int i = 0; i < shift; i++) { + man = ((man & mask) | (man >> 1)); + } + return man; +} + +/// @ingroup fp16_t public method +/// @param [in] e_a exponent of one temp fp16_t number +/// @param [in] m_a mantissa of one temp fp16_t number +/// @param [in] e_b exponent of another temp fp16_t number +/// @param [in] m_b mantissa of another temp fp16_t number +/// @brief Get mantissa sum of two temp fp16_t numbers, T support types: uint16_t/uint32_t/uint64_t +/// @return Return mantissa sum +template +T GetManSum(int16_t e_a, const T &m_a, int16_t e_b, const T &m_b) { + T sum = 0; + if (e_a != e_b) { + T m_tmp = 0; + int16_t e_tmp = std::abs(e_a - e_b); + if (e_a > e_b) { + m_tmp = m_b; + m_tmp = RightShift(m_tmp, e_tmp); + sum = m_a + m_tmp; + } else { + m_tmp = m_a; + m_tmp = RightShift(m_tmp, e_tmp); + sum = m_tmp + m_b; + } + } else { + sum = m_a + m_b; + } + return sum; +} + +/// @ingroup fp16_t public method +/// @param [in] bit0 whether the last preserved bit is 1 before round +/// @param [in] bit1 whether the abbreviation's highest bit is 1 +/// @param [in] bitLeft whether the abbreviation's bits which not contain highest bit grater than 0 +/// @param [in] man mantissa of a fp16_t or float number, support types: uint16_t/uint32_t/uint64_t +/// @param [in] shift abbreviation bits +/// @brief Round fp16_t or float mantissa to nearest value +/// @return Returns true if round 1,otherwise false; +template +T ManRoundToNearest(bool bit0, bool bit1, bool bitLeft, T man, uint16_t shift = 0) { + man = (man >> shift) + ((bit1 && (bitLeft || bit0)) ? 1 : 0); + return man; +} + +/// @ingroup fp16_t public method +/// @param [in] man mantissa of a float number, support types: uint16_t/uint32_t/uint64_t +/// @brief Get bit length of a uint32_t number +/// @return Return bit length of man +template +int16_t GetManBitLength(T man) { + int16_t len = 0; + while (man) { + man >>= 1; + len++; + } + return len; +} +} // namespace parser +} // namespace ge +#endif // GE_PARSER_COMMON_FP16_T_H_ diff --git a/parser/common/parser_inner_ctx.cc b/parser/common/parser_inner_ctx.cc new file mode 100644 index 0000000..1e5bc84 --- /dev/null +++ b/parser/common/parser_inner_ctx.cc @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "framework/omg/parser/parser_inner_ctx.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserContext &GetParserContext() { + static ParserContext context; + return context; +} +} // namespace domi diff --git a/parser/common/parser_types.cc b/parser/common/parser_types.cc new file mode 100644 index 0000000..440e884 --- /dev/null +++ b/parser/common/parser_types.cc @@ -0,0 +1,494 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "framework/omg/parser/parser_types.h" + + +namespace ge{ +namespace parser { +const char *DATA = "Data"; +const char *AIPPDATA = "AippData"; +const char *CONVOLUTION = "Convolution"; +const char *CORRELATION = "Correlation"; +const char *CORRELATIONV2 = "Correlation_V2"; +const char *DECONVOLUTION = "Deconvolution"; +const char *POOLING = "Pooling"; +const char *ELTWISE = "Eltwise"; +const char *RELU = "ReLU"; +const char *RELU6 = "ReLU6"; +const char *SIGMOID = "Sigmoid"; +const char *ABSVAL = "AbsVal"; +const char *TANH = "TanH"; +const char *PRELU = "PReLU"; +const char *BATCHNORM = "BatchNorm"; +const char *FUSIONBATCHNORM = "FusionBatchNorm"; +const char *SCALE = "Scale"; +const char *FULL_CONNECTION = "FullConnection"; +const char *SOFTMAX = "Softmax"; +const char *PLUS = "Plus"; +const char *ACTIVATION = "Activation"; +const char *FLATTEN = "Flatten"; +const char *ADD = "Add"; +const char *SUB = "Sub"; +const char *MUL = "Mul"; +const char *MATMUL = "MatMul"; +const char *RSQRT = "Rsqrt"; +const char *BIASADD = "BiasAdd"; +const char *RESHAPE = "Reshape"; +const char *REFORMAT = "ReFormat"; +const char *DEPCONVOLUTION = "ConvolutionDepthwise"; +const char *DROPOUT = "Dropout"; +const char *DROPOUTGENMASK = "DropOutGenMask"; +const char *DROPOUTDOMASK = "DropOutDoMask"; +const char *CONCAT = "Concat"; +const char *ROIPOOLING = "ROIPooling"; +const char *PROPOSAL = "Proposal"; +const char *FSRDETECTIONOUTPUT = "FSRDetectionOutput"; +const char *DETECTIONPOSTPROCESS = "Detectpostprocess"; +const char *LRN = "LRN"; +const char *TRANSDATA = "TransData"; +const char *PERMUTE = "Permute"; +const char *SSDNORMALIZE = "SSDNormalize"; +const char *SSDPRIORBOX = "SSDPriorBox"; +const char *NETOUTPUT = "NetOutput"; +const char *SSDDETECTIONOUTPUT = "SSDDetectionOutput"; +const char *REFINEDETDETECTIONOUTPUT = "RefinedetDetectionOutput"; +const char *CHANNELAXPY = "ChannelAxpy"; +const char *PSROIPOOLING = "PSROIPooling"; +const char *POWER = "Power"; +const char *POW = "Pow"; +const char *ROIALIGN = "ROIAlign"; +const char *PYTHON = "Python"; +const char *FREESPACEEXTRACT = "FreespaceExtract"; +const char *SPATIALTF = "SpatialTransform"; +const char *SHAPE = "Shape"; +const char *SHAPEN = "ShapeN"; +const char *ARGMAX = "ArgMax"; +const char *GATHERND = "GatherNd"; +const char *GATHER = "Gather"; +const char *REALDIV = "RealDiv"; +const char *PACK = "Pack"; +const char *SLICE = "Slice"; +const char *SLICED = "SliceD"; +const char *FLOORDIV = "FloorDiv"; +const char *SQUEEZE = "Squeeze"; +const char *UNSQUEEZE = "Unsqueeze"; +const char *STRIDEDSLICE = "StridedSlice"; +const char *RANGE = "Range"; +const char *RPNPROPOSALS = "RpnProposals"; +const char *DECODEBBOX = "DecodeBbox"; +const char *PAD = "Pad"; +const char *PADV2 = "PadV2"; +const char *MIRRORPAD = "MirrorPad"; +const char *TILE = "Tile"; +const char *SIZE = "Size"; +const char *CLIPBOXES = "ClipBoxes"; +const char *FASTRCNNPREDICTIONS = "FastrcnnPredictions"; +const char *SPLIT = "Split"; +const char *SPLITV = "SplitV"; +const char *EXPANDDIMS = "ExpandDims"; +const char *EMPTY = "Empty"; +const char *MEAN = "Mean"; +const char *GREATER = "Greater"; +const char *SWITCH = "Switch"; +const char *SWITCHN = "SwitchN"; +const char *MERGE = "Merge"; +const char *SYMBOLICGRADIENT = "SymbolicGradient"; +const char *REMOTECALL = "RemoteCall"; +const char *_IF = "_If"; +const char *STATELESSIF = "StatelessIf"; +const char *IF = "If"; +const char *CASE = "Case"; +const char *_WHILE = "_While"; +const char *WHILE = "While"; +const char *STATELESSWHILE = "StatelessWhile"; +const char *FOR = "For"; +const char *PARTITIONEDCALL = "PartitionedCall"; +const char *STATEFULPARTITIONEDCALL = "StatefulPartitionedCall"; +const char *FAKEPARAM = "FakeParam"; +const char *TRANSPOSE = "Transpose"; +const char *TRANSPOSED = "TransposeD"; +const char *CAST = "Cast"; +const char *REGION = "Region"; +const char *YOLO = "Yolo"; +const char *YOLODETECTIONOUTPUT = "YoloDetectionOutput"; +const char *FILL = "Fill"; +const char *REVERSE = "Reverse"; +const char *UNPACK = "Unpack"; +const char *YOLO2REORG = "Yolo2Reorg"; +const char *REDUCESUM = "ReduceSum"; +const char *SUM = "Sum"; +const char *CONSTANT = "Const"; +const char *RESIZEBILINEAR = "ResizeBilinear"; +const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; +const char *MAXIMUM = "Maximum"; +const char *FRAMEWORKOP = "FrameworkOp"; +const char *ARG = "_Arg"; +const char *FUSEDBATCHNORMGRAD = "FusedBatchNormGrad"; +const char *LSTM = "LSTM"; +const char *HIGHWAY = "HighWay"; +const char *RNN = "RNN"; +const char *ATTENTIONDECODER = "AttentionDecoder"; +const char *LOGICAL_NOT = "LogicalNot"; +const char *LOGICAL_AND = "LogicalAnd"; +const char *LOGICAL_OR = "LogicalOr"; +const char *EQUAL = "Equal"; +const char *NOTEQUAL = "NotEqual"; +const char *INTERP = "Interp"; +const char *SHUFFLECHANNEL = "ShuffleChannel"; +const char *AIPP = "Aipp"; +const char *MULTISHAPE = "MultiShape"; +const char *RECIPROCAL = "Reciprocal"; +const char *SELU = "Selu"; +const char *ELU = "Elu"; +const char *ACOSH = "Acosh"; +const char *ASINH = "Asinh"; +const char *MINIMUM = "Minimum"; +const char *CLIP = "Clip"; +const char *L2NORMALIZE = "L2Normalize"; +const char *CROPANDRESIZE = "CropAndResize"; +const char *UNUSEDCONST = "UnusedConst"; +const char *SPARSETODENSE = "SparseToDense"; +const char *NONMAXSUPPRESSION = "NonMaxSuppression"; +const char *TOPKV2 = "TopKV2"; +const char *INVERTPERMUTATION = "InvertPermutation"; +const char *MULTINOMIAL = "Multinomial"; +const char *REVERSESEQUENCE = "ReverseSequence"; +const char *REDUCEPROD = "ReduceProd"; +const char *REDUCEMAX = "ReduceMax"; +const char *REDUCEMIN = "ReduceMin"; +const char *EXTRACTIMAGEPATCHES = "ExtractImagePatches"; +const char *SQRT = "Sqrt"; +const char *REDUCEALL = "ReduceAll"; +const char *RESIZENEARESTNEIGHBOR = "ResizeNearestNeighbor"; +const char *SPACETOBATCHND = "SpaceToBatchND"; +const char *BATCHTOSPACEND = "BatchToSpaceND"; +const char *ASSERT = "Assert"; +const char *GREATEREQUAL = "GreaterEqual"; +const char *FLOOR = "Floor"; +const char *RANDOMUNIFORM = "RandomUniform"; +const char *BATCHMATMUL = "BatchMatMul"; +const char *SPACETODEPTH = "SpaceToDepth"; +const char *DEPTHTOSPACE = "DepthToSpace"; +const char *RINT = "Rint"; +const char *ATAN = "Atan"; +const char *ATAN2 = "Atan2"; +const char *ATANH = "Atanh"; +const char *ACOS = "Acos"; +const char *ASIN = "Asin"; +const char *NEG = "Neg"; +const char *LOG = "Log"; +const char *TAN = "Tan"; +const char *ROUND = "Round"; +const char *UPSAMPLE = "Upsample"; +const char *FLOORMOD = "FloorMod"; +const char *LESS = "Less"; +const char *LESSEQUAL = "LessEqual"; +const char *ONEHOT = "OneHot"; +const char *REFSWITCH = "RefSwitch"; +const char *REFMERGE = "RefMerge"; +const char *ENTER = "Enter"; +const char *REFENTER = "RefEnter"; +const char *LOOPCOND = "LoopCond"; +const char *NEXTITERATION = "NextIteration"; +const char *REFNEXTITERATION = "RefNextIteration"; +const char *EXIT = "Exit"; +const char *REFEXIT = "RefExit"; +const char *CONTROLTRIGGER = "ControlTrigger"; +const char *ZEROSLIKE = "ZerosLike"; +const char *EXP = "Exp"; +const char *WHERE = "Where"; +const char *FAKEQUANTWITHMINMAXVARS = "FakeQuantWithMinMaxVars"; +const char *SOFTPLUS = "Softplus"; +const char *SOFTSIGN = "Softsign"; +const char *COSH = "Cosh"; +const char *SINH = "Sinh"; +const char *SQUAREDDIFFERENCE = "SquaredDifference"; +const char *REQUIREDSPACETOBATCHPADDINGS = "RequiredSpaceToBatchPaddings"; // for retinanet scope fusion +const char *SSDPOSTPROCESSOR = "SSDPostProcessor"; +const char *RETINANETBOXES = "RetinanetBoxes"; +const char *RETINAMULTIANCHORS = "RetinaMultiAnchor"; +const char *RETINANETCLIPPEDBOXES = "RetinanetClippedBoxes"; +const char *RETINANETFILTEREDDETECTIONS = "RetinanetFilteredDetections"; +const char *RETINANETPOSTPROCESSOR = "RetinanetPostProcessor"; +const char *RETINANETANCHORS = "RetinanetAnchors"; +const char *FASTERRCNNMAP = "FasterRCNNMap"; +const char *FASTERRCNNMAP1 = "FasterRCNNMap1"; +const char *FASTERRCNNSECONDSTAGEPOSTPROCESSOR = "FasterRCNNSecondStagePostprocessor"; +const char *FASTERRCNNROIINTERPOOLING = "FasterRCNNROIInterPooling"; +const char *FASTERRCNNFIRSTSTAGEPOSTPROCESSOR = "FasterRCNNFirstStagePostprocessor"; +const char *FASTERRCNNGRIDANCHORGENERATOR = "FasterRCNNGridAnchorGenerator"; +const char *ROIINTERPOOLING = "ROIInterPooling"; +const char *FASTERRCNNCLIPTOWINDOW = "FasterRCNNClipToWindow"; +const char *EMBEDLOOKUP = "EmbedLookup"; +const char *HASHLOOKUP = "HashLookup"; +const char *LSH_PROJ = "LshProject"; +const char *SVDF = "SVDF"; +const char *SSDANCHORGENERATOR = "SSDAnchorGenerator"; +const char *IDENTITY = "Identity"; +const char *IDENTITYN = "IdentityN"; +const char *PLACEHOLDERWITHDEFAULT = "PlaceholderWithDefault"; +const char *SELECT = "Select"; +const char *GETSPAN = "GetSpan"; +const char *STOPGRADIENT = "StopGradient"; +const char *PREVENTGRADIENT = "PreventGradient"; +const char *GUARANTEECONST = "GuaranteeConst"; +const char *BROADCASTGRADIENTARGS = "BroadcastGradientArgs"; +const char *BROADCASTARGS = "BroadcastArgs"; +const char *CONFUSIONMATRIX = "ConfusionMatrix"; +const char *RANK = "Rank"; +const char *PLACEHOLDER = "PlaceHolder"; +const char *END = "End"; +const char *BASICLSTMCELL = "BasicLSTMCell"; +const char *GETNEXT = "GetNext"; +const char *INITDATA = "InitData"; +const char *REFIDENTITY = "RefIdentity"; +const char *BITCAST = "Bitcast"; + +/***************Ann special operator*************************/ +const char *ANN_MEAN = "AnnMean"; +const char *ANN_CONVOLUTION = "AnnConvolution"; +const char *ANN_DEPCONVOLUTION = "AnnDepthConv"; +const char *ANN_FULLCONNECTION = "AnnFullConnection"; +const char *ANN_NETOUTPUT = "AnnNetOutput"; +const char *ANN_DATA = "AnnData"; +const char *ANN_RESHAPE = "AnnReshape"; +const char *ANN_ADD = "AnnAdd"; +const char *ANN_MUL = "AnnMul"; +const char *ANN_SUB = "AnnSub"; +const char *ANN_DIV = "AnnDiv"; +const char *ANN_DEQUANTIZE = "AnnDequant"; +const char *ANN_QUANTIZE = "AnnQuant"; +const char *ANN_PAD = "AnnPad"; +const char *ANN_RESIZE_BILINEAR = "AnnResizeBilinear"; + +/***************************************************/ +/******************Training operator*************************/ +const char *GATHERV2 = "GatherV2"; +const char *CONVGRADFILTER = "Conv2DBackpropFilter"; +const char *CONV2D = "Conv2D"; +const char *CONV2DBACKPROPINPUT = "Conv2DBackpropInput"; +const char *FUSEDBATCHNORM = "FusedBatchNorm"; +const char *BIASADDGRAD = "BiasAddGrad"; +const char *ACTIVATIONGRAD = "ReluGrad"; +const char *MAXPOOLWITHARGMAX = "MaxPoolWithArgmax"; +const char *MAXPOOLGRADWITHARGMAX = "MaxPoolGradWithArgmax"; +const char *SPARSESOFTMAXCROSSENTROPYWITHLOGITS = "SparseSoftmaxCrossEntropyWithLogits"; +const char *SNAPSHOT = "Snapshot"; +const char *VAR = "Var"; +const char *MEANGRAD = "MeanGrad"; +const char *TRANSLATE = "Translate"; +const char *ADDN = "AddN"; +const char *L2LOSS = "L2Loss"; +const char *MULTIPLY = "Multiply"; +const char *HUBERLOSSGRAD = "HuberLossGrad"; +const char *HUBERLOSS = "HuberLoss"; +const char *NEGATIVE = "Negative"; +const char *SSDCAST = "SSDCast"; +const char *SPARSESOFTMAXCROSSENTROPY = "SsdSparseSoftmaxCrossEntropy"; +const char *SPARSESOFTMAXCROSSENTROPYGRAD = "SsdSparseSoftmaxCrossEntropyGrad"; +const char *SSDSQUEEZEFUSION = "SsdSqueezeFusion"; +const char *CONCATFOUR2FIVE = "ConcatFour2Five"; +const char *CONCATFIVE2FOUR = "ConcatFive2Four"; +const char *SSDREALDIVTILEMUL = "SSDRealdivTileMul"; +const char *SSDSUMMULREALDIVMEAN = "SSDSumMulRealdivMean"; + +const char *VARIABLEV2 = "VariableV2"; +const char *VARHANDLEOP = "VarHandleOp"; +const char *TEMPORARYVARIABLE = "TemporaryVariable"; +const char *DESTROYTEMPORARYVARIABLE = "DestroyTemporaryVariable"; +const char *VARIABLE = "Variable"; +const char *ASSIGN = "Assign"; +const char *ASSIGNVARIABLEOP = "AssignVariableOp"; +const char *ASSIGNADD = "AssignAdd"; +const char *ASSIGNADDVARIABLEOP = "AssignAddVariableOp"; +const char *ASSIGNSUB = "AssignSub"; +const char *ASSIGNSUBVARIABLEOP = "AssignSubVariableOp"; +const char *APPLYMOMENTUM = "ApplyMomentum"; +const char *RESOURCEAPPLYMOMENTUM = "ResourceApplyMomentum"; +const char *SGD = "SGD"; +const char *NOOP = "NoOp"; +const char *READVARIABLEOP = "ReadVariableOp"; +const char *PARALLELCONCATSTART = "_ParallelConcatStart"; +const char *CONSTANTOP = "Constant"; +const char *DEPTHWISECONV2DBACKPROPFILTER = "DepthwiseConv2dNativeBackpropFilter"; +const char *DEPTHWISECONV2DBACKPORPINPUT = "DepthwiseConv2dNativeBackpropInput"; +const char *DEPTHWISECONV2DFORWARDNATIVE = "DepthwiseConv2dNative"; +const char *DROPOUTGRAD = "DropOutGrad"; +const char *APPLYRMSPROPMIXEDPRECISION = "apply_rms_prop_mixed_precision"; +const char *APPLYRMSPROP = "ApplyRMSProp"; +const char *RELU6GRAD = "Relu6Grad"; +const char *AVGPOOLGRAD = "AvgPoolGrad"; +const char *CONCATV2 = "ConcatV2"; +const char *CONCATOFFSET = "ConcatOffset"; +const char *LAYERNORMGRAD = "LayerNormGrad"; +const char *LAYERNORM = "LayerNorm"; +const char *LARS = "Lars"; +const char *DYNAMICSTITCH = "DynamicStitch"; + +/***************************************************/ +const char *SQUARE = "Square"; +const char *HCOMBROADCAST = "HcomBroadcast"; +const char *HCOMALLGATHER = "HcomAllGather"; +const char *HCOMALLREDUCE = "HcomAllReduce"; +const char *HCOMREDUCESCATTER = "HcomReduceScatter"; +const char *HCOMSEND = "HcomSend"; +const char *HCOMRECEIVE = "HcomReceive"; +const char *HCOMREMOTEREAD = "HcomRemoteRead"; +const char *HCOMREMOTEWRITE = "HcomRemoteWrite"; + +const char *VARASSIGN = "VarAssign"; +const char *VARISINITIALIZEDOP = "VarIsInitializedOp"; +const char *LogTimeStamp = "LogTimeStamp"; +const char *ISVARIABLEINITIALIZED = "IsVariableInitialized"; +const char *STREAMSWITCH = "StreamSwitch"; +const char *STREAMSWITCHN = "StreamSwitchN"; +const char *STREAMACTIVE = "StreamActive"; +const char *MEMCPYASYNC = "MemcpyAsync"; +const char *MEMCPYADDRASYNC = "MemcpyAddrAsync"; +const char *STREAMMERGE = "StreamMerge"; +const char *ENDGRAPH = "EndGraph"; +const char *SEND = "Send"; +const char *RECV = "Recv"; +const char *ENDOFSEQUENCE = "EndOfSequence"; + +const char *LABELSET = "LabelSet"; +const char *LABELGOTO = "LabelGoto"; +const char *LABELGOTOEX = "LabelGotoEx"; +const char *LABELSWITCH = "LabelSwitch"; +const char *LABELSWITCHBYINDEX = "LabelSwitchByIndex"; + +const char *ATOMICADDRCLEAN = "AtomicAddrClean"; + +const char *ABS_GRAD = "AbsGrad"; +const char *ACCUMULATE_N_V2 = "AccumulateNV2"; +const char *ACOS_GRAD = "AcosGrad"; +const char *ACOSH_GRAD = "AcoshGrad"; +const char *ANY = "Any"; +const char *APPROXIMATE_EQUAL = "ApproximateEqual"; +const char *ASIN_GRAD = "AsinGrad"; +const char *ASINH_GRAD = "AsinhGrad"; +const char *ATAN_GRAD = "AtanGrad"; +const char *BROADCAST_TO = "BroadcastTo"; +const char *ELU_GRAD = "EluGrad"; +const char *ADD_V2 = "AddV2"; +const char *DATAFORMATDIMMAP = "DataFormatDimMap"; +const char *DATAFORMATVECPERMUTE = "DataFormatVecPermute"; +const char *BESSELI0E = "BesselI0e"; +const char *BESSELI1E = "BesselI1e"; +const char *APPLYADADELTA = "ApplyAdadelta"; +const char *APPLYADAGRAD = "ApplyAdagrad"; +const char *APPLYADAGRADDA = "ApplyAdagradDA"; +const char *APPLYADAM = "ApplyAdam"; +const char *APPLYADAMAX = "ApplyAdaMax"; +const char *APPLYADDSIGN = "ApplyAddSign"; +const char *APPLYCENTEREDRMSPROP = "ApplyCenteredRMSProp"; +const char *APPLYFTRL = "ApplyFtrl"; +const char *APPLYFTRLV2 = "ApplyFtrlV2"; +const char *APPLYGRADIENTDESCENT = "ApplyGradientDescent"; +const char *APPLYPOWERSIGN = "ApplyPowerSign"; +const char *APPLYPROXIMALADAGRAD = "ApplyProximalAdagrad"; +const char *APPLYPROXIMALGRADIENTDESCENT = "ApplyProximalGradientDescent"; +const char *DEQUANTIZE = "Dequantize"; + +const char *FOCAL_LOSS = "FocalLoss"; +const char *FOCAL_LOSS_GRAD = "FocalLossGrad"; +const char *SMOOTHL1_LOSS = "SmoothL1Loss"; +const char *SMOOTHL1_LOSS_grad = "SmoothL1LossGrad"; +const char *REDUCEMEAN = "ReduceMean"; +const char *CONCAT_V2 = "ConcatV2"; +const char *ONEHOT_V2 = "OneHotV2"; +const char *SLICE_V2 = "SliceV2"; +const char *TILE_V2 = "TileV2"; +const char *SUM_V2 = "SumV2"; +// Common type when the operator has the same name +const char *DETECTIONOUTPUT = "DetectionOutput"; +// Custom operator +const char *CUSTOMOP = "CustomOp"; +const char *CUSTOMOP_NCHW = "CustomOpNchw"; +const char *CUSTOMOP_NHWC = "CustomOpNhwc"; +const char *CUSTOMOP_NC1HWC0 = "CustomOpNc1hwc0"; + +// Depthwise 4d_2_6d,6d_2_4d +const char *DEPTHWISEWEIGHT4D26D = "depthwise_weight_4d_2_6d"; +const char *DEPTHWISEWEIGHT6D24D = "depthwise_weight_6d_2_4d"; + +const char *SQRTGRAD = "SqrtGrad"; +const char *SIGMOIDGRAD = "SigmoidGrad"; + +const char *TRANSSHAPE = "TransShape"; + +// Horovod operator +const char *HVDCALLBACKALLREDUCE = "HorovodAllreduce"; +const char *HVDCALLBACKALLGATHER = "HorovodAllgather"; +const char *HVDCALLBACKBROADCAST = "HorovodBroadcast"; +const char *HVDWAIT = "HorovodWait"; + +/// +/// @brief Magic number of model file +/// +const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number + +/// +/// @brief Model head length +/// +const uint32_t MODEL_FILE_HEAD_LEN = 256; + +const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// + +/// +/// @ingroup domi_omg +/// @brief alpha default value +/// +const float ALPHA_DEFAULT_VALUE = 1.0; + +/// +/// @ingroup domi_omg +/// @brief beta default value +/// +const float BETA_DEFAULT_VALUE = 0.0; + +/// +/// @ingroup domi_omg +/// @brief Input node type +/// +const std::string INPUT_TYPE = "Input"; +const std::string DUMMY_DATA = "DummyData"; + +// for fusion op plugin +const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; + +const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; +const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; + +/// +/// @ingroup domi_omg +/// @brief DATA node type +/// +const std::string DATA_TYPE = "Data"; + +/// +/// @ingroup domi_omg +/// @brief Frame operator type +/// +const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; + +/// +/// @ingroup domi_omg +/// @brief Convolution node type +/// +const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; +} // namespace parser +} // namespace ge diff --git a/parser/common/pass_manager.cc b/parser/common/pass_manager.cc new file mode 100644 index 0000000..0c28572 --- /dev/null +++ b/parser/common/pass_manager.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/pass_manager.h" +#include "framework/omg/parser/parser_types.h" +#include "parser/common/acl_graph_parser_util.h" +#include "common/debug/log.h" +#include "graph/utils/node_utils.h" +#include "omg/omg_inner_types.h" + +namespace ge { +namespace parser { +const vector> &PassManager::GraphPasses() const { return names_to_graph_passes_; } + +Status PassManager::AddPass(const string &pass_name, GraphPass *pass) { + GE_CHECK_NOTNULL(pass); + names_to_graph_passes_.emplace_back(pass_name, pass); + return SUCCESS; +} + +Status PassManager::Run(const ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + return Run(graph, names_to_graph_passes_); +} + +Status PassManager::Run(const ComputeGraphPtr &graph, vector> &names_to_passes) { + GE_CHECK_NOTNULL(graph); + bool not_changed = true; + + for (auto &pass_pair : names_to_passes) { + const auto &pass = pass_pair.second; + const auto &pass_name = pass_pair.first; + GE_CHECK_NOTNULL(pass); + + PARSER_TIMESTAMP_START(PassRun); + Status status = pass->Run(graph); + if (status == SUCCESS) { + not_changed = false; + } else if (status != NOT_CHANGED) { + GELOGE(status, "Pass Run failed on graph %s", graph->GetName().c_str()); + return status; + } + for (const auto &subgraph :graph->GetAllSubgraphs()) { + GE_CHECK_NOTNULL(subgraph); + GE_CHK_STATUS_RET(pass->ClearStatus(), "pass clear status failed for subgraph %s", subgraph->GetName().c_str()); + string subgraph_pass_name = pass_name + "::" + graph->GetName(); + PARSER_TIMESTAMP_START(PassRunSubgraph); + status = pass->Run(subgraph); + PARSER_TIMESTAMP_END(PassRunSubgraph, subgraph_pass_name.c_str()); + if (status == SUCCESS) { + not_changed = false; + } else if (status != NOT_CHANGED) { + GELOGE(status, "Pass Run failed on subgraph %s", subgraph->GetName().c_str()); + return status; + } + } + PARSER_TIMESTAMP_END(PassRun, pass_name.c_str()); + } + + return not_changed ? NOT_CHANGED : SUCCESS; +} + +PassManager::~PassManager() { + for (auto &pass_pair : names_to_graph_passes_) { + auto &pass = pass_pair.second; + GE_DELETE_NEW_SINGLE(pass); + } +} +} // namespace parser +} // namespace ge diff --git a/parser/common/pass_manager.h b/parser/common/pass_manager.h new file mode 100644 index 0000000..b260248 --- /dev/null +++ b/parser/common/pass_manager.h @@ -0,0 +1,76 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_PASS_MANAGER_H_ +#define PARSER_COMMON_PASS_MANAGER_H_ + +#include + +#include "inc/graph_pass.h" + +using std::vector; + +namespace ge { +namespace parser { +/// +/// @ingroup domi_omg +/// @brief pass manager +/// @author +/// +class PassManager { +public: + /// + /// get graph passes + /// @author + /// + const vector> &GraphPasses() const; + + /// + /// Add graph pass + /// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys. + /// @author + /// + Status AddPass(const string &pass_name, GraphPass *pass); + + /// + /// Optimize graph with added pass + /// @param [inout] graph graph to be optimized + /// @return SUCCESS optimize successfully + /// @return NOT_CHANGED not optimized + /// @return others optimize failed + /// @author + /// + Status Run(const ge::ComputeGraphPtr &graph); + + /// + /// Optimize graph with specified pass + /// @param [inout] graph graph to be optimized + /// @param [in] passes passes to be used + /// @return SUCCESS optimize successfully + /// @return NOT_CHANGED not optimized + /// @return others optimized failed + /// @author + /// + static Status Run(const ge::ComputeGraphPtr &graph, vector> &passes); + + ~PassManager(); + +private: + vector> names_to_graph_passes_; +}; +} // namespace parser +} // namespace ge +#endif // PARSER_COMMON_PASS_MANAGER_H_ diff --git a/parser/common/pre_checker.cc b/parser/common/pre_checker.cc new file mode 100644 index 0000000..91ea192 --- /dev/null +++ b/parser/common/pre_checker.cc @@ -0,0 +1,287 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/pre_checker.h" +#include +#include "common/model_saver.h" +#include "common/op_map.h" +#include "common/util.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "omg/omg.h" +#include "parser/common/op_parser_factory.h" +#include "parser/common/model_saver.h" +#include "register/op_registry.h" + +namespace ge { +// Keys in JSON file +namespace { +const char *const kKeyName = "name"; +const char *const kKeyResult = "result"; +const char *const kKeyTotal = "total"; +const char *const kKeyPass = "pass"; +const char *const kKeyFail = "fail"; +const char *const kKeyOp = "op"; +const char *const kKeyOpName = "name"; +const char *const kKeyOpType = "type"; +const char *const kKeyOpResult = "result"; +const char *const kKeyCause = "cause"; +const char *const kKeyCauseCode = "code"; +const char *const kKeyCauseMessage = "message"; + +// Checking result and support warning later +const char *const kResultSuccess = "success"; +const char *const kResultFailed = "failed"; +} // namespace + +PreChecker::PreChecker() : fmk_op_types_(nullptr) { Init(); } + +void PreChecker::Init() { + model_name_.clear(); + op_map_.clear(); + ops_.clear(); + fmk_op_types_ = nullptr; + + // Currently only Caffe and tensorflow are supported + domi::FrameworkType fmk_type = GetParserContext().type; + if (fmk_type == domi::CAFFE) + fmk_op_types_ = &caffe_op_map; + else if (fmk_type == domi::TENSORFLOW) + fmk_op_types_ = &tensorflow_op_map; + else + return; +} + +PreChecker::~PreChecker() {} + +FMK_FUNC_HOST_VISIBILITY PreChecker &PreChecker::Instance() { + static PreChecker instance; + return instance; +} + +FMK_FUNC_HOST_VISIBILITY void PreChecker::SetModelName(const string &name) { model_name_ = name; } + +FMK_FUNC_HOST_VISIBILITY Status PreChecker::AddOp(OpId id, const string &name, const string &type) { + GE_RETURN_WITH_LOG_IF_TRUE(op_map_.find(id) != op_map_.end(), "Id already exists."); + + Info info; + info.id = id; + info.name = name; + info.type = type; + op_map_[id] = info; + ops_.push_back(id); + + return SUCCESS; +} + +Status PreChecker::CheckName(OpId id) { + auto iter = op_map_.find(id); + GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); + + Info &info = iter->second; + for (auto &v : op_map_) { + // If the name is duplicate, an error is logged + if (id != v.first && info.name == v.second.name) { + Cause cause; + cause.code = NAME_REPEATED; + cause.message = "The name is repeated."; + + GELOGI("Name %s repeated.", info.name.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name}); + GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed."); + GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "Add cause failed."); + break; + } + } + + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY Status PreChecker::CheckType(OpId id, bool is_tensorflow) { + auto iter = op_map_.find(id); + GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); + + Info &info = iter->second; + string type = info.type; + + // If the user explicitly specifies the mapping relationship of the operator type through + // the -- OP_name_map parameter, the type specified by the user is used. + auto op_map_iter = GetParserContext().op_conf_map.find(type); + if (op_map_iter != GetParserContext().op_conf_map.end()) { + type = op_map_iter->second; + } + + // Judge whether the type is supported + GE_RETURN_WITH_LOG_IF_ERROR( + CheckTypeSupported(info.id, type, info.name, is_tensorflow), "Check type supported failed."); + + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY Status PreChecker::AddCause(OpId id, ErrorCode code, const string &msg) { + Cause cause; + cause.code = code; + cause.message = msg; + return AddCause(id, cause); +} + +FMK_FUNC_HOST_VISIBILITY void PreChecker::RefreshErrorMessageByName(const string &op_name, ErrorCode code, + const string &msg) { + for (const auto &op : op_map_) { + if (op.second.name == op_name) { + AddCause(op.second.id, code, msg); + return; + } + } + GELOGW("Node [%s] not founded in prechecking list.", op_name.c_str()); +} + +Status PreChecker::AddCause(OpId id, const Cause &cause) { + auto iter = op_map_.find(id); + GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); + + Info &info = iter->second; + + // Avoid adding repeatedly + for (Cause &c : info.causes) { + if (c.code == cause.code && c.message == cause.message) { + return SUCCESS; + } + } + + info.causes.push_back(cause); + + return SUCCESS; +} + +void PreChecker::Clear() { Init(); } + +Status PreChecker::Clear(OpId id, const string &message) { + auto iter = op_map_.find(id); + GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); + + Info &info = iter->second; + info.causes.clear(); + + // Set additional information + if (message != "") { + Cause cause; + cause.code = ErrorCode::OK; + cause.message = message; + GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed."); + } + + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY bool PreChecker::HasError() { + for (auto id : ops_) { + if (HasError(id)) { + return true; + } + } + + return false; +} + +Status PreChecker::Save(string file) { + uint32_t fail_num = 0; + for (auto id : ops_) { + if (HasError(id)) { + fail_num++; + } + } + + // Initialization model related JSON information + nlohmann::json model; + model[kKeyName] = model_name_; + model[kKeyResult] = HasError() ? kResultFailed : kResultSuccess; + model[kKeyTotal] = ops_.size(); + model[kKeyPass] = ops_.size() - fail_num; + model[kKeyFail] = fail_num; + + // Constructing JSON information of operators in order of network + for (auto id : ops_) { + auto iter = op_map_.find(id); + GE_CHK_BOOL_RET_STATUS(iter != op_map_.end(), FAILED, "don't find this op."); + Info &info = iter->second; + + // Initialization operator general information + nlohmann::json op = {{kKeyOpName, info.name}, {kKeyOpType, info.type}}; + op[kKeyOpResult] = HasError(id) ? kResultFailed : kResultSuccess; + + // handle causes + for (const Cause &cause : info.causes) { + nlohmann::json cause_j = {{kKeyCauseCode, cause.code}, {kKeyCauseMessage, cause.message}}; + op[kKeyCause].push_back(cause_j); + } + + model[kKeyOp].push_back(op); + } + + // Save JSON data to a file + GE_RETURN_WITH_LOG_IF_ERROR(ge::parser::ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed."); + + return SUCCESS; +} + +Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string &name, bool is_tensorflow) { + // Currently only partial framework type checking is supported + if (fmk_op_types_ == nullptr) { + std::string op_type; + if (!domi::OpRegistry::Instance()->GetOmTypeByOriOpType(type, op_type)) { + Cause cause; + cause.code = TYPE_UNSUPPORTED; + cause.message = "The type is not supported."; + GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); + if (!is_tensorflow) { + ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); + } + GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed."); + } + return SUCCESS; + } + + // Log error if type not found + if (fmk_op_types_->find(type) == fmk_op_types_->end()) { + Cause cause; + cause.code = TYPE_UNSUPPORTED; + cause.message = "The type is not supported."; + + GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str()); + if (!is_tensorflow) { + ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type}); + } + GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed."); + } + + return SUCCESS; +} + +bool PreChecker::HasError(OpId id) { + auto iter = op_map_.find(id); + GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist."); + + Info &info = iter->second; + for (const Cause &cause : info.causes) { + if (cause.code != ErrorCode::OK) { + return true; + } + } + + return false; +} +} // namespace ge diff --git a/parser/common/pre_checker.h b/parser/common/pre_checker.h new file mode 100644 index 0000000..12d3323 --- /dev/null +++ b/parser/common/pre_checker.h @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_PRE_CHECKER_H_ +#define PARSER_COMMON_PRE_CHECKER_H_ + +#include +#include +#include "framework/omg/parser/parser_types.h" +#include "omg/omg_inner_types.h" + +namespace ge { +using std::map; +using std::string; +using std::vector; +using Status = domi::Status; +/** + * @ingroup domi_omg + * @brief pre_check + * @author + */ +class PreChecker { + public: + /** + * @ingroup domi_omg + * @brief Operator unique identification + */ + using OpId = const void *; + + /** + * @ingroup domi_omg + * @brief error code, 1~99:Error, 100~199:Waring。 + */ + enum ErrorCode { + // no error + OK = 0, + + // type unsupported + TYPE_UNSUPPORTED = 1, + + // param invalid + PARAM_INVALID = 2, + + // type ambiguous + TYPE_AMBIGUOUS = 8, + + // name repeated + NAME_REPEATED = 9 + }; + + /** + * @ingroup domi_omg + * @brief Operator error description + */ + struct Cause { + // error code + ErrorCode code; + + // error message + string message; + }; + + public: + /** + * @ingroup domi_omg + * @brief instance interface + */ + static PreChecker &Instance(); + + /** + * @ingroup domi_omg + * @brief set model name + */ + void SetModelName(const string &name); + + /** + * @ingroup domi_omg + * @brief add op information + */ + Status AddOp(OpId id, const string &name, const string &type); + + /** + * @ingroup domi_omg + * @brief Judge whether the operator name is duplicate + */ + Status CheckName(OpId id); + + /** + * @ingroup domi_omg + * @brief check operation type + * 1、Check whether the operator type supports according to the global frameworktype + * 2、Check if the operator type is ambiguous + */ + Status CheckType(OpId id, bool is_tensorflow = false); + + void RefreshErrorMessageByName(const string &op_name, ErrorCode code, const string& msg); + + /** + * @ingroup domi_omg + * @brief Add custom error description + */ + Status AddCause(OpId id, ErrorCode code, const string &msg); + + /** + * @ingroup domi_omg + * @brief Add custom error description + */ + Status AddCause(OpId id, const Cause &cause); + + /** + * @ingroup domi_omg + * @brief Clear all operator information + */ + void Clear(); + + /** + * @ingroup domi_omg + * @brief Clear the error information of the specified operator + */ + Status Clear(OpId id, const string &message = ""); + + /** + * @ingroup domi_omg + * @brief Determine if an error has been detected + */ + bool HasError(); + + /** + * @ingroup domi_omg + * @brief Save inspection results(JSON) + */ + Status Save(string file); + + private: + /** + * @ingroup domi_omg + * @brief operation information + */ + struct Info { + // Operator identifier + OpId id; + + // Operator name + string name; + + // Operator type + string type; + + // Error description, which may contain multiple (for example, both name and type are illegal) + vector causes; + }; + + PreChecker(); + ~PreChecker(); + PreChecker(const PreChecker &); + PreChecker &operator=(const PreChecker &); + + // Initialize internal data + void Init(); + + // Judge whether the type is supported + Status CheckTypeSupported(OpId id, const string &type, const string &name, bool is_tensorflow); + + // Determine if an error has been detected + bool HasError(OpId id); + + private: + // model name + string model_name_; + + // Save operator check results + map op_map_; + + // Save operator list in original order + vector ops_; + + // save frame related operator types + map *fmk_op_types_; +}; +} // namespace ge +#endif // PARSER_COMMON_PRE_CHECKER_H_ \ No newline at end of file diff --git a/parser/common/proto/ge_ir.proto b/parser/common/proto/ge_ir.proto new file mode 100644 index 0000000..e7bfe0c --- /dev/null +++ b/parser/common/proto/ge_ir.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/parser/common/proto/insert_op.proto b/parser/common/proto/insert_op.proto new file mode 100644 index 0000000..c635ca1 --- /dev/null +++ b/parser/common/proto/insert_op.proto @@ -0,0 +1,136 @@ +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ־̬AIPPͶ̬AIPP + AippMode aipp_mode = 1; + + // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 + // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 + uint32 related_input_rank = 2; + + // input_edge_idxΪѡΪͣ÷ΧΪ>=0 + // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP + // ֵ <= Dataߵĸ + repeated uint32 input_edge_idx = 3; + + // [Begin] ̬AIPPþ̬AIPPʱЧ + uint32 max_src_image_size = 4; + + // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ + bool support_rotation = 5; + + // [End] ̬AIPP + + + // [Begin] ̬AIPPö̬AIPPʱЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ̬AIPP + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //̬batch + resolution = 1; //ֱ̬ʣչ + } + + MultiShapeMode mode = 1; //ģʽ + uint32 related_input_rank = 2; //Ӳ뵽ĸ + + + repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ +} diff --git a/parser/common/proto/om.proto b/parser/common/proto/om.proto new file mode 100644 index 0000000..e15e5f8 --- /dev/null +++ b/parser/common/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/parser/common/proto/tensorflow/attr_value.proto b/parser/common/proto/tensorflow/attr_value.proto new file mode 100644 index 0000000..1cc67d6 --- /dev/null +++ b/parser/common/proto/tensorflow/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/parser/common/proto/tensorflow/function.proto b/parser/common/proto/tensorflow/function.proto new file mode 100644 index 0000000..075897c --- /dev/null +++ b/parser/common/proto/tensorflow/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/parser/common/proto/tensorflow/graph.proto b/parser/common/proto/tensorflow/graph.proto new file mode 100644 index 0000000..d639a7d --- /dev/null +++ b/parser/common/proto/tensorflow/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/parser/common/proto/tensorflow/node_def.proto b/parser/common/proto/tensorflow/node_def.proto new file mode 100644 index 0000000..b9bc97e --- /dev/null +++ b/parser/common/proto/tensorflow/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/parser/common/proto/tensorflow/op_def.proto b/parser/common/proto/tensorflow/op_def.proto new file mode 100644 index 0000000..3485d04 --- /dev/null +++ b/parser/common/proto/tensorflow/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/parser/common/proto/tensorflow/resource_handle.proto b/parser/common/proto/tensorflow/resource_handle.proto new file mode 100644 index 0000000..a345235 --- /dev/null +++ b/parser/common/proto/tensorflow/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/parser/common/proto/tensorflow/tensor.proto b/parser/common/proto/tensorflow/tensor.proto new file mode 100644 index 0000000..d0a4d02 --- /dev/null +++ b/parser/common/proto/tensorflow/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/parser/common/proto/tensorflow/tensor_shape.proto b/parser/common/proto/tensorflow/tensor_shape.proto new file mode 100644 index 0000000..4225a2e --- /dev/null +++ b/parser/common/proto/tensorflow/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/parser/common/proto/tensorflow/types.proto b/parser/common/proto/tensorflow/types.proto new file mode 100644 index 0000000..ba7a72b --- /dev/null +++ b/parser/common/proto/tensorflow/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/parser/common/proto/tensorflow/versions.proto b/parser/common/proto/tensorflow/versions.proto new file mode 100644 index 0000000..4806121 --- /dev/null +++ b/parser/common/proto/tensorflow/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/parser/common/proto_file_parser.cc b/parser/common/proto_file_parser.cc new file mode 100644 index 0000000..731ac8c --- /dev/null +++ b/parser/common/proto_file_parser.cc @@ -0,0 +1,528 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#include "parser/common/proto_file_parser.h" + +#include +#include +#include +#include +#include +#include +#include +#include "common/string_util.h" +#include "common/types.h" +#include "common/util.h" +#include "common/debug/log.h" +#include "parser/common/acl_graph_parser_util.h" +#include "ge/ge_api_types.h" +#include "framework/common/debug/ge_log.h" + +using std::ifstream; +using std::vector; +using std::string; + +namespace { +const char kMinNum = '0'; +const char kMaxNum = '9'; +const int kMinLineWordSize = 3; +const int kMinMessageLineWords = 2; +const int kMaxIdentifier = 536870912; // 2^29 - 1 +const int kTmpFileNameLen = 16; +const int kMinRandomNum = 0; +const int kMaxRandomNum = 9; +const int kDecimalMulti = 10; +const int kOpenRetValue = 0; +const int kMessageNameIndex = 2; +const char *const kTmpPath = "/tmp"; +const char *const kMessage = "message"; +const char *const kLayerParameter = "LayerParameter"; +const char *const kNetParameter = "NetParameter"; +const char *const kStartBrace = "{"; +const char *const kCloseBrace = "}"; +const char *const kOptional = "optional"; +const char *const kRepeated = "repeated"; +const char *const kRequired = "required"; + +bool GetIdentifier(const std::string &line, int &identifier) { + int size = line.size(); + auto pos = line.find("="); + if (pos == std::string::npos) { + return false; + } + for (int i = pos + 1; i < size; i++) { + if (line[i] == ';') { + break; + } + if (line[i] >= kMinNum && line[i] <= kMaxNum) { + identifier = identifier * kDecimalMulti + line[i] - kMinNum; + } + if (identifier > kMaxIdentifier || identifier < 0) { + return false; + } + } + if (identifier == 0) { + return false; + } + return true; +} + +void GetName(const std::string &op_info, string &op_name) { + op_name.assign(op_info); + auto pos = op_name.find("="); + if (pos != string::npos) { + op_name = op_name.substr(0, pos); + } +} + +void GetOpParamInfo(const std::string &line, std::vector &op_param_info) { + std::istringstream string_stream(line); + std::string temp; + while (std::getline(string_stream, temp, ' ')) { + if (temp.empty()) { + continue; + } + op_param_info.emplace_back(std::move(temp)); + } +} + +string GetMessageName(const std::string &line) { + std::vector op_param_info; + GetOpParamInfo(line, op_param_info); + string message_name; + if (op_param_info.size() < kMinMessageLineWords) { + message_name = ""; + return message_name; + } + message_name = op_param_info[1]; + auto pos = message_name.find(kStartBrace); + if (pos != string::npos) { + message_name = message_name.substr(0, pos); + } + return message_name; +} + +string CreatTmpName(int len) { + std::uniform_int_distribution u(kMinRandomNum, kMaxRandomNum); + std::default_random_engine e; + e.seed(time(0)); + string tmp_name = ""; + for (int i = 0; i < len; i++) { + tmp_name += std::to_string(u(e)); + } + return tmp_name; +} + +bool SaveIdentifierOpMapInfo(const string &line, std::map> &identifier_op_map, + std::map> &op_identifier_map) { + std::vector op_param_info; + GetOpParamInfo(line, op_param_info); + int info_size = op_param_info.size(); + if (info_size < kMinLineWordSize) { + GELOGE(ge::FAILED, "Words size of line[%s] is less than kMinLineWordSize[%d].", line.c_str(), kMinLineWordSize); + return false; + } + + if (op_param_info[0] != kOptional && op_param_info[0] != kRepeated && op_param_info[0] != kRequired) { + GELOGE(ge::FAILED, "Split line[%s] failed.", line.c_str()); + return false; + } + + // get identifier + int identifier = 0; + bool ret = GetIdentifier(line, identifier); + if (!ret) { + GELOGE(ge::FAILED, "Get identifier of line[%s] failed.", line.c_str()); + return false; + } + + // get op_name + string name; + GetName(op_param_info[kMessageNameIndex], name); + + identifier_op_map[identifier] = std::make_pair(op_param_info[1], name); + op_identifier_map[name] = std::make_pair(identifier, op_param_info[1]); + return true; +} + +bool CheckRealPath(const char *file_path) { + string dest_path = ge::parser::RealPath(file_path); + if (dest_path.empty()) { + GELOGW("Path [%s] is not real existed.", file_path); + return false; + } + return true; +} +} // namespace + +namespace ge { +ProtoFileParser::~ProtoFileParser() { + if (!fusion_proto_path.empty() && CheckRealPath(fusion_proto_path.c_str())) { + (void)remove(fusion_proto_path.c_str()); + } +} + +std::string ProtoFileParser::GetFusionProtoFile() { + return fusion_proto_path; +} + +Status ProtoFileParser::CreatProtoFile() { + if (fusion_proto_path.empty()) { + fusion_proto_path.assign(kTmpPath); + fusion_proto_path += "/" + CreatTmpName(kTmpFileNameLen); + } + + int fd = open(fusion_proto_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP); + if (fd < kOpenRetValue) { + GELOGE(FAILED, "creat tmp proto file[%s] failed.", fusion_proto_path.c_str()); + return FAILED; + } + close(fd); + return SUCCESS; +} + +Status ProtoFileParser::ParseProtoFile(const string &proto_file, + std::map> &identifier_op_map, + std::map> &op_identifier_map) { + ifstream read_file; + read_file.open(proto_file, std::ios::in); + if (read_file.fail()) { + GELOGE(FAILED, "ifsream open proto file[%s] failed.", proto_file.c_str()); + return FAILED; + } + + std::string line; + bool save_flag = false; + while (std::getline(read_file, line)) { + if (line.find(kMessage) != std::string::npos && line.find(kLayerParameter) != std::string::npos) { + save_flag = true; + continue; + } + + if (save_flag && line.find(kCloseBrace) != std::string::npos) { + save_flag = false; + break; + } + + if (save_flag) { + if (line.find(kRepeated) == std::string::npos && line.find(kOptional) == std::string::npos && + line.find(kRequired) == std::string::npos) { + continue; + } + bool ret = SaveIdentifierOpMapInfo(line, identifier_op_map, op_identifier_map); + if (!ret) { + read_file.close(); + return FAILED; + } + } + } + read_file.close(); + return SUCCESS; +} + +Status ProtoFileParser::AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp) { + ifstream read_custom; + read_custom.open(custom_proto_file, std::ios::in); + if (read_custom.fail()) { + GELOGE(FAILED, "ifsream open custom proto file[%s] failed.", custom_proto_file); + return FAILED; + } + + std::string line_custom; + bool custom_in_layer = false; + while (std::getline(read_custom, line_custom)) { + if (line_custom.find(kMessage) != std::string::npos && line_custom.find(kLayerParameter) != std::string::npos) { + custom_in_layer = true; + continue; + } + + if (!custom_in_layer) { + continue; + } + + if (line_custom.find(kCloseBrace) != std::string::npos) { + custom_in_layer = false; + break; + } + // exclude remark lines + if (line_custom.find(kRepeated) == std::string::npos && line_custom.find(kOptional) == std::string::npos && + line_custom.find(kRequired) == std::string::npos) { + continue; + } + // exclude repeated lines + if (custom_repeat_line_map_.count(line_custom) == 0) { + write_tmp << line_custom << '\n'; + } + } + read_custom.close(); + return SUCCESS; +} + +Status ProtoFileParser::AddCustomAndConflictMessage(const char *custom_proto_file, std::ofstream &write_tmp) { + ifstream read_custom; + read_custom.open(custom_proto_file, std::ios::in); + if (read_custom.fail()) { + GELOGE(FAILED, "ifsream open custom proto file[%s] failed.", custom_proto_file); + return FAILED; + } + + std::string line_custom; + bool custom_in_message = false; + while (std::getline(read_custom, line_custom)) { + if (line_custom.find(kMessage) != std::string::npos) { + std::string message_name = GetMessageName(line_custom); + if (message_name != kLayerParameter && message_name != kNetParameter) { + custom_in_message = true; + write_tmp << line_custom << '\n'; + } else { + custom_in_message = false; + } + continue; + } + + // exclude repeated messages + if (custom_in_message) { + write_tmp << line_custom << '\n'; + } + } + read_custom.close(); + return SUCCESS; +} + +Status ProtoFileParser::WriteCaffeProtoFile(const char *custom_proto_file, + std::ifstream &read_caffe, + std::ofstream &write_tmp) { + std::string line_caffe; + bool caffe_in_layer = false; + bool caffe_in_unrepeated_message = true; + string tmp_message_name; + while (std::getline(read_caffe, line_caffe)) { + if (line_caffe.find(kMessage) != std::string::npos) { + tmp_message_name.assign(GetMessageName(line_caffe)); + if (custom_repeat_message_map_.count(tmp_message_name) > 0) { + caffe_in_unrepeated_message = false; + } else { + caffe_in_unrepeated_message = true; + if (tmp_message_name == kLayerParameter) { + caffe_in_layer = true; + } + } + } + if (!caffe_in_unrepeated_message) { + continue; + } + if (caffe_in_layer && line_caffe.find(kCloseBrace) != std::string::npos) { + if (AddCustomAndConflictLayer(custom_proto_file, write_tmp) != SUCCESS) { + GELOGE(FAILED, "Add conflict and new layer line from custom proto to dest proto failed."); + return FAILED; + } + caffe_in_layer = false; + } + + // exclude conflict lines + if (caffe_in_layer && caffe_conflict_line_map_.count(line_caffe) > 0) { + GELOGD("pass line: %s", line_caffe.c_str()); + continue; + } + write_tmp << line_caffe << '\n'; + } + return SUCCESS; +} + +Status ProtoFileParser::WriteProtoFile(const char *caffe_proto_file, + const char *custom_proto_file) { + std::ifstream read_caffe; + std::ofstream write_tmp; + read_caffe.open(caffe_proto_file, std::ios::in); + if (read_caffe.fail()) { + GELOGE(FAILED, "ifsream open proto file[%s] failed.", caffe_proto_file); + return FAILED; + } + write_tmp.open(fusion_proto_path, std::ios::out); + if (write_tmp.fail()) { + GELOGE(FAILED, "ofstream open proto file[%s] failed.", fusion_proto_path.c_str()); + read_caffe.close(); + return FAILED; + } + + if (WriteCaffeProtoFile(custom_proto_file, read_caffe, write_tmp) != SUCCESS) { + read_caffe.close(); + write_tmp.close(); + return FAILED; + } + + if (AddCustomAndConflictMessage(custom_proto_file, write_tmp) != SUCCESS) { + GELOGE(FAILED, "Add conflict and new message from custom proto to dest proto failed."); + read_caffe.close(); + write_tmp.close(); + return FAILED; + } + + read_caffe.close(); + write_tmp.close(); + return SUCCESS; +} + +Status ProtoFileParser::FindConflictLine(const char *proto_file, int identifier, + std::string &dest_line) { + ifstream read_file; + read_file.open(proto_file, std::ios::in); + if (read_file.fail()) { + GELOGE(FAILED, "open file[%s] failed.", proto_file); + return FAILED; + } + + std::string line; + bool save_flag = false; + while (std::getline(read_file, line)) { + if (line.find(kMessage) != std::string::npos && line.find(kLayerParameter) != std::string::npos) { + save_flag = true; + continue; + } + + if (save_flag && line.find(kCloseBrace) != std::string::npos) { + save_flag = false; + break; + } + + int tmp_identifier = 0; + if (save_flag && GetIdentifier(line, tmp_identifier) && tmp_identifier == identifier) { + dest_line.assign(line); + read_file.close(); + return SUCCESS; + } + } + read_file.close(); + GELOGE(FAILED, "find line according to identifier[%d] failed.", identifier); + return FAILED; +} + +void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, + std::map> &caffe_op_identifier_map, + std::map> &custom_op_identifier_map) { + for (auto iter = custom_op_identifier_map.begin(); iter != custom_op_identifier_map.end(); ++iter) { + if (caffe_op_identifier_map.count(iter->first) > 0) { + string message_name = iter->first; + auto caffe_pair = caffe_op_identifier_map[iter->first]; + auto custom_pair = custom_op_identifier_map[iter->first]; + if (caffe_pair.first != custom_pair.first || caffe_pair.second != custom_pair.second) { + // consider conflict op and name and type; + GELOGD("Find conflict op: caffe_identifier[%d], custom_identifier[%d], op_name[%s].", + caffe_pair.first, custom_pair.first, message_name.c_str()); + std::string caffe_conflict_line; + (void)FindConflictLine(caffe_proto_file, caffe_pair.first, caffe_conflict_line); + GELOGD("conflict: %s", caffe_conflict_line.c_str()); + caffe_conflict_line_map_[caffe_conflict_line]++; + } else { + // consider repeat op and name and type; could be removed + std::string custom_repeat_line; + (void)FindConflictLine(custom_proto_file, caffe_pair.first, custom_repeat_line); + custom_repeat_line_map_[custom_repeat_line]++; + GELOGD("repeat: %s", custom_repeat_line.c_str()); + } + } + } +} + +void ProtoFileParser::CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file, + std::map> caffe_identifier_op_map, + std::map> custom_identifier_op_map) { + for (auto iter = custom_identifier_op_map.begin(); iter != custom_identifier_op_map.end(); ++iter) { + if (caffe_identifier_op_map.count(iter->first) > 0) { + int identifier = iter->first; + auto caffe_pair = caffe_identifier_op_map[iter->first]; + auto custom_pair = custom_identifier_op_map[iter->first]; + if (caffe_pair.first != custom_pair.first || caffe_pair.second != custom_pair.second) { + // consider conflict op and name and type; + GELOGD("Find conflict op: caffe_op[%s], custom_op[%s], identifier[%d].", + caffe_pair.first.c_str(), custom_pair.first.c_str(), + identifier); + std::string caffe_conflict_line; + (void)FindConflictLine(caffe_proto_file, identifier, caffe_conflict_line); + GELOGD("conflict: %s", caffe_conflict_line.c_str()); + caffe_conflict_line_map_[caffe_conflict_line]++; + } else { + // consider repeat op and name and type; + std::string custom_repeat_line; + (void)FindConflictLine(custom_proto_file, identifier, custom_repeat_line); + custom_repeat_line_map_[custom_repeat_line]++; + GELOGD("repeat: %s", custom_repeat_line.c_str()); + } + } + } +} + +Status ProtoFileParser::RecordProtoMessage(const string &proto_file) { + ifstream read_file; + read_file.open(proto_file, std::ios::in); + if (read_file.fail()) { + GELOGE(FAILED, "ifsream open proto file[%s] failed.", proto_file.c_str()); + return FAILED; + } + + std::string line; + while (std::getline(read_file, line)) { + if (line.find(kMessage) != std::string::npos) { + std::string message_name = GetMessageName(line); + if (message_name != kLayerParameter && message_name != kNetParameter) { + custom_repeat_message_map_[message_name]++; + } + } + } + read_file.close(); + return SUCCESS; +} + +Status ProtoFileParser::CombineProtoFile(const char *caffe_proto_file, const char *custom_proto_file, + std::string &dest_proto_file) { + GE_CHECK_NOTNULL(caffe_proto_file); + GE_CHECK_NOTNULL(custom_proto_file); + + if (!CheckRealPath(caffe_proto_file) || !CheckRealPath(custom_proto_file)) { + GELOGE(FAILED, "caffe proto[%s] and custom proto[%s] are not all existed.", + caffe_proto_file, custom_proto_file); + return FAILED; + } + + GELOGI("Start fusion custom and caffe proto to file."); + std::map> caffe_identifier_op_map; + std::map> custom_identifier_op_map; + std::map> caffe_op_identifier_map; + std::map> custom_op_identifier_map; + + (void)ParseProtoFile(caffe_proto_file, caffe_identifier_op_map, caffe_op_identifier_map); + (void)ParseProtoFile(custom_proto_file, custom_identifier_op_map, custom_op_identifier_map); + (void)RecordProtoMessage(custom_proto_file); + + // check identifier or op_type is same + CheckConflictIdentifier(caffe_proto_file, custom_proto_file, + caffe_identifier_op_map, custom_identifier_op_map); + CheckConflictOp(caffe_proto_file, custom_proto_file, + caffe_op_identifier_map, custom_op_identifier_map); + + if (CreatProtoFile() != SUCCESS) { + return FAILED; + } + + if (WriteProtoFile(caffe_proto_file, custom_proto_file) != SUCCESS) { + GELOGE(FAILED, "Combine caffe proto and custom proto to dest proto file failed."); + return FAILED; + } + dest_proto_file.assign(fusion_proto_path); + GELOGI("Fusion custom and caffe proto to file[%s] success.", dest_proto_file.c_str()); + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/parser/common/proto_file_parser.h b/parser/common/proto_file_parser.h new file mode 100644 index 0000000..5dc46aa --- /dev/null +++ b/parser/common/proto_file_parser.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +#ifndef PROTO_FILE_PARSE_UTIL_ +#define PROTO_FILE_PARSE_UTIL_ + +#include +#include +#include "common/types.h" +#include "ge/ge_api_types.h" + +namespace ge { +class ProtoFileParser { +public: + ProtoFileParser(){}; + ProtoFileParser(const char *dest_path){ + fusion_proto_path = dest_path; + } + ~ProtoFileParser(); + Status CombineProtoFile(const char *caffe_proto_file, const char *custom_proto_file, + std::string &dest_proto_file); + std::string GetFusionProtoFile(); +private: + Status CreatProtoFile(); + Status ParseProtoFile(const std::string &proto_file, + std::map > &identifier_op_map, + std::map > &op_identifier_map); + Status WriteCaffeProtoFile(const char *custom_proto_file, + std::ifstream &read_caffe, + std::ofstream &write_tmp); + Status WriteProtoFile(const char *caffe_proto_file, const char *custom_proto_file); + Status FindConflictLine(const char *proto_file, int identifier, + std::string &dest_line); + Status AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp); + Status AddCustomAndConflictMessage(const char *custom_proto_file, std::ofstream &write_tmp); + void CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, + std::map> &caffe_op_identifier_map, + std::map> &custom_op_identifier_map); + void CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file, + std::map> caffe_identifier_op_map, + std::map> custom_identifier_op_map); + Status RecordProtoMessage(const std::string &proto_file); + std::map caffe_conflict_line_map_; + std::map custom_repeat_line_map_; + std::map custom_repeat_message_map_; + std::string fusion_proto_path; +}; +} // namespace ge + +#endif // PROTO_FILE_PARSE_UTIL_ \ No newline at end of file diff --git a/parser/common/register_tbe.cc b/parser/common/register_tbe.cc new file mode 100644 index 0000000..1b3f098 --- /dev/null +++ b/parser/common/register_tbe.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/common/register_tbe.h" +#include +#include +#include +#include "common/debug/log.h" +#include "common/ge/ge_util.h" +#include "common/op/ge_op_utils.h" +#include "common/op_map.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/utils/type_utils.h" +#include "parser/common/op_parser_factory.h" +#include "parser/tensorflow/tensorflow_custom_parser_adapter.h" +#include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" + +namespace ge { +using PARSER_CREATOR_FN = std::function(void)>; + +FMK_FUNC_HOST_VISIBILITY OpRegistrationTbe *OpRegistrationTbe::Instance() { + static OpRegistrationTbe instance; + return &instance; +} + +bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_train) { + static std::map *> op_map = {{CAFFE, &caffe_op_map}}; + if (is_train) { + op_map[domi::TENSORFLOW] = &tensorflow_train_op_map; + } else { + op_map[domi::TENSORFLOW] = &tensorflow_op_map; + } + + if (op_map.find(reg_data.GetFrameworkType()) != op_map.end()) { + std::map *fmk_op_map = op_map[reg_data.GetFrameworkType()]; + auto ori_optype_set = reg_data.GetOriginOpTypeSet(); + for (auto &tmp : ori_optype_set) { + if ((*fmk_op_map).find(tmp) != (*fmk_op_map).end()) { + GELOGW("Op type does not need to be changed, om_optype:%s, orignal type:%s.", (*fmk_op_map)[tmp].c_str(), + tmp.c_str()); + continue; + } else { + (*fmk_op_map)[tmp] = reg_data.GetOmOptype(); + GELOGD("First register in parser initialize, original type: %s, om_optype: %s, imply type: %s.", tmp.c_str(), + reg_data.GetOmOptype().c_str(), TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); + } + } + } + + bool ret = RegisterParser(reg_data); + return ret; +} + +bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { + if (reg_data.GetFrameworkType() == domi::TENSORFLOW) { + std::shared_ptr factory = OpParserFactory::Instance(domi::TENSORFLOW); + if (factory == nullptr) { + GELOGE(INTERNAL_ERROR, "Get op parser factory for tf failed."); + return false; + } + if (reg_data.GetParseParamFn() != nullptr || reg_data.GetParseParamByOperatorFn() != nullptr) { + bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype()); + if (is_registed) { + GELOGW("Parse param func has already register for op:%s.", reg_data.GetOmOptype().c_str()); + return false; + } + std::shared_ptr tf_parser_adapter = + ge::MakeShared(); + if (tf_parser_adapter == nullptr) { + GELOGE(PARAM_INVALID, "Create tf parser adapter failed."); + return false; + } + OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( + domi::TENSORFLOW, reg_data.GetOmOptype(), [=]() -> std::shared_ptr { return tf_parser_adapter; }); + } + if (reg_data.GetFusionParseParamFn() != nullptr || reg_data.GetFusionParseParamByOpFn() != nullptr) { + bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype(), true); + if (is_registed) { + GELOGW("Parse param func has already register for fusion op:%s.", reg_data.GetOmOptype().c_str()); + return false; + } + GELOGI("Register fusion custom op parser: %s", reg_data.GetOmOptype().c_str()); + std::shared_ptr tf_fusion_parser_adapter = + ge::MakeShared(); + if (tf_fusion_parser_adapter == nullptr) { + GELOGE(PARAM_INVALID, "Create tf fusion parser adapter failed."); + return false; + } + OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( + domi::TENSORFLOW, reg_data.GetOmOptype(), + [=]() -> std::shared_ptr { return tf_fusion_parser_adapter; }, true); + } + } else { + std::shared_ptr factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); + if (factory == nullptr) { + GELOGE(INTERNAL_ERROR, "Get op parser factory for %s failed.", + TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); + return false; + } + bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype()); + if (is_registed) { + GELOGW("Parse param func has already register for op:%s.", reg_data.GetOmOptype().c_str()); + return false; + } + + PARSER_CREATOR_FN func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(reg_data.GetFrameworkType()); + if (func == nullptr) { + GELOGE(INTERNAL_ERROR, "Get custom parser adapter failed for fmk type %s.", + TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); + return false; + } + OpParserFactory::Instance(reg_data.GetFrameworkType())->RegisterCreator(reg_data.GetOmOptype(), func); + GELOGD("Register custom parser adapter for op %s of fmk type %s success.", reg_data.GetOmOptype().c_str(), + TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); + } + return true; +} +} // namespace ge diff --git a/parser/common/register_tbe.h b/parser/common/register_tbe.h new file mode 100644 index 0000000..7e2803c --- /dev/null +++ b/parser/common/register_tbe.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_REGISTER_TBE_H_ +#define PARSER_COMMON_REGISTER_TBE_H_ + +#include "register/op_registry.h" + +namespace ge { +class OpRegistrationTbe { + public: + static OpRegistrationTbe *Instance(); + + bool Finalize(const OpRegistrationData ®_data, bool is_train = false); + + private: + bool RegisterParser(const OpRegistrationData ®_data); +}; +} // namespace ge + +#endif // PARSER_COMMON_REGISTER_TBE_H_ \ No newline at end of file diff --git a/parser/common/tbe_plugin_loader.cc b/parser/common/tbe_plugin_loader.cc new file mode 100644 index 0000000..82c06eb --- /dev/null +++ b/parser/common/tbe_plugin_loader.cc @@ -0,0 +1,212 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tbe_plugin_loader.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/string_util.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "graph/utils/type_utils.h" +#include "parser/common/acl_graph_parser_util.h" + +namespace ge { +std::map TBEPluginLoader::options_ = {}; + +namespace { +const std::string FRAMEWORK_TYPE = "ge.frameworkType"; +} + +// Get Singleton Instance +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginLoader &TBEPluginLoader::Instance() { + static TBEPluginLoader instance_ptr_; + return instance_ptr_; +} + +Status TBEPluginLoader::ClearHandles_() { + Status ret = SUCCESS; + for (const auto &handle : handles_vec_) { + if (dlclose(handle) != 0) { + ret = FAILED; + GELOGW("Failed to close handle: %s", dlerror()); + } + } + handles_vec_.clear(); + return ret; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status TBEPluginLoader::Finalize() { + Status ret = ClearHandles_(); + return ret; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginLoader::LoadPluginSo( + const std::map &options) { + vector file_list; + string caffe_parser_path; + std::string plugin_path; + + options_ = options; + GetCustomOpPath(plugin_path); + + // Whether there are files in the plugin so path + GetPluginSoFileList(plugin_path, file_list, caffe_parser_path); + + // No file + if (file_list.empty()) { + // Print log + GELOGW("Can not find any plugin file in plugin_path: %s", plugin_path.c_str()); + } + + GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); + + // Load other so files except lib_caffe_parser.so in the plugin so path + for (auto elem : file_list) { + StringUtils::Trim(elem); + + void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE); + if (handle == nullptr) { + GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); + } else if (find(handles_vec_.begin(), handles_vec_.end(), handle) == handles_vec_.end()) { + // Close dl when the program exist, not close here + GELOGI("Plugin load %s success.", elem.c_str()); + handles_vec_.push_back(handle); + } else { + GELOGI("Plugin so has already been loaded, no need to load again."); + } + } +} + +void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) { + GELOGI("Enter get custom op path schedule"); + std::string fmk_type; + domi::FrameworkType type = domi::TENSORFLOW; + auto it = options_.find(FRAMEWORK_TYPE); + if (it != options_.end()) { + type = static_cast(std::strtol(it->second.c_str(), nullptr, 10)); + } + fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); + GELOGI("Framework type is %s.", fmk_type.c_str()); + + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env != nullptr) { + std::string path = path_env; + customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); + GELOGI("Get custom so path from env : %s", path_env); + return; + } + std::string path_base = GetPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); +} + +string TBEPluginLoader::GetPath() { + Dl_info dl_info; + if (dladdr(reinterpret_cast(&TBEPluginLoader::GetPath), &dl_info) == 0) { + GELOGW("Failed to read so path!"); + return string(); + } else { + string so_path = dl_info.dli_fname; + char path[PATH_MAX] = {0}; + if (so_path.length() >= PATH_MAX) { + GELOGW("File path is too long!"); + return string(); + } + if (realpath(so_path.c_str(), path) == nullptr) { + GELOGW("Failed to get realpath of %s", so_path.c_str()); + return string(); + } + + so_path = path; + so_path = so_path.substr(0, so_path.rfind('/') + 1); + return so_path; + } +} + +void TBEPluginLoader::GetPluginSoFileList(const string &path, vector &file_list, string &caffe_parser_path) { + // Support to split multiple so directories by ":" + vector v_path = StringUtils::Split(path, ':'); + for (size_t i = 0; i < v_path.size(); ++i) { + FindParserSo(v_path[i], file_list, caffe_parser_path); + GELOGI("CustomOpLib full name = %s", v_path[i].c_str()); + } +} + +void TBEPluginLoader::FindParserSo(const string &path, vector &file_list, string &caffe_parser_path) { + // Path, change to absolute path + string real_path = ge::parser::RealPath(path.c_str()); + // Plugin path does not exist + if (real_path.empty()) { + GELOGW("RealPath is empty."); + return; + } + struct stat stat_buf; + if ((stat(real_path.c_str(), &stat_buf) != 0) || (!S_ISDIR(stat_buf.st_mode))) { + GELOGW("%s is not a dir.", real_path.c_str()); + return; + } + struct dirent *dent(0); + DIR *dir = opendir(real_path.c_str()); + // Plugin path does not exist + if (dir == nullptr) { + GELOGW("Open directory %s failed.", real_path.c_str()); + return; + } + + while ((dent = readdir(dir)) != nullptr) { + if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) continue; + string name = dent->d_name; + string full_name = real_path + "/" + name; + const string so_suff = ".so"; + const string caffe_parser_so_suff = "lib_caffe_parser.so"; + const string aicpu_so_suff = "_aicpu.so"; + const string aicpu_host_so_suff = "_online.so"; + if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { + ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, + aicpu_host_so_suff); + } else { + FindParserSo(full_name, file_list, caffe_parser_path); + } + } + closedir(dir); +} + +void TBEPluginLoader::ProcessSoFullName(vector &file_list, string &caffe_parser_path, string &full_name, + const string &caffe_parser_so_suff, const string &aicpu_so_suff, + const string &aicpu_host_so_suff) { + if (full_name.size() >= caffe_parser_so_suff.size() && + full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), + caffe_parser_so_suff) == 0) { + caffe_parser_path = full_name; + } else { + // Save parser so path into file_list vector + file_list.push_back(full_name); + } +} +} // namespace ge diff --git a/parser/common/tbe_plugin_loader.h b/parser/common/tbe_plugin_loader.h new file mode 100644 index 0000000..1cd6f6b --- /dev/null +++ b/parser/common/tbe_plugin_loader.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_TBE_PLUGIN_LOADER_H_ +#define PARSER_COMMON_TBE_PLUGIN_LOADER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "external/ge/ge_api_error_codes.h" +#include "external/register/register.h" + +namespace ge { +using SoHandlesVec = std::vector; +class TBEPluginLoader { +public: + Status Finalize(); + + // Get TBEPluginManager singleton instance + static TBEPluginLoader& Instance(); + + void LoadPluginSo(const std::map &options); + + static string GetPath(); + +private: + TBEPluginLoader() = default; + ~TBEPluginLoader() = default; + Status ClearHandles_(); + static void ProcessSoFullName(vector &file_list, string &caffe_parser_path, string &full_name, + const string &caffe_parser_so_suff, const string &aicpu_so_suff, + const string &aicpu_host_so_suff); + static void GetCustomOpPath(std::string &customop_path); + static void GetPluginSoFileList(const string &path, vector &file_list, string &caffe_parser_path); + static void FindParserSo(const string &path, vector &file_list, string &caffe_parser_path); + + SoHandlesVec handles_vec_; + static std::map options_; +}; +} // namespace ge + +#endif //PARSER_COMMON_TBE_PLUGIN_LOADER_H_ diff --git a/parser/common/thread_pool.cc b/parser/common/thread_pool.cc new file mode 100644 index 0000000..dead012 --- /dev/null +++ b/parser/common/thread_pool.cc @@ -0,0 +1,78 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/thread_pool.h" + +#include +#include +#include +#include +#include +#include + +#include "register/register_types.h" + +namespace ge { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) { + idle_thrd_num_ = size < 1 ? 1 : size; + + for (uint32_t i = 0; i < idle_thrd_num_; ++i) { + pool_.emplace_back(ThreadFunc, this); + } +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() { + is_stoped_.store(true); + { + std::unique_lock lock{m_lock_}; + cond_var_.notify_all(); + } + + for (std::thread &thd : pool_) { + if (thd.joinable()) { + try { + thd.join(); + } catch (const std::system_error &) { + GELOGW("system_error"); + } catch (...) { + GELOGW("exception"); + } + } + } +} + +void ThreadPool::ThreadFunc(ThreadPool *thread_pool) { + if (thread_pool == nullptr) { + return; + } + while (!thread_pool->is_stoped_) { + std::function task; + { + std::unique_lock lock{thread_pool->m_lock_}; + thread_pool->cond_var_.wait( + lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); }); + if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { + return; + } + task = std::move(thread_pool->tasks_.front()); + thread_pool->tasks_.pop(); + } + --thread_pool->idle_thrd_num_; + task(); + ++thread_pool->idle_thrd_num_; + } +} +} // namespace ge diff --git a/parser/common/thread_pool.h b/parser/common/thread_pool.h new file mode 100644 index 0000000..08f47e2 --- /dev/null +++ b/parser/common/thread_pool.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_COMMON_THREAD_POOL_H_ +#define PARSER_COMMON_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "framework/common/debug/ge_log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "external/ge/ge_api_error_codes.h" +#include "graph/types.h" +#include "common/ge/ge_util.h" + +namespace ge { +using ThreadTask = std::function; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool { + public: + explicit ThreadPool(uint32_t size = 4); + ~ThreadPool(); + + template + auto commit(Func &&func, Args &&... args) -> std::future { + GELOGD("commit run task enter."); + using retType = decltype(func(args...)); + std::future fail_future; + if (is_stoped_.load()) { + GELOGE(ge::FAILED, "thread pool has been stopped."); + return fail_future; + } + + auto bindFunc = std::bind(std::forward(func), std::forward(args)...); + auto task = ge::MakeShared>(bindFunc); + if (task == nullptr) { + GELOGE(ge::FAILED, "Make shared failed."); + return fail_future; + } + std::future future = task->get_future(); + { + std::lock_guard lock{m_lock_}; + tasks_.emplace([task]() { (*task)(); }); + } + cond_var_.notify_one(); + GELOGD("commit run task end"); + return future; + } + + static void ThreadFunc(ThreadPool *thread_pool); + + private: + std::vector pool_; + std::queue tasks_; + std::mutex m_lock_; + std::condition_variable cond_var_; + std::atomic is_stoped_; + std::atomic idle_thrd_num_; +}; +} // namespace ge + +#endif // PARSER_COMMON_THREAD_POOL_H_ diff --git a/parser/common/tuple.h b/parser/common/tuple.h new file mode 100644 index 0000000..425f215 --- /dev/null +++ b/parser/common/tuple.h @@ -0,0 +1,307 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_TUPLE_H_ +#define GE_COMMON_TUPLE_H_ + +#include +#include +#include +#include +#include +#include +#include "framework/common/debug/log.h" + +namespace ge { +template +class Tuple { + public: + Tuple() = default; + inline ~Tuple() { + delete[] data_heap_; + data_heap_ = nullptr; + } + /// + /// @brief copy constructor from another tuple + /// @param s the source tuple + /// + inline Tuple(const Tuple &s) { this->assign(s.begin(), s.end()); } + /// + /// @brief constructor from initializer list + /// @param init the initializer_list + /// + inline Tuple(const std::initializer_list &init) { this->assign(init.begin(), init.end()); } + /// + /// @brief constructor from vector + /// @param init the vector + /// + inline Tuple(const std::vector &init) { // NOLINT(runtime/explicit) + this->assign(init.begin(), init.end()); + } + /// + /// @brief move constructor from Tuple + /// @param src the source shape + /// + inline Tuple(Tuple &&src) { // NOLINT(runtime/explicit) + this->swap(src); + } + /// + /// @brief construct the Tuple from content of iterator + /// @param begin the beginning of iterator + /// @param end end the end of the iterator + /// @tparam RandomAccessIterator iterator type + /// + template + inline Tuple(RandomAccessIterator begin, RandomAccessIterator end) { + this->assign(begin, end); + } + /// + /// @brief Assign content to tuple from iterator. + /// @param begin the beginning of iterator + /// @param end end the end of the iterator + /// @tparam RandomAccessIterator iterator type + /// + template + inline void assign(const RandomAccessIterator &begin, const RandomAccessIterator &end) { + this->SetDim(end - begin); + (void)std::copy(begin, end, this->begin()); + } + /// + /// @brief Swap current object with other + /// @param other another object to be swapped. + /// + inline void swap(Tuple &other) { // NOLINT(*) + std::swap(ndim_, other.ndim_); + std::swap(num_heap_allocated_, other.num_heap_allocated_); + std::swap(data_stack_, other.data_stack_); + std::swap(data_heap_, other.data_heap_); + } + /// + /// @brief assignment from another tuple. + /// @param src source tuple + /// @return reference of self + /// + inline Tuple &operator=(const Tuple &src) { + if (&src != this) { + this->assign(src.begin(), src.end()); + } + return *this; + } + /// + /// @brief assignment from rvalue of another tuple. + /// @param src source tuple + /// @return reference of self + /// + inline Tuple &operator=(Tuple &&src) { + if (&src != this) { + Tuple(std::move(src)).swap(*this); + } + return *this; + } + /// + /// @brief assignment from initializer list + /// @param init the source initializer list + /// @return reference of self + /// + inline Tuple &operator=(std::initializer_list init) { + this->assign(init.begin(), init.end()); + return *this; + } + /// + /// @return whether two tuple equals + /// @param s the tuple to compare against + /// + inline bool operator==(const Tuple &s) const { + if (ndim_ != s.ndim_) return false; + return std::equal(begin(), end(), s.begin()); + } + /// + /// @return whether two tuple not equal + /// @param s the tuple to compare against + /// + inline bool operator!=(const Tuple &s) const { return !(*this == s); } + /// + /// @return the begin data pointer to content of the tuple + /// + inline const ValueType *begin() const { return ndim_ <= STACK_CACHE_NUM ? data_stack_ : data_heap_; } + /// + /// @return the begin data pointer to content of the tuple + /// + inline ValueType *begin() { return ndim_ <= STACK_CACHE_NUM ? data_stack_ : data_heap_; } + /// + /// @return the data pointer to end of the tuple + /// + inline const ValueType *end() const { + return ndim_ <= STACK_CACHE_NUM ? (data_stack_ + ndim_) : (data_heap_ + ndim_); + } + /// + /// @return the data pointer to end the tuple + /// + inline ValueType *end() { return ndim_ <= STACK_CACHE_NUM ? (data_stack_ + ndim_) : (data_heap_ + ndim_); } + /// + /// @return number of dimension of the tuple + /// + inline uint32_t ndim() const { return ndim_; } + /// + /// @brief get corresponding index + /// @param i dimension index + /// @return the corresponding dimension size + /// + inline ValueType &operator[](size_t i) { return begin()[i]; } + /// + /// @brief get corresponding index + /// @param i dimension index + /// @return the corresponding dimension size + /// + inline const ValueType &operator[](size_t i) const { return begin()[i]; } + /// + /// @brief allow output string of tuple to ostream + /// @param os the output stream + /// @param t the tuple + /// @return the ostream + /// + friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { + os << '['; + const ValueType *begin = t.begin(); + const ValueType *end = t.end(); + for (const ValueType *it = begin; it != end; ++it) { + if (it != begin) os << ','; + os << *it; + } + os << ']'; + return os; + } + /// + /// @brief read tuple from the istream + /// @param is the input stream + /// @param t The tuple + /// @return the istream + /// + friend std::istream &operator>>(std::istream &is, Tuple &t) { + // get ( + if (!HandleLeftBracket(is, t)) { + return is; + } + + // Handle empty tuple + while (isspace(is.peek())) { + (void)is.get(); + } + if (IsRightBracket(is.peek())) { + (void)is.get(); + return is; + } + // Handle non-empty tuple + ValueType idx; + std::vector tmp; + while (is >> idx) { + tmp.push_back(idx); + char ch; + do { + ch = static_cast(is.get()); + } while (isspace(ch)); + if (std::is_integral::value && ch == 'L') { + ch = static_cast(is.get()); + } + if (ch == ',') { + while (true) { + ch = static_cast(is.peek()); + if (isspace(ch)) { + (void)is.get(); + continue; + } + if (IsRightBracket(ch)) { + (void)is.get(); + break; + } + break; + } + if (IsRightBracket(ch)) break; + } else if (IsRightBracket(ch)) { + break; + } else { + is.setstate(std::ios::failbit); + return is; + } + } + t.assign(tmp.begin(), tmp.end()); + return is; + } + + // stack cache size + static const uint32_t STACK_CACHE_NUM = 4; + // in stack space used to store shape when it is small + ValueType data_stack_[STACK_CACHE_NUM]; + // space to store shape when dimension is big + ValueType *data_heap_{nullptr}; + uint32_t ndim_{0}; + + protected: + // number of cells allocated in data_heap_ + uint32_t num_heap_allocated_{0}; + + // internal function to change the dimension + inline void SetDim(uint32_t ndim) { + if (ndim > STACK_CACHE_NUM && ndim > num_heap_allocated_) { + if (data_heap_ != nullptr) { + delete[] data_heap_; + data_heap_ = nullptr; + } + data_heap_ = new (std::nothrow) ValueType[ndim](); + if (data_heap_ == nullptr) { + GELOGW("data_heap_ is nullptr."); + } + num_heap_allocated_ = ndim; + } + ndim_ = ndim; + } + static inline bool IsLeftBracket(char ch) { return ch == '(' || ch == '['; } + + static inline bool IsRightBracket(char ch) { return ch == ')' || ch == ']'; } + + friend bool HandleLeftBracket(std::istream &is, Tuple &t) { + while (true) { + char ch = is.peek(); + if (isdigit(ch) || (ch == '-')) { + ValueType idx; + if (is >> idx) { + t.assign(&idx, &idx + 1); + } + return false; + } + (void)is.get(); + if (IsLeftBracket(ch)) { + break; + } + + if (!isspace(ch)) { + is.setstate(std::ios::failbit); + return false; + } + } + + return true; + } +}; + +using UintTuple = Tuple; +using IntTuple = Tuple; +using FloatTuple = Tuple; +using BoolTuple = Tuple; +using StringTuple = Tuple; +} // namespace ge + +#endif // GE_COMMON_TUPLE_H_ diff --git a/parser/common/types_map.h b/parser/common/types_map.h new file mode 100644 index 0000000..082d607 --- /dev/null +++ b/parser/common/types_map.h @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_TYPES_MAP_H +#define GE_TYPES_MAP_H + +#include "external/graph/types.h" +#include "proto/tensorflow/graph.pb.h" + +namespace ge { +// Correspondence between data_type in GE and tensorflow +static map GE_TENSORFLOW_DATA_TYPE_MAP = { + {ge::DataType::DT_UNDEFINED, domi::tensorflow::DT_INVALID}, + {ge::DataType::DT_FLOAT, domi::tensorflow::DT_FLOAT}, + {ge::DataType::DT_FLOAT16, domi::tensorflow::DT_HALF}, + {ge::DataType::DT_INT8, domi::tensorflow::DT_INT8}, + {ge::DataType::DT_INT16, domi::tensorflow::DT_INT16}, + {ge::DataType::DT_UINT16, domi::tensorflow::DT_UINT16}, + {ge::DataType::DT_UINT8, domi::tensorflow::DT_UINT8}, + {ge::DataType::DT_INT32, domi::tensorflow::DT_INT32}, + {ge::DataType::DT_INT64, domi::tensorflow::DT_INT64}, + {ge::DataType::DT_UINT32, domi::tensorflow::DT_UINT32}, + {ge::DataType::DT_UINT64, domi::tensorflow::DT_UINT64}, + {ge::DataType::DT_STRING, domi::tensorflow::DT_STRING}, + {ge::DataType::DT_RESOURCE, domi::tensorflow::DT_RESOURCE}, + {ge::DataType::DT_BOOL, domi::tensorflow::DT_BOOL}, + {ge::DataType::DT_DOUBLE, domi::tensorflow::DT_DOUBLE}, + {ge::DataType::DT_COMPLEX64, domi::tensorflow::DT_COMPLEX64}, + {ge::DataType::DT_COMPLEX128, domi::tensorflow::DT_COMPLEX128}, + {ge::DataType::DT_QINT8, domi::tensorflow::DT_QINT8}, + {ge::DataType::DT_QINT16, domi::tensorflow::DT_QINT16}, + {ge::DataType::DT_QINT32, domi::tensorflow::DT_QINT32}, + {ge::DataType::DT_QUINT8, domi::tensorflow::DT_QUINT8}, + {ge::DataType::DT_QUINT16, domi::tensorflow::DT_QUINT16}, + {ge::DataType::DT_DUAL, domi::tensorflow::DT_INVALID}, + {ge::DataType::DT_DUAL_SUB_INT8, domi::tensorflow::DT_INVALID}, + {ge::DataType::DT_DUAL_SUB_UINT8, domi::tensorflow::DT_INVALID}, +}; +} // namespace ge +#endif // GE_TYPES_MAP_H diff --git a/parser/func_to_graph/CMakeLists.txt b/parser/func_to_graph/CMakeLists.txt new file mode 100644 index 0000000..ee83229 --- /dev/null +++ b/parser/func_to_graph/CMakeLists.txt @@ -0,0 +1,32 @@ +set(PROTO_LIST + "${TOP_DIR}/inc/register/proto/tensorflow/graph.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/node_def.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/tensor_shape.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/attr_value.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/function.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/op_def.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/resource_handle.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/tensor.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/types.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/versions.proto" + "${TOP_DIR}/inc/register/proto/tensorflow/graph_library.proto" +) + +protobuf_generate_py(ge PROTO_SRCS ${PROTO_LIST}) + +include_directories(${CMAKE_CURRENT_LIST_DIR}) + +############ func2graph/util ############ +add_custom_target(util ALL + DEPENDS ${PROTO_SRCS} + COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/util + && cp -r ${PROTO_SRCS} ${CMAKE_CURRENT_BINARY_DIR}/util + ) + +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/util OPTIONAL + DESTINATION ${INSTALL_LIBRARY_DIR}/func2graph +) + diff --git a/parser/func_to_graph/func2graph.py b/parser/func_to_graph/func2graph.py new file mode 100644 index 0000000..fed7f19 --- /dev/null +++ b/parser/func_to_graph/func2graph.py @@ -0,0 +1,279 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + +import os +import sys +import getopt + +from google.protobuf import text_format +import tensorflow as tf +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework.errors_impl import NotFoundError +from tensorflow.python.platform import gfile + +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.core.framework import versions_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import versions + +sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util")) + +import graph_library_pb2 + + +def _get_num_args(arg_def, node_def): + if arg_def.number_attr: + return node_def.attr[arg_def.number_attr].i + elif arg_def.type_list_attr: + return len(node_def.attr[arg_def.type_list_attr].list.type) + elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: + return 1 + else: + raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) + + +def is_function(fname): + """Checks for a function definition with `fname` in the current context.""" + if context.executing_eagerly(): + return context.context().has_function(fname) + else: + return ops.get_default_graph()._is_function(fname) + +def create_arg_for_input_nodes(fdef, graph_def, input_shapes): + for i, arg_def in enumerate(fdef.signature.input_arg): + node_def = graph_def.node.add() + node_def.name = arg_def.name + node_def.op = "_Arg" + node_def.attr["T"].type = arg_def.type + node_def.attr["index"].i = i + if input_shapes and input_shapes[i] is not None: + input_shape = input_shapes[i] + if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto): + input_shape = input_shape.as_proto() + node_def.attr["shape"].shape.CopyFrom(input_shape) + arg_attrs = fdef.arg_attr[i].attr + for k in arg_attrs: + # Only copy internal attributes. Normal attributes for nodes cannot be + # applied to these Arg nodes. + if k.startswith("_"): + node_def.attr[k].CopyFrom(arg_attrs[k]) + return + +def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name): + for i, arg_def in enumerate(fdef.signature.output_arg): + node_def = graph_def.node.add() + node_def.name = '{}_Retval'.format(arg_def.name) + node_def.op = "_Retval" + node_def.attr["T"].type = arg_def.type + node_def.attr["index"].i = i + node_def.attr["op_def"].s = ops.get_default_graph()._get_op_def(node_def.op).SerializeToString() + + ret_name = fdef.ret[arg_def.name] + node_def.input.append(nested_to_flat_tensor_name[ret_name]) + return + +def updat_input_index(node_def, op_def, nested_to_flat_tensor_name): + flattened_index = 0 + for arg_def in op_def.output_arg: + num_args = _get_num_args(arg_def, node_def) + for i in range(num_args): + # Map tensor names from "node_name:output_arg_name:index" to + # "node_name:flattened_index". + nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) + if flattened_index == 0: + flat_name = node_def.name + else: + flat_name = "{}:{}".format(node_def.name, flattened_index) + nested_to_flat_tensor_name[nested_name] = flat_name + flattened_index += 1 + control_name = "^" + node_def.name + nested_to_flat_tensor_name[control_name] = control_name + return + +def build_tensor_name(fdef, default_graph): + nested_to_flat_tensor_name = {} + for arg_def in fdef.signature.input_arg: + nested_to_flat_tensor_name[arg_def.name] = arg_def.name + control_name = '^{}'.format(arg_def.name) + nested_to_flat_tensor_name[control_name] = control_name + + global op_def + for node_def in fdef.node_def: + f = default_graph._functions.get(node_def.op, None) + if f is not None and hasattr(f, "signature"): + op_def = f.signature + if node_def.op not in copied_functions: + # Since this function is referenced as an op type, we have no choice but + # to copy it into the GraphDef if we want downstream tools to process + # it. + graph_def.library.function.add().CopyFrom(f.definition) + copied_functions.add(node_def.op) + else: + op_def = ops.get_default_graph()._get_op_def(node_def.op) + + for attr in op_def.attr: + if attr.type == "func": + fname = node_def.attr[attr.name].func.name + if not is_function(fname): + raise ValueError("%s function not found." % fname) + elif attr.type == "list(func)": + for fn in node_def.attr[attr.name].list.func: + fname = fn.name + if not is_function(fname): + raise ValueError("%s function not found." % fname) + + # Iterate over output_args in op_def to build the map. + # Index of the output tensor in the flattened list of *all* output + # tensors of the op. + updat_input_index(node_def, op_def, nested_to_flat_tensor_name) + return nested_to_flat_tensor_name + +def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True): + graph_def = graph_pb2.GraphDef() + graph_def.versions.CopyFrom( + versions_pb2.VersionDef( + producer=versions.GRAPH_DEF_VERSION, + min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) + + default_graph = ops.get_default_graph() + + copied_functions = set() + + # Copy *all* functions from outer graph to `graph_def` so that both direct + # and indirect references are safely handled. + if copy_functions: + default_graph._copy_functions_to_graph_def(graph_def, 0) + for function_name in default_graph._functions.keys(): + copied_functions.add(function_name) + + if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): + raise ValueError("Length of input_shapes must match the number of " + + "input_args. len(input_shapes): {} len(input_arg): {}". + format(len(input_shapes), len(fdef.signature.input_arg))) + + # 1. Create _Arg for input nodes. + create_arg_for_input_nodes(fdef, graph_def, input_shapes) + + # 2. Copy all body NodeDefs to the GraphDef. + graph_def.node.extend(fdef.node_def) + + # 3. Perform the renaming. + + # Build the tensor name mapping then flatten the tensor names. + # See comment on `FunctionDef.node_def` on how the tensor naming in + # FunctionDefs is different from GraphDefs. + nested_to_flat_tensor_name = build_tensor_name(fdef, default_graph) + + # Update inputs of all nodes in graph. + for node_def in graph_def.node: + for i in range(len(node_def.input)): + node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] + + # Create _Retval for output nodes. + create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name) + + return graph_def, nested_to_flat_tensor_name + + +def convert_graphs(filename): + try: + with tf.io.gfile.GFile(filename, 'rb') as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + if len(graph_def.library.function) == 0: + print("INFO: The input model does not contain a functionDef and does not require conversion.") + return + try: + convert_subgraphs(graph_def, filename) + except Exception as e: + print("ERROR: Convert subgraphs failed.", e) + return + print("INFO: Convert to subgraphs successfully.") + except NotFoundError: + print('ERROR: model file {} does not exist'.format(filename)) + return + + +def convert_subgraphs(graph_def, filename): + graph_def_library = graph_library_pb2.GraphDefLibrary() + for i, fdef in enumerate(graph_def.library.function): + sub_graph, nested_to_flat_tensor_name = convert_function_def_to_graph_def(fdef, copy_functions=False) + print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name)) + sub_graph_name = '{}.pb'.format(fdef.signature.name) + result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename))) + tf.io.write_graph(sub_graph, result_path, sub_graph_name, as_text=False) + data = sub_graph.SerializeToString() + ge_graph_def = graph_library_pb2.GeGraphDef() + ge_graph_def.name = fdef.signature.name + ge_graph_def.graph.ParseFromString(data) + graph_def_library.graph_def.append(ge_graph_def) + print(graph_def_library.graph_def[i]) + + # Write to prototxt + try: + graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename))) + print("graph_def_file: ", graph_def_file) + with open(graph_def_file, "w") as f: + print(graph_def_library, file=f) + except IOError: + print("Could not open file. Creating a new one.") + + +def usage(): + print( + ''' + Based on tensorflow 1.15 or later, Python 3 + + Convert the tensorflow functionDefs in the input model file to single GraphDefs, + and save the result to the "results" directory and graph_def_library.pbtxt in + the input file directory. + The name of the sub graph is same as the name of the corresponding functionDef. + + Usage: func2grpah.py + + Available commands: + model (-m) Input model file. + version (-v) Prints the version of this software. + help (-h) Prints help for commands. + ''' + ) + + +if __name__ == '__main__': + model = '' + try: + opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model=']) + for opt_name, opt_value in opts: + if opt_name in ('-m', '--model'): + model = opt_value + print("INFO: Input model file is", model) + convert_graphs(model) + elif opt_name in ('-h', '--help'): + usage() + break + elif opt_name in ('-v', '--version'): + print("version 1.0.0") + break + except getopt.GetoptError: + print("ERROR: Input parameters is invalid, use '--help' to view the help.") + if (len(sys.argv) == 1): + print("INFO: Please specify the input parameters, and use '--help' to view the help.") diff --git a/parser/func_to_graph/module.mk b/parser/func_to_graph/module.mk new file mode 100644 index 0000000..6c89342 --- /dev/null +++ b/parser/func_to_graph/module.mk @@ -0,0 +1,9 @@ +LOCAL_PATH := $(call my-dir) + +include $(CLEAR_VARS) + +LOCAL_MODULE := func2graph/util + +LOCAL_MODULE_CLASS := FOLDER + +include $(LOCAL_PATH)/proto_python_rule.mk \ No newline at end of file diff --git a/parser/func_to_graph/proto/attr_value.proto b/parser/func_to_graph/proto/attr_value.proto new file mode 100644 index 0000000..1cc67d6 --- /dev/null +++ b/parser/func_to_graph/proto/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/parser/func_to_graph/proto/function.proto b/parser/func_to_graph/proto/function.proto new file mode 100644 index 0000000..075897c --- /dev/null +++ b/parser/func_to_graph/proto/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/parser/func_to_graph/proto/graph.proto b/parser/func_to_graph/proto/graph.proto new file mode 100644 index 0000000..d639a7d --- /dev/null +++ b/parser/func_to_graph/proto/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/parser/func_to_graph/proto/graph_library.proto b/parser/func_to_graph/proto/graph_library.proto new file mode 100644 index 0000000..e393d38 --- /dev/null +++ b/parser/func_to_graph/proto/graph_library.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package domi.tensorflow; + +import "graph.proto"; + +message GeGraphDef { + string name = 1; + GraphDef graph = 2; +} + +message GraphDefLibrary { + repeated GeGraphDef graph_def = 1; +}; \ No newline at end of file diff --git a/parser/func_to_graph/proto/node_def.proto b/parser/func_to_graph/proto/node_def.proto new file mode 100644 index 0000000..b9bc97e --- /dev/null +++ b/parser/func_to_graph/proto/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/parser/func_to_graph/proto/op_def.proto b/parser/func_to_graph/proto/op_def.proto new file mode 100644 index 0000000..3485d04 --- /dev/null +++ b/parser/func_to_graph/proto/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/parser/func_to_graph/proto/resource_handle.proto b/parser/func_to_graph/proto/resource_handle.proto new file mode 100644 index 0000000..a345235 --- /dev/null +++ b/parser/func_to_graph/proto/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/parser/func_to_graph/proto/tensor.proto b/parser/func_to_graph/proto/tensor.proto new file mode 100644 index 0000000..d0a4d02 --- /dev/null +++ b/parser/func_to_graph/proto/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/parser/func_to_graph/proto/tensor_shape.proto b/parser/func_to_graph/proto/tensor_shape.proto new file mode 100644 index 0000000..4225a2e --- /dev/null +++ b/parser/func_to_graph/proto/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/parser/func_to_graph/proto/types.proto b/parser/func_to_graph/proto/types.proto new file mode 100644 index 0000000..ba7a72b --- /dev/null +++ b/parser/func_to_graph/proto/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/parser/func_to_graph/proto/versions.proto b/parser/func_to_graph/proto/versions.proto new file mode 100644 index 0000000..4806121 --- /dev/null +++ b/parser/func_to_graph/proto/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/parser/func_to_graph/proto_python_rule.mk b/parser/func_to_graph/proto_python_rule.mk new file mode 100644 index 0000000..0749574 --- /dev/null +++ b/parser/func_to_graph/proto_python_rule.mk @@ -0,0 +1,17 @@ +include $(BUILD_SYSTEM)/base_rules.mk + +FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp + +PROTO_SRC_DIR = framework/domi/parser/func_to_graph/proto +PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto + +$(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC)) +$(warning protobuf_lib_dir is $(protobuf_lib_dir)) + +$(FUNCTION_TO_GRAPH_OUT_TIMESTAMP): $(PRIVATE_PROTOC) + mkdir -p $(PY_PROTO_BUILD_DIR) + LD_LIBRARY_PATH=$(protobuf_lib_dir):$$LD_LIBRARY_PATH $(PRIVATE_PROTOC) -I=$(PROTO_SRC_DIR) --python_out=$(PY_PROTO_BUILD_DIR) $(PROTO_SRC_DIR)/*.proto + +$(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP) + mkdir -p $@ + cp -rf $(PY_PROTO_BUILD_DIR)/* $@ \ No newline at end of file diff --git a/parser/module.mk b/parser/module.mk new file mode 100644 index 0000000..319e7a6 --- /dev/null +++ b/parser/module.mk @@ -0,0 +1,143 @@ + +LOCAL_PATH := $(call my-dir) +include $(LOCAL_PATH)/../stub/Makefile +COMMON_LOCAL_C_INCLUDES := \ + proto/om.proto \ + proto/insert_op.proto \ + proto/ge_ir.proto \ + proto/task.proto \ + proto/tensorflow/graph.proto \ + proto/tensorflow/node_def.proto \ + proto/tensorflow/tensor_shape.proto \ + proto/tensorflow/attr_value.proto \ + proto/tensorflow/function.proto \ + proto/tensorflow/op_def.proto \ + proto/tensorflow/resource_handle.proto \ + proto/tensorflow/tensor.proto \ + proto/tensorflow/types.proto \ + proto/tensorflow/versions.proto \ + proto/tensorflow/graph_library.proto \ + proto/caffe/caffe.proto \ + tensorflow/proto/tensorflow/graph.proto \ + tensorflow/proto/tensorflow/node_def.proto \ + tensorflow/proto/tensorflow/tensor_shape.proto \ + tensorflow/proto/tensorflow/attr_value.proto \ + tensorflow/proto/tensorflow/function.proto \ + tensorflow/proto/tensorflow/op_def.proto \ + tensorflow/proto/tensorflow/resource_handle.proto \ + tensorflow/proto/tensorflow/tensor.proto \ + tensorflow/proto/tensorflow/types.proto \ + tensorflow/proto/tensorflow/versions.proto \ + tensorflow/proto/tensorflow/graph_library.proto \ + caffe/proto/caffe/caffe.proto \ + $(LOCAL_PATH) \ + $(LOCAL_PATH)/tensorflow \ + $(LOCAL_PATH)/caffe \ + $(LOCAL_PATH)/../ \ + $(TOPDIR)inc \ + $(TOPDIR)inc/external \ + $(TOPDIR)inc/external/graph \ + $(TOPDIR)inc/external/parser \ + $(TOPDIR)inc/framework \ + $(TOPDIR)framework/domi/parser \ + libc_sec/include \ + third_party/protobuf/include \ + third_party/json/include \ + third_party/openssl/include/x86/include \ + +include $(CLEAR_VARS) + +LOCAL_MODULE := libfmk_parser + +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 +LOCAL_CFLAGS += -Werror +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + +PARSER_TENSORFLOW_SRC_FILES := \ + tensorflow/tensorflow_arg_parser.cc \ + tensorflow/tensorflow_auto_mapping_parser_adapter.cc \ + tensorflow/tensorflow_constant_parser.cc \ + tensorflow/tensorflow_data_parser.cc \ + tensorflow/tensorflow_enter_parser.cc \ + tensorflow/tensorflow_fill_parser.cc \ + tensorflow/tensorflow_frameworkop_parser.cc \ + tensorflow/tensorflow_fusionop_util.cc \ + tensorflow/tensorflow_identity_parser.cc \ + tensorflow/tensorflow_merge_parser.cc \ + tensorflow/tensorflow_no_op_parser.cc \ + tensorflow/tensorflow_parser.cc \ + tensorflow/tensorflow_ref_switch_parser.cc \ + tensorflow/tensorflow_reshape_parser.cc \ + tensorflow/tensorflow_shape_n_parser.cc \ + tensorflow/tensorflow_squeeze_parser.cc \ + tensorflow/tensorflow_var_is_initialized_op_parser.cc \ + tensorflow/tensorflow_variable_v2_parser.cc \ + tensorflow/proto/tensorflow/graph_library.proto \ + caffe/caffe_parser.cc \ + caffe/caffe_data_parser.cc \ + caffe/caffe_reshape_parser.cc \ + caffe/caffe_custom_parser_adapter.cc \ + caffe/caffe_op_parser.cc \ + +PARSER_SCOPE_SRC_FILES := \ + tensorflow/scope/scope_pass_manager.cc \ + +FMK_COMMON_SRC_FILES := \ + tensorflow/graph_functiondef.cc \ + tensorflow/graph_optimizer.cc \ + tensorflow/iterator_fusion_pass.cc \ + common/op_def/arg_op.cc \ + common/op_def/constant_op.cc \ + common/op_def/fill_op.cc \ + common/op_def/frameworkop_op.cc \ + common/op_def/no_op_op.cc \ + common/op_def/ref_switch_op.cc \ + common/op_def/shape_n_op.cc \ + common/op_def/var_is_initialized_op_op.cc \ + common/op_def/variable_op.cc \ + +LOCAL_SRC_FILES := $(PARSER_TENSORFLOW_SRC_FILES) +LOCAL_SRC_FILES += $(PARSER_SCOPE_SRC_FILES) +LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES) + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) + +LOCAL_SHARED_LIBRARIES := \ + libprotobuf \ + libslog \ + libmmpa \ + libc_sec \ + liberror_manager \ + libparser_common \ + libgraph \ + libregister \ + lib_caffe_parser \ + +LOCAL_LDFLAGS := -lrt + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for host parser +include $(CLEAR_VARS) + +LOCAL_MODULE := stub/libfmk_parser + +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 +LOCAL_CFLAGS += -DFMK_HOST_INFER -DFMK_SUPPORT_DUMP +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) + +LOCAL_SRC_FILES := ../../../out/ge/lib64/stub/tensorflow_parser.cc +LOCAL_SRC_FILES += ../../../out/ge/lib64/stub/caffe_parser.cc + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +include $(BUILD_HOST_SHARED_LIBRARY) diff --git a/parser/onnx/CMakeLists.txt b/parser/onnx/CMakeLists.txt new file mode 100644 index 0000000..db85e08 --- /dev/null +++ b/parser/onnx/CMakeLists.txt @@ -0,0 +1,61 @@ +set(PROTO_LIST + "${TOP_DIR}/inc/register/proto/onnx/ge_onnx.proto" + "${TOP_DIR}/inc/common/proto/om.proto" +) + +set(SRC_LIST + "onnx_custom_parser_adapter.cc" + "onnx_parser.cc" + "onnx_data_parser.cc" + "onnx_util.cc" + "onnx_constant_parser.cc" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +############ libfmk_onnx_parser.so ############ +add_library(fmk_onnx_parser SHARED ${SRC_LIST} ${PROTO_HDRS}) + +target_compile_options(fmk_onnx_parser PRIVATE + -Werror +) + +target_compile_definitions(fmk_onnx_parser PRIVATE + PROTOBUF_INLINE_NOT_IN_HEADERS=0 +) + +target_include_directories(fmk_onnx_parser PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${TOP_DIR}/framework/domi/parser + ${TOP_DIR}/framework/domi + ${TOP_DIR}/inc + ${TOP_DIR}/inc/external + ${TOP_DIR}/inc/external/graph + ${TOP_DIR}/inc/framework + ${CMAKE_BINARY_DIR} + ${CMAKE_BINARY_DIR}/proto/ge +) + +target_link_libraries(fmk_onnx_parser PRIVATE + $ + -Wl,--no-as-needed + protobuf + ge_common + register + c_sec + parser_common + graph + slog + mmpa + -Wl,--as-needed + json + -lrt +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS fmk_onnx_parser OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/parser/onnx/module.mk b/parser/onnx/module.mk new file mode 100644 index 0000000..aee731f --- /dev/null +++ b/parser/onnx/module.mk @@ -0,0 +1,50 @@ + +LOCAL_PATH := $(call my-dir) + +include $(CLEAR_VARS) + +LOCAL_MODULE := libfmk_onnx_parser + +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 +LOCAL_CFLAGS += -Werror +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + +PARSER_ONNX_SRC_FILES := \ + onnx_custom_parser_adapter.cc \ + onnx_parser.cc \ + onnx_data_parser.cc \ + onnx_util.cc \ + onnx_constant_parser.cc \ + proto/onnx/ge_onnx.proto \ + proto/om.proto \ + +LOCAL_SRC_FILES := $(PARSER_ONNX_SRC_FILES) + +LOCAL_C_INCLUDES := \ + $(LOCAL_PATH) \ + $(LOCAL_PATH)/../../ \ + $(TOPDIR)inc \ + $(TOPDIR)inc/external \ + $(TOPDIR)inc/external/graph \ + $(TOPDIR)inc/framework \ + $(TOPDIR)framework/domi/parser \ + libc_sec/include \ + third_party/protobuf/include \ + third_party/json/include \ + third_party/openssl/include/x86/include \ + +LOCAL_SHARED_LIBRARIES := \ + libprotobuf \ + libslog \ + libmmpa \ + libc_sec \ + libparser_common \ + libgraph \ + libregister \ + libge_common \ + +LOCAL_LDFLAGS := -lrt + +include $(BUILD_HOST_SHARED_LIBRARY) diff --git a/parser/onnx/onnx_constant_parser.cc b/parser/onnx/onnx_constant_parser.cc new file mode 100644 index 0000000..55393aa --- /dev/null +++ b/parser/onnx/onnx_constant_parser.cc @@ -0,0 +1,213 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "onnx_constant_parser.h" +#include +#include +#include "common/ge/ge_util.h" +#include "common/util.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "graph/ge_tensor.h" +#include "graph/utils/tensor_adapter.h" +#include "parser/common/op_parser_factory.h" +#include "parser/onnx/onnx_util.h" + +using ge::onnx::NodeProto; +using ge::onnx::TensorProto; +using domi::ONNX; +using GeShape = ge::GeShape; +using GeTensorDesc = ge::GeTensorDesc; +using namespace ge::parser; + +namespace ge { +Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) { + int64_t data_type = tensor_proto.data_type(); + if (ge::OnnxUtil::ConvertOnnxDataType(data_type) == ge::DataType::DT_UNDEFINED) { + GELOGE(FAILED, "data_type %ld not support.", data_type); + return FAILED; + } + + if (count == 0) { + GELOGI("At least one dim equals zero, result in the count equal to zero."); + return SUCCESS; + } + + std::map datatype_val_size_map = { + {OnnxDataType::INT32, tensor_proto.int32_data_size()}, + {OnnxDataType::INT64, tensor_proto.int64_data_size()}, + {OnnxDataType::STRING, tensor_proto.string_data_size()}, + {OnnxDataType::FLOAT, tensor_proto.float_data_size()}, + {OnnxDataType::DOUBLE, tensor_proto.double_data_size()}, + {OnnxDataType::UINT64, tensor_proto.uint64_data_size()}, + {OnnxDataType::UINT8, 0}, + {OnnxDataType::INT8, 0}, + {OnnxDataType::UINT16, 0}, + {OnnxDataType::INT16, 0}, + {OnnxDataType::BOOL, 0}, + {OnnxDataType::FLOAT16, 0}, + {OnnxDataType::UINT32, 0}, + {OnnxDataType::COMPLEX64, 0}, + {OnnxDataType::COMPLEX128, 0}, + {OnnxDataType::BFLOAT16, 0}, + }; + + int32_t datatype_val_size = 0; + auto iter = datatype_val_size_map.find(data_type); + if (iter != datatype_val_size_map.end()) { + datatype_val_size = iter->second; + } else { + GELOGE(domi::PARAM_INVALID, "data_type %ld not support.", data_type); + return FAILED; + } + + // find raw data + if (datatype_val_size == 0) { + if (tensor_proto.raw_data().empty()) { + GELOGE(domi::PARAM_INVALID, "tensor_proto has no data() elements or raw_data()"); + return FAILED; + } + + if (data_type == OnnxDataType::STRING) { + tensor.SetData(tensor_proto.raw_data()); + } else { + tensor.SetData(reinterpret_cast(tensor_proto.raw_data().c_str()), + tensor_proto.raw_data().size()); + } + GELOGD("Raw data size is : %zu", tensor_proto.raw_data().size()); + return SUCCESS; + } + + // find _data() elements + ParseConvertDataElements(tensor_proto, tensor, count, data_type); + return SUCCESS; +} + +void OnnxConstantParser::ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, + int count, int64_t data_type) { + switch (data_type) { + case OnnxDataType::INT32: + (void)SetTensorData(tensor_proto.int32_data_size(), tensor_proto.int32_data(), count, tensor); + break; + case OnnxDataType::INT64: + (void)SetTensorData(tensor_proto.int64_data_size(), tensor_proto.int64_data(), count, tensor); + break; + case OnnxDataType::STRING: { + std::vector data; + for (auto str_data : tensor_proto.string_data()) { + data.emplace_back(str_data); + } + tensor.SetData(data); + break; + } + case OnnxDataType::FLOAT: + (void)SetTensorData(tensor_proto.float_data_size(), tensor_proto.float_data(), count, tensor); + break; + case OnnxDataType::DOUBLE: + (void)SetTensorData(tensor_proto.double_data_size(), tensor_proto.double_data(), count, tensor); + break; + case OnnxDataType::UINT64: + (void)SetTensorData(tensor_proto.uint64_data_size(), tensor_proto.uint64_data(), count, tensor); + break; + default: + break; + } +} + +Status OnnxConstantParser::ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor) { + // convert shape and format + std::vector tmp_shape; + int count = 1; + for (int i = 0; i < tensor_proto.dims_size(); i++) { + tmp_shape.push_back(tensor_proto.dims(i)); + int64_t dim = tmp_shape[i]; + // support weights shape [0],have no weights + if (dim < 0 || (count != 0 && (dim >= INT64_MAX / count))) { + GELOGE(FAILED, "Dim size is invalid, dim is less than zero or dim size exceeds INT64_MAX."); + return FAILED; + } + count *= dim; + }; + TensorDesc tensor_desc = tensor.GetTensorDesc(); + tensor_desc.SetShape(ge::Shape(tmp_shape)); + tensor_desc.SetFormat(static_cast(GetParserContext().format)); + tensor.SetTensorDesc(tensor_desc); + + // set data + if (ParseConvertData(tensor_proto, tensor, count) != SUCCESS) { + GELOGE(FAILED, "Convert ge tensor data and format failed."); + return FAILED; + } + return SUCCESS; +} + +Status OnnxConstantParser::ParseConvertDataType(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor) { + int64_t data_type = tensor_proto.data_type(); + ge::DataType type = ge::OnnxUtil::ConvertOnnxDataType(data_type); + if (type == ge::DataType::DT_UNDEFINED) { + GELOGE(domi::PARAM_INVALID, "tensor_proto date type %ld is undefined.", data_type); + return FAILED; + } + + TensorDesc tensor_desc = tensor.GetTensorDesc(); + tensor_desc.SetDataType(ge::DataType(type)); + tensor.SetTensorDesc(tensor_desc); + return SUCCESS; +} + +Status OnnxConstantParser::ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def) { + GE_CHECK_NOTNULL(op_src); + const NodeProto *node = reinterpret_cast(op_src); + + // Get const Tensor from node + Tensor tensor; + for (auto it : node->attribute()) { + if (it.name() != ge::kAttrNameValue) { + continue; + } + const ::ge::onnx::TensorProto it_tensor = it.t(); + if (ParseConvertDataType(it_tensor, tensor) != SUCCESS) { + GELOGE(FAILED, "Convert ge tensor date type failed, attribute name is %s.", it.name().c_str()); + return FAILED; + } + + if (ParseConvertTensor(it_tensor, tensor) != SUCCESS) { + GELOGE(FAILED, "Convert ge tensor shape and format failed, attribute name is %s.", it.name().c_str()); + return FAILED; + } + } + + op_def.SetAttr(ge::kAttrNameValue, tensor); + auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_def); + op_def.UpdateOutputDesc(op_desc->GetOutputNameByIndex(0), tensor.GetTensorDesc()); + + return SUCCESS; +} + +Status OnnxConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) { + GE_CHECK_NOTNULL(op_src); + const ge::onnx::NodeProto *node = reinterpret_cast(op_src); + GE_CHECK_NOTNULL(node); + GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); + + if (ParseConstFromInput(node, op_def) != SUCCESS) { + GELOGE(FAILED, "Parse constant node %s failed", node->name().c_str()); + return FAILED; + } + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(ONNX, CONSTANT, OnnxConstantParser); +} // namespace ge diff --git a/parser/onnx/onnx_constant_parser.h b/parser/onnx/onnx_constant_parser.h new file mode 100644 index 0000000..cfaafac --- /dev/null +++ b/parser/onnx/onnx_constant_parser.h @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_ONNX_CONSTANT_PARSER_H_ +#define GE_PARSER_ONNX_ONNX_CONSTANT_PARSER_H_ + +#include +#include "parser/common/data_op_parser.h" +#include "parser/onnx/onnx_op_parser.h" + +using ge::onnx::NodeProto; + +namespace ge { +class OnnxConstantParser : public OnnxOpParser { + public: + Status ParseParams(const Message *op_src, ge::Operator &op_def) override; + + private: + Status ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def); + Status ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); + Status ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count); + void ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count, + int64_t data_type); + Status ParseConvertDataType(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); + + template + static Status SetTensorData(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, + Tensor &tensor) { + bool zeros_like = (count != val_size && val_size == 1); + T *addr = new (std::nothrow) T[count](); + GE_CHECK_NOTNULL(addr); + int minCount = (count > val_size) ? val_size : count; + if (!zeros_like) { + for (int32_t i = 0; i < minCount; i++) { + *(addr + i) = val_vector.Get(i); + } + for (int32_t i = minCount; i < count; i++) { + *(addr + i) = val_vector.Get(minCount - 1); + } + } else { + for (int32_t i = 0; i < count; i++) { + *(addr + i) = val_vector.Get(0); + } + } + tensor.SetData(reinterpret_cast(addr), count * sizeof(T)); + GE_DELETE_NEW_ARRAY(addr); + return SUCCESS; + } +}; +} // namespace ge + +#endif // GE_PARSER_ONNX_ONNX_CONSTANT_PARSER_H_ diff --git a/parser/onnx/onnx_custom_parser_adapter.cc b/parser/onnx/onnx_custom_parser_adapter.cc new file mode 100644 index 0000000..fa7d702 --- /dev/null +++ b/parser/onnx/onnx_custom_parser_adapter.cc @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/onnx/onnx_custom_parser_adapter.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "parser/common/op_parser_factory.h" +#include "register/op_registry.h" + +using domi::ParseParamFunc; +using domi::ONNX; + +namespace ge{ +Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { + GE_CHECK_NOTNULL(op_src); + const ge::onnx::NodeProto *node_src = reinterpret_cast(op_src); + GE_CHECK_NOTNULL(node_src); + GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); + + ParseParamFunc + custom_op_parser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type()); + GE_CHECK_NOTNULL(custom_op_parser); + if (custom_op_parser(op_src, op_dest) != SUCCESS) { + GELOGE(FAILED, "Custom parser params failed."); + return FAILED; + } + return SUCCESS; +} + +REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(ONNX, OnnxCustomParserAdapter); +} // namespace ge diff --git a/parser/onnx/onnx_custom_parser_adapter.h b/parser/onnx/onnx_custom_parser_adapter.h new file mode 100644 index 0000000..fbbdb2f --- /dev/null +++ b/parser/onnx/onnx_custom_parser_adapter.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_ONNX_ONNX_CUSTOM_PARSER_ADAPTER_H_ +#define PARSER_ONNX_ONNX_CUSTOM_PARSER_ADAPTER_H_ + +#include "parser/onnx/onnx_op_parser.h" + +namespace ge { +class OnnxCustomParserAdapter : public OnnxOpParser { + public: + /// @brief Parsing model file information + /// @param [in] op_src model data to be parsed + /// @param [out] op_dest model data after parsing + /// @return SUCCESS parse successfully + /// @return FAILED parse failed + Status ParseParams(const Message *op_src, ge::Operator &op_dest) override; +}; +} // namespace ge + +#endif // PARSER_ONNX_ONNX_CUSTOM_PARSER_ADAPTER_H_ diff --git a/parser/onnx/onnx_data_parser.cc b/parser/onnx/onnx_data_parser.cc new file mode 100644 index 0000000..7b396b7 --- /dev/null +++ b/parser/onnx/onnx_data_parser.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "onnx_data_parser.h" +#include +#include "common/util.h" +#include "parser/common/op_parser_factory.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "parser/onnx/onnx_util.h" + +using domi::ONNX; +using namespace ge::parser; + +namespace ge { +Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { + GE_CHECK_NOTNULL(op_src); + const ge::onnx::NodeProto *node_src = reinterpret_cast(op_src); + GE_CHECK_NOTNULL(node_src); + GELOGD("Onnx op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op_type().c_str()); + if (ParseInputFromModel(op_src, op_def) != SUCCESS) { + GELOGE(FAILED, "parse shape of data op %s from model failed", op_def.GetName().c_str()); + return FAILED; + } + + if (ParseInputFromUser(op_def) != SUCCESS) { + GELOGE(FAILED, "parse shape of data op %s from user failed", op_def.GetName().c_str()); + return FAILED; + } + + ge::TensorDesc tensor_desc; + tensor_desc.SetFormat(static_cast(GetParserContext().format)); + tensor_desc.SetShape(ge::Shape(user_input_dims_v_)); + int64_t type = 1; + (void)op_def.GetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, type); + tensor_desc.SetDataType(static_cast(type)); + + auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_def); + op_def.UpdateInputDesc(op_desc->GetInputNameByIndex(0), tensor_desc); + op_def.UpdateOutputDesc(op_desc->GetOutputNameByIndex(0), tensor_desc); + + return SUCCESS; +} + +int64_t OnnxDataParser::ParseInputTensor(const ge::onnx::AttributeProto &attribute) { + const ::ge::onnx::TensorProto it_tensor = attribute.t(); + int64_t data_type = it_tensor.data_type(); + GELOGI("Attr name: %s, data type: %ld ", attribute.name().c_str(), data_type); + for (auto dim : it_tensor.dims()) { + model_input_dims_v_.push_back(dim); + } + return data_type; +} + +Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &op_def) { + GE_CHECK_NOTNULL(op_src); + const ge::onnx::NodeProto *node = reinterpret_cast(op_src); + GE_CHECK_NOTNULL(node); + + // Get attr t:'input_tensor' form NodeProto + int64_t data_type = 1; + int64_t index = 0; + for (auto it : node->attribute()) { + if (it.name() == ge::kAttrNameInput) { + data_type = ParseInputTensor(it); + } else if (it.name() == ge::kAttrNameIndex) { + index = it.i(); + GELOGI("The node has attribute with index: %ld", index); + } + } + + // Trans onnx type to ge type + DataType type = OnnxUtil::ConvertOnnxDataType(data_type); + if (type == ge::DataType::DT_UNDEFINED) { + GELOGE(domi::PARAM_INVALID, "tensor_proto date type %ld is undefined.", data_type); + return FAILED; + } + op_def.SetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, static_cast(type)); + op_def.SetAttr(ge::ATTR_NAME_INDEX, index); + + return SUCCESS; +} + +Status OnnxDataParser::ParseInputFromUser(const ge::Operator &op_def) { + std::unordered_map> input_dims = GetParserContext().input_dims; + // User not designate the input_shape + std::string name = op_def.GetName(); + if (input_dims.count(name) == 0) { + GELOGI("input shape of node %s is not designated ,need parse from model", name.c_str()); + for (uint32_t i = 0; i < model_input_dims_v_.size(); i++) { + user_input_dims_v_.push_back(model_input_dims_v_[i]); + } + return SUCCESS; + } + + /// User designate the input_shape by passing '--input_shape=xxx:x,x,x,x' + /// Two cases below both OK: + /// 1. the input_shape not defined in the model(dimension is 0). + /// 2. the input_shape defined in the model(dimension greater than 0), and the dimension matches with user + /// designate_dim. + std::vector designated_dims = input_dims.at(name); + size_t input_dim_size = designated_dims.size(); + if (!(model_input_dims_v_.empty() || input_dim_size == model_input_dims_v_.size())) { + GELOGD("user designated input_dim_num %zu does match input_dim_num %zu defined by model", + input_dim_size, model_input_dims_v_.size()); + return domi::PARAM_INVALID; + } + + // replace with the user designated_dims + user_input_dims_v_.swap(designated_dims); + + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(ONNX, DATA, OnnxDataParser); +} // namespace ge diff --git a/parser/onnx/onnx_data_parser.h b/parser/onnx/onnx_data_parser.h new file mode 100644 index 0000000..ca68b1d --- /dev/null +++ b/parser/onnx/onnx_data_parser.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_ONNX_DATA_PARSER_H_ +#define GE_PARSER_ONNX_ONNX_DATA_PARSER_H_ + +#include +#include +#include "parser/common/data_op_parser.h" +#include "parser/onnx/onnx_op_parser.h" + +namespace ge { +class OnnxDataParser : public OnnxOpParser { + public: + Status ParseParams(const Message *op_src, ge::Operator &op_def) override; + + private: + Status ParseInputFromModel(const Message *op_src, ge::Operator &op_def); + + Status ParseInputFromUser(const ge::Operator &op_def); + + int64_t ParseInputTensor(const ge::onnx::AttributeProto &attribute); + + std::vector model_input_dims_v_; + + std::vector user_input_dims_v_; +}; +} // namespace ge + +#endif // GE_PARSER_ONNX_ONNX_DATA_PARSER_H_ diff --git a/parser/onnx/onnx_op_parser.h b/parser/onnx/onnx_op_parser.h new file mode 100644 index 0000000..a986f5d --- /dev/null +++ b/parser/onnx/onnx_op_parser.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_ONNX_OP_PARSER_H_ +#define GE_PARSER_ONNX_ONNX_OP_PARSER_H_ + +#include +#include +#include "framework/common/op/attr_value_util.h" +#include "framework/omg/parser/op_parser.h" +#include "graph/ge_tensor.h" +#include "graph/node.h" +#include "proto/onnx/ge_onnx.pb.h" + +using Status = domi::Status; + +namespace ge { +class OnnxOpParser : public OpParser { + public: + /// @brief parse params + /// @param [in] op_src op to be parsed + /// @param [out] op_dest the parsed op + /// @return SUCCESS parse success + /// @return FAILED Parse failed + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override { + return domi::SUCCESS; + } + + /// @brief parse params + /// @param [in] op_src op to be parsed + /// @param [out] op_dest the parsed op + /// @return SUCCESS parse success + /// @return FAILED Parse failed + Status ParseParams(const Message *op_src, ge::Operator &op_dest) override { + return domi::SUCCESS; + } + + /// @brief parsie weight + /// @param [in] op_src op to be parsed + /// @param [out] op_dest the parsed op + /// @return SUCCESS parsing success + /// @return FAILED parsing failed + Status ParseWeights(const Message *op_src, ge::NodePtr &node) override { + return domi::SUCCESS; + } +}; +} // namespace ge + +#endif // GE_PARSER_ONNX_ONNX_OP_PARSER_H_ diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc new file mode 100644 index 0000000..25e6f8c --- /dev/null +++ b/parser/onnx/onnx_parser.cc @@ -0,0 +1,572 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "onnx_parser.h" +#include +#include +#include "common/convert/pb2json.h" +#include "common/util.h" +#include "external/graph/operator_factory.h" +#include "external/register/register_error_codes.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "framework/omg/parser/parser_types.h" +#include "omg/parser/parser_factory.h" +#include "onnx_op_parser.h" +#include "onnx_util.h" +#include "parser/common/op_parser_factory.h" +#include "parser/common/pre_checker.h" +#include "parser/common/acl_graph_parser_util.h" +#include "parser/common/model_saver.h" +#include "parser/onnx/onnx_util.h" +#include "register/op_registry.h" + +namespace ge { +namespace { +std::map kOnnxOpMap = { + {ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, +}; +} + +Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, + std::map &initializer_name_tensor) { + if (onnx_graph.input_size() == 0) { + GELOGE(FAILED, "Onnx graph has zero input"); + return FAILED; + } + + // get input value info map + std::map input_name_tensor; + for (int i = 0; i < onnx_graph.input_size(); i++) { + ge::onnx::ValueInfoProto value_info = onnx_graph.input(i); + GELOGI("The index of %d input name : %s.", i, value_info.name().c_str()); + + // The input are possibly initialized by a default value found in ‘initializer.’ + auto initializer_iter = initializer_name_tensor.find(value_info.name()); + if (initializer_iter != initializer_name_tensor.end()) { + input_name_tensor[value_info.name()] = initializer_iter->second; + initializer_name_tensor.erase(initializer_iter); + continue; + } + + ge::onnx::TensorProto tensor_tmp; + if (value_info.has_type()) { + const ge::onnx::TypeProto type = value_info.type(); + if (type.has_tensor_type()) { + const ge::onnx::TypeProto_Tensor type_proto_tensor = type.tensor_type(); + int32_t elem_type = type_proto_tensor.elem_type(); + tensor_tmp.set_data_type(elem_type); + if (type_proto_tensor.has_shape()) { + const ge::onnx::TensorShapeProto tensor_shape = type_proto_tensor.shape(); + for (int j = 0; j < tensor_shape.dim_size(); j++) { + const ge::onnx::TensorShapeProto_Dimension dimension = tensor_shape.dim(j); + int64_t dim_value = dimension.dim_value(); + tensor_tmp.add_dims(dim_value); + GELOGI("elem_type: %d, dim_value: %ld", elem_type, dim_value); + } + } + } + } + input_name_tensor[value_info.name()] = tensor_tmp; + } + + // Construct node for input + int64_t index = 0; + for (auto it : input_name_tensor) { + ge::onnx::NodeProto *input_node = onnx_graph.add_node(); + input_node->set_name(it.first); + input_node->set_op_type(ge::kOpTypeInput); + input_node->add_output(it.first); + // add tensor + ge::onnx::AttributeProto *attribute = input_node->add_attribute(); + attribute->set_name(ge::kAttrNameInput); + ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); + *attribute_tensor = it.second; + // add index + ge::onnx::AttributeProto *attribute_index = input_node->add_attribute(); + attribute_index->set_name(ge::kAttrNameIndex); + attribute_index->set_i(index++); + + input_node_names_.emplace_back(it.first); + } + return SUCCESS; +} + +Status OnnxModelParser::ParseOutput(const ge::onnx::GraphProto &onnx_graph) { + if (onnx_graph.output_size() == 0) { + GELOGE(FAILED, "Onnx graph has zero output"); + return FAILED; + } + + for (int i = 0; i < onnx_graph.output_size(); i++) { + ge::onnx::ValueInfoProto value_info = onnx_graph.output(i); + GELOGI("The index of %d output name : %s.", i, value_info.name().c_str()); + + auto it = outputs_map_.find(value_info.name()); + if (it != outputs_map_.end()) { + std::string node_name = it->second[0].first; + output_node_names_.emplace_back(node_name); + GELOGI("Output node name: %s", node_name.c_str()); + } + } + + return SUCCESS; +} + +Status OnnxModelParser::ParseInitializer(ge::onnx::GraphProto &onnx_graph, + std::map &initializer_name_tensor) { + // Construct const node for weight + int index = 0; + for (auto it : initializer_name_tensor) { + ge::onnx::NodeProto *const_node = onnx_graph.add_node(); + std::string output_name = it.first + "_" + to_string(index++); + const_node->set_name(output_name); + const_node->set_op_type(ge::kOpTypeConstant); + const_node->add_output(it.first); + ge::onnx::AttributeProto *attribute = const_node->add_attribute(); + attribute->set_name(ge::kAttrNameValue); + ge::onnx::TensorProto *attribute_t = attribute->mutable_t(); + *attribute_t = it.second; + } + + return SUCCESS; +} + +Status OnnxModelParser::UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph) { + int index = 0; + for (int i = 0; i < onnx_graph.node_size(); i++) { + ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); + if (node->name().empty()) { + std::string node_name = node->op_type() + "_" + to_string(index++); + node->set_name(node_name); + } + } + + return SUCCESS; +} + +Status OnnxModelParser::ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type) { + GE_CHECK_NOTNULL(node_proto); + + ori_type = node_proto->op_type(); + if (kOnnxOpMap.find(ori_type) != kOnnxOpMap.end()) { + return SUCCESS; + } + + std::string domain = node_proto->domain(); + int64_t version = 0; + if (!domain.empty()) { + auto it = domain_verseion_.find(domain); + if (it != domain_verseion_.end()) { + version = it->second; + } else { + GELOGE(PARAM_INVALID, "The opset of domain[%s] has no responding version.", domain.c_str()); + return PARAM_INVALID; + } + } else { + if (domain_verseion_.size() == 1){ + domain = domain_verseion_.begin()->first; + version = domain_verseion_.begin()->second; + } else { + GELOGE(PARAM_INVALID, "The opset size %zu is bigger than one.", domain_verseion_.size()); + return PARAM_INVALID; + } + } + + if (domain.empty()) { + domain = "ai.onnx"; + } + + ori_type = domain + "::" + to_string(version) + "::" + ori_type; + return SUCCESS; +} + +Status OnnxModelParser::AdapterOpType(const ge::onnx::NodeProto *node_proto, std::string &ori_type, + std::string &op_type) { + GE_CHECK_NOTNULL(node_proto); + ori_type = node_proto->op_type(); + + auto map_it = kOnnxOpMap.find(ori_type); + if (map_it != kOnnxOpMap.end()) { + op_type = map_it->second; + ori_to_om_type_[ori_type] = op_type; + return SUCCESS; + } + + Status ret = ConstructOriType(node_proto, ori_type); + if (ret != SUCCESS) { + GELOGE(ret, "Construct ori type for [%s] failed.", ori_type.c_str()); + return ret; + } + + if (!domi::OpRegistry::Instance()->GetOmTypeByOriOpType(ori_type, op_type)) { + GELOGE(PARAM_INVALID, "Get omType according ori_type : %s failed.", ori_type.c_str()); + return PARAM_INVALID; + } + + ori_to_om_type_[ori_type] = op_type; + return SUCCESS; +} + +Status OnnxModelParser::TransNodeToOperator(const ge::onnx::NodeProto *node_proto, ge::Operator &op, + const string &op_type) { + GE_CHECK_NOTNULL(node_proto); + string node_name = node_proto->name(); + op = ge::OperatorFactory::CreateOperator(node_name, op_type); + if (op.GetName() != node_name) { + GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); + return INTERNAL_ERROR; + } + + GELOGI("After create operator, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), + op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); + return SUCCESS; +} + +Status OnnxModelParser::ConstructInputOutputContext(const ge::onnx::NodeProto *node_proto) { + GE_CHECK_NOTNULL(node_proto); + + std::string node_name = node_proto->name(); + for (int i = 0; i < node_proto->input_size(); i++) { + std::string input_name = node_proto->input(i); + inputs_map_[input_name].emplace_back(node_name, i); + } + + for (int i = 0; i < node_proto->output_size(); i++) { + std::string output_name = node_proto->output(i); + outputs_map_[output_name].emplace_back(node_name, i); + } + + return SUCCESS; +} + +Status OnnxModelParser::SetOperatorInputs() { + for (auto in_iter = inputs_map_.begin(); in_iter != inputs_map_.end(); in_iter++) { + auto out_iter = outputs_map_.find(in_iter->first); + if (out_iter == outputs_map_.end()) { + GELOGE(INTERNAL_ERROR, "Unknown input: %s:%d in node: %s", in_iter->first.c_str(), in_iter->second[0].second, + in_iter->second[0].first.c_str()); + return INTERNAL_ERROR; + } + + std::vector> &input_node_indexs = in_iter->second; + std::vector> &output_node_indexs = out_iter->second; + for (auto input_node_index : input_node_indexs) { + for (auto out_node_index : output_node_indexs) { + auto input_op_iter = name_operator_.find(input_node_index.first); + if (input_op_iter == name_operator_.end()) { + GELOGE(INTERNAL_ERROR, "Node: %s can not find in name_operator map.", input_node_index.first.c_str()); + return INTERNAL_ERROR; + } + auto output_op_iter = name_operator_.find(out_node_index.first); + if (output_op_iter == name_operator_.end()) { + GELOGE(INTERNAL_ERROR, "Node: %s can not find in name_operator map.", out_node_index.first.c_str()); + return INTERNAL_ERROR; + } + + auto dst_op = input_op_iter->second; + auto src_op = output_op_iter->second; + int dst_index = input_node_index.second; + int src_index = out_node_index.second; + GELOGI("Start add output:%d of op:%s as input:%d of op:%s.", src_index, src_op.GetName().c_str(), dst_index, + dst_op.GetName().c_str()); + auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); + GE_CHECK_NOTNULL(dst_op_desc); + auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); + GE_CHECK_NOTNULL(src_op_desc); + dst_op.SetInput(dst_op_desc->GetInputNameByIndex(dst_index), src_op, + src_op_desc->GetOutputNameByIndex(src_index)); + } + } + } + return SUCCESS; +} + +Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { + ge::PreChecker::Instance().Clear(); + for (int i = 0; i < onnx_graph.node_size(); i++) { + ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); + std::string ori_type; + Status ret = ConstructOriType(node, ori_type); + if (ret != SUCCESS) { + GELOGE(ret, "Construct ori type for [%s] failed.", ori_type.c_str()); + return ret; + } + GELOGI("Construct ori type : %s ", ori_type.c_str()); + if (ge::PreChecker::Instance().AddOp(node, node->name(), ori_type) != SUCCESS) { + GELOGE(FAILED, "Add node_def to PreChecker failed, node name: %s.", node->name().c_str()); + return FAILED; + } + if (ge::PreChecker::Instance().CheckName(node) != SUCCESS) { + GELOGE(FAILED, "Check node_def name failed, node name: %s.", node->name().c_str()); + return FAILED; + } + if (kOnnxOpMap.find(ori_type) == kOnnxOpMap.end()) { + if (ge::PreChecker::Instance().CheckType(node) != SUCCESS) { + GELOGE(FAILED, "Check node_def type failed, node name: %s.", node->name().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} + +void OnnxModelParser::UpdateFormat(ge::Graph &graph) { + std::vector vec_op_name; + graph.GetAllOpName(vec_op_name); + ge::Format format = ge::FORMAT_NCHW; + for (string name: vec_op_name) { + ge::Operator op; + graph.FindOpByName(name, op); + auto op_dsc = ge::OpDescUtils::GetOpDescFromOperator(op); + auto input_size = op_dsc->GetAllInputsSize(); + for (size_t i = 0; i < input_size; i++) { + auto input = op_dsc->MutableInputDesc(static_cast(i)); + if (input == nullptr) { + continue; + } + input->SetFormat(format); + input->SetOriginFormat(format); + } + + auto output_size = op_dsc->GetOutputsSize(); + for (size_t i = 0; i < output_size; i++) { + auto output = op_dsc->GetOutputDesc(i); + output.SetFormat(format); + output.SetOriginFormat(format); + op_dsc->UpdateOutputDesc(i, output); + } + } +} + +Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { + for (int i = 0; i < onnx_graph.node_size(); i++) { + ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); + std::string node_name = node_proto->name(); + std::string ori_type = node_proto->op_type(); + GELOGI("Start parse node which name is %s, type is %s", node_name.c_str(), ori_type.c_str()); + + std::string op_type; + Status status = AdapterOpType(node_proto, ori_type, op_type); + if (status != SUCCESS) { + GELOGE(status, "Adapter op type for ori type %s failed.", ori_type.c_str()); + return status; + } + node_proto->set_op_type(ori_type); + + GELOGI("Trans original type:%s to op type:%s", ori_type.c_str(), op_type.c_str()); + + ge::Operator op; + status = TransNodeToOperator(node_proto, op, op_type); + if (status != SUCCESS) { + GELOGE(status, "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); + return status; + } + + // 7. op parser + std::shared_ptr factory = ge::OpParserFactory::Instance(domi::ONNX); + GE_CHECK_NOTNULL(factory); + std::shared_ptr op_parser = factory->CreateOpParser(op_type); + GE_CHECK_NOTNULL(op_parser); + std::shared_ptr onnx_op_parser = std::static_pointer_cast(op_parser); + GE_CHECK_NOTNULL(onnx_op_parser); + status = onnx_op_parser->ParseParams(node_proto, op); + if (status != SUCCESS) { + GELOGE(status, "Parse params for %s:%s failed.", node_name.c_str(), op_type.c_str()); + return status; + } + + ge::graphStatus graph_status = graph.AddOp(op); + if (graph_status != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "Add op:%s to graph failed.", op.GetName().c_str()); + return FAILED; + } + name_operator_[op.GetName()] = op; + + // 8. Construct input output relation of every node + status = ConstructInputOutputContext(node_proto); + if (status != SUCCESS) { + GELOGE(status, "Construct input output relation map failed."); + return status; + } + } + GELOGI("Parse all node proto success."); + return SUCCESS; +} + +Status OnnxModelParser::GetGraphInputsOutputs(std::vector &input_ops, + std::vector>> output_indexs) { + for (auto in_name : input_node_names_) { + auto in_op = name_operator_.find(in_name); + if (in_op == name_operator_.end()) { + GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.", + in_name.c_str()); + return PARAM_INVALID; + } + input_ops.emplace_back(in_op->second); + GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); + } + + for (auto it : output_node_names_) { + auto out_op = name_operator_.find(it); + if (out_op == name_operator_.end()) { + GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.", + it.c_str()); + return PARAM_INVALID; + } + output_indexs.emplace_back(out_op->second, std::vector{}); + GELOGI("Model assigned output node name: %s", out_op->second.GetName().c_str()); + } + return SUCCESS; +} + +Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { + GE_CHECK_NOTNULL(file); + GELOGI("File path is %s.", file); + + // 1. Get graph from onnx model file. + ge::onnx::ModelProto onnx_model; + if (!ge::parser::ReadProtoFromBinaryFile(file, &onnx_model)) { + GELOGE(PARAM_INVALID, "Read onnx model file failed."); + return FAILED; + } + if (!onnx_model.has_graph()) { + GELOGE(PARAM_INVALID, "Onnx model do not has graph."); + return FAILED; + } + ge::onnx::GraphProto onnx_graph = onnx_model.graph(); + + auto opset_import = onnx_model.opset_import(); + for (auto it : opset_import) { + domain_verseion_[it.domain()] = it.version(); + GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); + } + + // 2. Get all inializer. + std::map initializer_name_tensor; + for (int i = 0; i < onnx_graph.initializer_size(); i++) { + ge::onnx::TensorProto initializer_tensor = onnx_graph.initializer(i); + if (!initializer_tensor.name().empty()) { + initializer_name_tensor[initializer_tensor.name()] = initializer_tensor; + GELOGI("Initializer name: %s .", initializer_tensor.name().c_str()); + } + } + + // 3. Parse Input from graph. + GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); + Status ret = ParseInput(onnx_graph, initializer_name_tensor); + if (ret != SUCCESS) { + GELOGE(ret, "Parse input for onnx failed."); + return ret; + } + GELOGI("The size of initializer_name_tensor is %zu after ParseInput", initializer_name_tensor.size()); + + // 4. Parse Constant from graph. + ret = ParseInitializer(onnx_graph, initializer_name_tensor); + if (ret != SUCCESS) { + GELOGE(ret, "Parse initializer for onnx failed."); + return ret; + } + + // 5. Update node name for node do not has name. + ret = UpdateAllNodeName(onnx_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Update all node name for onnx failed."); + return ret; + } + + // 6 Precheck. + ret = Prechecker(onnx_graph); + bool is_precheck_failed = (ret != SUCCESS) || (ge::PreChecker::Instance().HasError()); + if (is_precheck_failed) { + GELOGE(FAILED, "Prechecker failed."); + return FAILED; + } + + if (ge::GetParserContext().run_mode == ge::ONLY_PRE_CHECK) { + GELOGI("Only prechecker."); + return SUCCESS; + } + + // 7. Construct all operator and input output tensor relation. + ret = ParseAllNodeProto(onnx_graph, graph); + if (ret != SUCCESS) { + GELOGE(ret, "Parse all node proto failed."); + return ret; + } + + // 8. Parse output from graph. + ret = ParseOutput(onnx_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Parse output failed."); + return ret; + } + + // 9. Set all operator input. + ret = SetOperatorInputs(); + if (ret != SUCCESS) { + GELOGE(ret, "Set operator input failed."); + return ret; + } + + std::vector op_names; + graph.GetAllOpName(op_names); + GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); + + // 10. Construct graph. + std::vector input_ops; + std::vector>> output_indexs; + ret = GetGraphInputsOutputs(input_ops, output_indexs); + if (ret != SUCCESS) { + GELOGE(ret, "Get graph inputs and outputs failed."); + return ret; + } + graph.SetInputs(input_ops).SetOutputs(output_indexs); + + UpdateFormat(graph); + + GELOGI("Onnx model parser success."); + return SUCCESS; +} + +Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { + if (model_file == nullptr) { + GELOGE(FAILED, "Model file is nullptr."); + return FAILED; + } + if (json_file == nullptr) { + GELOGE(FAILED, "Json file is nullptr."); + return FAILED; + } + + ge::onnx::ModelProto onnx_model; + GE_RETURN_WITH_LOG_IF_FALSE(ge::parser::ReadProtoFromBinaryFile(model_file, &onnx_model), + "ReadProtoFromBinaryFile failed, file:%s.", model_file); + ge::onnx::GraphProto graph_proto = onnx_model.graph(); + nlohmann::json j; + ge::Pb2Json::Message2Json(graph_proto, std::set(), j, true); + return ge::parser::ModelSaver::SaveJsonToFile(json_file, j); +} + +ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) { + return ge::OnnxUtil::ConvertOnnxDataType(type); +} + +} // namespace domi + +namespace domi { + REGISTER_MODEL_PARSER_CREATOR(ONNX, ge::OnnxModelParser); + REGISTER_WEIGHTS_PARSER_CREATOR(ONNX, ge::OnnxWeightsParser); +} \ No newline at end of file diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h new file mode 100644 index 0000000..bdb93c5 --- /dev/null +++ b/parser/onnx/onnx_parser.h @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_ONNX_ONNX_PARSER_H_ +#define PARSER_ONNX_ONNX_PARSER_H_ + +#include +#include +#include +#include "external/register/register_error_codes.h" +#include "omg/parser/model_parser.h" +#include "omg/parser/op_parser.h" +#include "omg/parser/weights_parser.h" +#include "proto/onnx/ge_onnx.pb.h" + +namespace ge { +class OnnxModelParser : public domi::ModelParser { + public: + OnnxModelParser() {} + virtual ~OnnxModelParser() {} + + Status Parse(const char *file, ge::Graph &graph) override; + + Status ToJson(const char *model_file, const char *json_file) override; + + ge::DataType ConvertToGeDataType(const uint32_t type) override; + + Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { return domi::SUCCESS; } + + Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override { + return domi::SUCCESS; + } + + Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, domi::GetGraphCallback callback, + ge::ComputeGraphPtr &graph) override { + return domi::SUCCESS; + } + + Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override { + return domi::SUCCESS; + } + + private: + Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); + + Status ParseInput(ge::onnx::GraphProto &onnx_graph, + std::map &initializer_name_tensor); + + Status ParseOutput(const ge::onnx::GraphProto &onnx_graph); + + Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, + std::map &initializer_name_tensor); + + Status UpdateAllNodeName(ge::onnx::GraphProto &onnx_graph); + + Status ConstructOriType(const ge::onnx::NodeProto *node_proto, std::string &ori_type); + + Status AdapterOpType(const ge::onnx::NodeProto *node_proto, std::string &ori_type, std::string &om_type); + + Status TransNodeToOperator(const ge::onnx::NodeProto *node_proto, ge::Operator &op, const string &op_type); + + Status ConstructInputOutputContext(const ge::onnx::NodeProto *node_proto); + + Status SetOperatorInputs(); + + Status GetGraphInputsOutputs(std::vector &input_ops, + std::vector>> output_indexs); + + Status Prechecker(ge::onnx::GraphProto &onnx_graph); + + void UpdateFormat(ge::Graph &graph); + + std::map ori_to_om_type_; + + std::map domain_verseion_; + + std::map name_operator_; + + std::vector input_node_names_; + + std::vector output_node_names_; + + std::unordered_map>> inputs_map_; + + std::unordered_map>> outputs_map_; +}; + +class OnnxWeightsParser : public domi::WeightsParser { + public: + Status Parse(const char *file, ge::Graph &graph) override { return domi::SUCCESS; } + + Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override { return domi::SUCCESS; } +}; +} // namespace domi +#endif // PARSER_ONNX_ONNX_PARSER_H_ diff --git a/parser/onnx/onnx_util.cc b/parser/onnx/onnx_util.cc new file mode 100644 index 0000000..d42ab39 --- /dev/null +++ b/parser/onnx/onnx_util.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "onnx_util.h" +#include + +namespace { +const std::map onnx_data_type_map = { + {OnnxDataType::UNDEFINED, ge::DataType::DT_UNDEFINED}, {OnnxDataType::FLOAT, ge::DataType::DT_FLOAT}, + {OnnxDataType::UINT8, ge::DataType::DT_UINT8}, {OnnxDataType::INT8, ge::DataType::DT_INT8}, + {OnnxDataType::UINT16, ge::DataType::DT_UINT16}, {OnnxDataType::INT16, ge::DataType::DT_INT16}, + {OnnxDataType::INT32, ge::DataType::DT_INT32}, {OnnxDataType::INT64, ge::DataType::DT_INT64}, + {OnnxDataType::STRING, ge::DataType::DT_STRING}, {OnnxDataType::BOOL, ge::DataType::DT_BOOL}, + {OnnxDataType::FLOAT16, ge::DataType::DT_FLOAT16}, {OnnxDataType::DOUBLE, ge::DataType::DT_DOUBLE}, + {OnnxDataType::UINT32, ge::DataType::DT_UINT32}, {OnnxDataType::UINT64, ge::DataType::DT_UINT64}, + {OnnxDataType::COMPLEX64, ge::DataType::DT_COMPLEX64}, {OnnxDataType::COMPLEX128, ge::DataType::DT_COMPLEX128}, + {OnnxDataType::BFLOAT16, ge::DataType::DT_UNDEFINED}, +}; + +const std::map onnx_data_type_size_map = { + {OnnxDataType::FLOAT, sizeof(float)}, {OnnxDataType::UINT8, sizeof(uint8_t)}, + {OnnxDataType::INT8, sizeof(int8_t)}, {OnnxDataType::UINT16, sizeof(uint16_t)}, + {OnnxDataType::INT16, sizeof(int16_t)}, {OnnxDataType::INT32, sizeof(int32_t)}, + {OnnxDataType::INT64, sizeof(int64_t)}, {OnnxDataType::STRING, sizeof(std::string)}, + {OnnxDataType::BOOL, sizeof(bool)}, {OnnxDataType::FLOAT16, 2}, + {OnnxDataType::DOUBLE, sizeof(double)}, {OnnxDataType::UINT32, sizeof(uint32_t)}, + {OnnxDataType::UINT64, sizeof(uint64_t)}, {OnnxDataType::COMPLEX64, 8}, + {OnnxDataType::COMPLEX128, 16}, {OnnxDataType::BFLOAT16, 2}, +}; +} + +namespace ge { +ge::DataType OnnxUtil::ConvertOnnxDataType(int64_t onnx_data_type) { + auto search = onnx_data_type_map.find(onnx_data_type); + if (search != onnx_data_type_map.end()) { + return search->second; + } else { + return ge::DataType::DT_UNDEFINED; + } +} + +int64_t OnnxUtil::CaculateDataSize(int64_t onnx_data_type) { + auto search = onnx_data_type_size_map.find(onnx_data_type); + if (search != onnx_data_type_size_map.end()) { + return search->second; + } else { + return ge::DataType::DT_UNDEFINED; + } +} +} // namespace ge diff --git a/parser/onnx/onnx_util.h b/parser/onnx/onnx_util.h new file mode 100644 index 0000000..259ed42 --- /dev/null +++ b/parser/onnx/onnx_util.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_ONNX_ONNX_UTIL_PARSER_H_ +#define PARSER_ONNX_ONNX_UTIL_PARSER_H_ + +#include "external/graph/types.h" + +namespace OnnxDataType { +enum OnnxDataType { + UNDEFINED = 0, + FLOAT = 1, + UINT8 = 2, + INT8 = 3, + UINT16 = 4, + INT16 = 5, + INT32 = 6, + INT64 = 7, + STRING = 8, + BOOL = 9, + FLOAT16 = 10, + DOUBLE = 11, + UINT32 = 12, + UINT64 = 13, + COMPLEX64 = 14, + COMPLEX128 = 15, + BFLOAT16 = 16, +}; +} + +namespace ge { +const char *const kAttrNameValue = "value"; +const char *const kAttrNameInput = "input_tensor"; +const char *const kAttrNameIndex = "index"; +const char *const kOpTypeConstant = "Constant"; +const char *const kOpTypeInput = "Input"; + +class OnnxUtil { + public: + static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); + static int64_t CaculateDataSize(int64_t onnx_data_type); +}; +} // namespace ge + +#endif //PARSER_ONNX_ONNX_UTIL_PARSER_H_ diff --git a/parser/onnx/proto/om.proto b/parser/onnx/proto/om.proto new file mode 100644 index 0000000..e15e5f8 --- /dev/null +++ b/parser/onnx/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/parser/onnx/proto/onnx/ge_onnx.proto b/parser/onnx/proto/onnx/ge_onnx.proto new file mode 100644 index 0000000..4cd77f3 --- /dev/null +++ b/parser/onnx/proto/onnx/ge_onnx.proto @@ -0,0 +1,563 @@ +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto3"; + +package ge.onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION = 0x0000000000000006; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value= 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} diff --git a/parser/ops/op_imp.cpp b/parser/ops/op_imp.cpp new file mode 100644 index 0000000..4f8e1de --- /dev/null +++ b/parser/ops/op_imp.cpp @@ -0,0 +1,63 @@ +#include +#include +#include +#include "debug/ge_log.h" +#include "debug/ge_util.h" + +using namespace std; + +namespace ge { + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +BroadCastInfer(const function()>& get_in1_shape, const function()>& get_in2_shape, + const function& outShape)>& set_out_shape) { + auto x1_shape = get_in1_shape(); + auto x2_shape = get_in2_shape(); + vector y_shape; + + if (x1_shape.empty()) { + y_shape = x2_shape; + set_out_shape(y_shape); + return GRAPH_SUCCESS; + } + if (x2_shape.empty()) { + y_shape = x1_shape; + set_out_shape(y_shape); + return GRAPH_SUCCESS; + } + + int len_diff = static_cast(x1_shape.size() - x2_shape.size()); + if (len_diff >= 0) { + for (int i = 0; i < len_diff; i++) { + y_shape.push_back(x1_shape[i]); + } + int x2_shape_size = static_cast(x2_shape.size()); + for (int i = 0; i < x2_shape_size; i++) { + bool shapeFlag = + ((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1)); + if (shapeFlag) { + GE_LOGE("operands could not be broadcast together"); + return GRAPH_FAILED; + } + y_shape.push_back(std::max(x1_shape[i + len_diff], x2_shape[i])); + } + } else { + for (int i = 0; i < -len_diff; i++) { + y_shape.push_back(x2_shape[i]); + } + int x1_shape_size = static_cast(x1_shape.size()); + for (int i = 0; i < x1_shape_size; i++) { + bool shapeFlag = + ((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1)); + if (shapeFlag) { + GE_LOGE("operands could not be broadcast together"); + return GRAPH_FAILED; + } + y_shape.push_back(std::max(x1_shape[i], x2_shape[i - len_diff])); + } + } + set_out_shape(y_shape); + return GRAPH_SUCCESS; +} + +} // namespace ge diff --git a/parser/proto/caffe/CMakeLists.txt b/parser/proto/caffe/CMakeLists.txt new file mode 100644 index 0000000..6d3999a --- /dev/null +++ b/parser/proto/caffe/CMakeLists.txt @@ -0,0 +1,27 @@ +set(PROTO_LIST + "${TOP_DIR}/inc/register/proto/caffe/caffe.proto" +) + +protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) + +############ lib_caffe_parser.so ############ +add_library(_caffe_parser SHARED ${PROTO_SRCS}) + +target_include_directories(_caffe_parser PRIVATE + ${CMAKE_CURRENT_LIST_DIR} +) + +target_link_libraries(_caffe_parser PRIVATE + $ + -Wl,--no-as-needed + protobuf + -Wl,--as-needed +) + +############ install ############ +set(INSTALL_BASE_DIR "") +set(INSTALL_LIBRARY_DIR lib) + +install(TARGETS _caffe_parser OPTIONAL + LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} +) diff --git a/parser/proto/caffe/caffe.proto b/parser/proto/caffe/caffe.proto new file mode 100644 index 0000000..3f45aae --- /dev/null +++ b/parser/proto/caffe/caffe.proto @@ -0,0 +1,1821 @@ +syntax = "proto2"; + +package domi.caffe; + +// Specifies the shape (dimensions) of a Blob. +message BlobShape { + repeated int64 dim = 1 [packed = true]; +} + +message BlobProto { + optional BlobShape shape = 7; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; + repeated double double_data = 8 [packed = true]; + repeated double double_diff = 9 [packed = true]; + optional bytes int8_data = 10; + repeated int32 int32_data = 11 [packed = true]; + repeated uint64 uint64_data = 12 [packed = true]; + // 4D dimensions -- deprecated. Use "shape" instead. + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero output weights for a given input in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; + // Normalize the filler variance by fan_in, fan_out, or their average. + // Applies to 'xavier' and 'msra' fillers. + enum VarianceNorm { + FAN_IN = 0; + FAN_OUT = 1; + AVERAGE = 2; + } + optional VarianceNorm variance_norm = 8 [default = FAN_IN]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + // DEPRECATED. See InputParameter. The input blobs to the network. + repeated string input = 3; + // DEPRECATED. See InputParameter. The shape of the input blobs. + repeated BlobShape input_shape = 8; + + // 4D input dimensions -- deprecated. Use "input_shape" instead. + // If specified, for each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; + + // Print debugging information about results while running Net::Forward, + // Net::Backward, and Net::Update. + optional bool debug_info = 7 [default = false]; + + // The layers that make up the net. Each of their configurations, including + // connectivity and behavior, is specified as a LayerParameter. + repeated LayerParameter layer = 100; // ID 100 so layers are printed last. + + // DEPRECATED: use 'layer' instead. + repeated V1LayerParameter layers = 2; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 42 (last added: layer_wise_reduce) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + // accumulate gradients over `iter_size` x `batch_size` instances + optional int32 iter_size = 36 [default = 1]; + + // The learning rate decay policy. The currently implemented learning rate + // policies are as follows: + // - fixed: always return base_lr. + // - step: return base_lr * gamma ^ (floor(iter / step)) + // - exp: return base_lr * gamma ^ iter + // - inv: return base_lr * (1 + gamma * iter) ^ (- power) + // - multistep: similar to step but it allows non uniform steps defined by + // stepvalue + // - poly: the effective learning rate follows a polynomial decay, to be + // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) + // - sigmoid: the effective learning rate follows a sigmod decay + // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) + // + // where base_lr, max_iter, gamma, step, stepvalue and power are defined + // in the solver parameter protocol buffer, and iter is the current iteration. + optional string lr_policy = 8; + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + + // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, + // whenever their actual L2 norm is larger. + optional float clip_gradients = 35 [default = -1]; + + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + enum SnapshotFormat { + HDF5 = 0; + BINARYPROTO = 1; + } + optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // type of the solver + optional string type = 40 [default = "SGD"]; + + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam + optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; + + // RMSProp decay value + // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) + optional float rms_decay = 38 [default = 0.99]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; + + // Overlap compute and communication for data parallel training + optional bool layer_wise_reduce = 41 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// Specifies training parameters (multipliers on global learning constants, +// and the name and other settings used for weight sharing). +message ParamSpec { + // The names of the parameter blobs -- useful for sharing parameters among + // layers, but never required otherwise. To share a parameter between two + // layers, give it a (non-empty) name. + optional string name = 1; + + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + optional DimCheckMode share_mode = 2; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + + // The multiplier on the global learning rate for this parameter. + optional float lr_mult = 3 [default = 1.0]; + + // The multiplier on the global weight decay for this parameter. + optional float decay_mult = 4 [default = 1.0]; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) +message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type + repeated string bottom = 3; // the name of each bottom blob + repeated string top = 4; // the name of each top blob + + // The train / test phase for computation. + optional Phase phase = 10; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 5; + + // Specifies training parameters (multipliers on global learning constants, + // and the name and other settings used for weight sharing). + repeated ParamSpec param = 6; + + // The blobs containing the numeric parameters of the layer. + repeated BlobProto blobs = 7; + + // Specifies whether to backpropagate to each bottom. If unspecified, + // Caffe will automatically infer whether each input needs backpropagation + // to compute parameter gradients. If set to true for some inputs, + // backpropagation to those inputs is forced; if set false for some inputs, + // backpropagation to those inputs is skipped. + // + // The size must be either 0 or equal to the number of bottoms. + repeated bool propagate_down = 11; + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 8; + repeated NetStateRule exclude = 9; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 100; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 101; + + // Layer type-specific parameters. + // + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + optional AccuracyParameter accuracy_param = 102; + optional ArgMaxParameter argmax_param = 103; + optional BatchNormParameter batch_norm_param = 139; + optional BiasParameter bias_param = 141; + optional ConcatParameter concat_param = 104; + optional ContrastiveLossParameter contrastive_loss_param = 105; + optional ConvolutionParameter convolution_param = 106; + optional CropParameter crop_param = 144; + optional DataParameter data_param = 107; + optional DetectionOutputParameter detection_output_param = 150; + optional DropoutParameter dropout_param = 108; + optional DummyDataParameter dummy_data_param = 109; + optional EltwiseParameter eltwise_param = 110; + optional ELUParameter elu_param = 140; + optional EmbedParameter embed_param = 137; + optional ExpParameter exp_param = 111; + optional FlattenParameter flatten_param = 135; + optional HDF5DataParameter hdf5_data_param = 112; + optional HDF5OutputParameter hdf5_output_param = 113; + optional HingeLossParameter hinge_loss_param = 114; + optional ImageDataParameter image_data_param = 115; + optional InfogainLossParameter infogain_loss_param = 116; + optional InnerProductParameter inner_product_param = 117; + optional InputParameter input_param = 143; + optional LogParameter log_param = 134; + optional LRNParameter lrn_param = 118; + optional MemoryDataParameter memory_data_param = 119; + optional MVNParameter mvn_param = 120; + optional ParameterParameter parameter_param = 145; + optional PoolingParameter pooling_param = 121; + optional PowerParameter power_param = 122; + optional PReLUParameter prelu_param = 131; + optional PythonParameter python_param = 130; + optional RecurrentParameter recurrent_param = 146; + optional ReductionParameter reduction_param = 136; + optional ReLUParameter relu_param = 123; + optional ReshapeParameter reshape_param = 133; + optional ScaleParameter scale_param = 142; + optional SigmoidParameter sigmoid_param = 124; + optional SmoothL1LossParameter smooth_l1_loss_param = 148; + optional SoftmaxParameter softmax_param = 125; + optional SPPParameter spp_param = 132; + optional SliceParameter slice_param = 126; + optional TanHParameter tanh_param = 127; + optional ThresholdParameter threshold_param = 128; + optional TileParameter tile_param = 138; + optional WindowDataParameter window_data_param = 129; + optional PermuteParameter permute_param = 202; + optional PriorBoxParameter prior_box_param = 203; + optional NormalizeParameter norm_param = 206; + optional PSROIPoolingParameter psroi_pooling_param = 207; + optional FreespaceExtractParameter freespace_extract_param = 151; + optional PostprocessParameter postprocess_param = 152; + optional SpatialTransformParameter spatial_transform_param = 153; + optional ROIAlignParameter roi_align_param = 154; + optional ReorgParameter reorg_param = 155; + optional RegionParameter region_param = 156; + optional ReverseParameter reverse_param = 157; + optional InterpParameter interp_param = 158; + optional ShuffleChannelParameter shuffle_channel_param = 159; + optional UpsampleParameter upsample_param = 160; + optional ROIPoolingParameter roi_pooling_param = 161; + optional YoloParameter yolo_param = 199; + optional YoloV3DetectionOutputParameter yolov3_detection_output_param = 200; + optional ProposalParameter proposal_param = 201; + optional FSRDetectionOutputParameter fsrdetectionoutput_param = 222; + optional SSDDetectionOutputParameter ssddetectionoutput_param = 232; + optional YoloV2DetectionOutputParameter yolov2_detection_output_param = 204; + optional QuantParameter quant_param = 208; + optional CondTakeParameter condtake_param = 233; + optional MatrixInverseParameter matrix_inverse_param = 210; + optional WarpPerspectiveParameter warp_perspective_param = 234; + optional BatchMatMulParameter batch_matmul_param = 235; + optional SpatialTransformerParameter st_param = 5000; + optional YoloV3DetectionOutputV2Parameter yolov3_detection_output_v2_param = 5001; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; + // Force the decoded image to have 3 color channels. + optional bool force_color = 6 [default = false]; + // Force the decoded image to have 1 color channels. + optional bool force_gray = 7 [default = false]; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // How to normalize the loss for loss layers that aggregate across batches, + // spatial dimensions, or other dimensions. Currently only implemented in + // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. + enum NormalizationMode { + // Divide by the number of examples in the batch times spatial dimensions. + // Outputs that receive the ignore label will NOT be ignored in computing + // the normalization factor. + FULL = 0; + // Divide by the total number of output locations that do not take the + // ignore_label. If ignore_label is not set, this behaves like FULL. + VALID = 1; + // Divide by the batch size. + BATCH_SIZE = 2; + // Do not normalize the loss. + NONE = 3; + } + // For historical reasons, the default normalization for + // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. + optional NormalizationMode normalization = 3 [default = VALID]; + // Deprecated. Ignored if normalization is specified. If normalization + // is not specified, then setting this to false will be equivalent to + // normalization = BATCH_SIZE to be consistent with previous behavior. + optional bool normalize = 2; +} + +// Messages that store parameters used by individual layer types follow, in +// alphabetical order. + +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; + + // If specified, ignore instances with the given label. + optional int32 ignore_label = 3; +} + +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; + // The axis along which to maximise -- may be negative to index from the + // end (e.g., -1 for the last axis). + // By default ArgMaxLayer maximizes over the flattened trailing dimensions + // for each index of the first / num dimension. + optional int32 axis = 3; +} + +message ConcatParameter { + // The axis along which to concatenate -- may be negative to index from the + // end (e.g., -1 for the last axis). Other axes must have the + // same dimension for all the bottom blobs. + // By default, ConcatLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 2 [default = 1]; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 concat_dim = 1 [default = 1]; +} + +message BatchNormParameter { + // If false, normalization is performed over the current mini-batch + // and global statistics are accumulated (but not yet used) by a moving + // average. + // If true, those accumulated mean and variance values are used for the + // normalization. + // By default, it is set to false when the network is in the training + // phase and true when the network is in the testing phase. + optional bool use_global_stats = 1; + // What fraction of the moving average remains each iteration? + // Smaller values make the moving average decay faster, giving more + // weight to the recent values. + // Each iteration updates the moving average @f$S_{t-1}@f$ with the + // current mean @f$ Y_t @f$ by + // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ + // is the moving_average_fraction parameter. + optional float moving_average_fraction = 2 [default = .999]; + // Small value to add to the variance estimate so that we don't divide by + // zero. + optional float eps = 3 [default = 1e-5]; +} + +message BiasParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar bias. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the bias + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to add a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the bias is + // a learned parameter of the layer.) + // The initialization for the learned bias parameter. + // Default is the zero (0) initialization, resulting in the BiasLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + optional bool bias_from_blob = 4 [default = true]; +} + +message ContrastiveLossParameter { + // margin for dissimilar pair + optional float margin = 1 [default = 1.0]; + // The first implementation of this cost did not exactly match the cost of + // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. + // legacy_version = false (the default) uses (margin - d)^2 as proposed in the + // Hadsell paper. New models should probably use this version. + // legacy_version = true uses (margin - d^2). This is kept to support / + // reproduce existing models and results + optional bool legacy_version = 2 [default = false]; +} + +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in all spatial dimensions, or once per spatial dimension. + repeated uint32 pad = 3; // The padding size; defaults to 0 + repeated uint32 kernel_size = 4; // The kernel size + repeated uint32 stride = 6; // The stride; defaults to 1 + // Factor used to dilate the kernel, (implicitly) zero-filling the resulting + // holes. (Kernel dilation is sometimes referred to by its use in the + // algorithme à trous from Holschneider et al. 1987.) + repeated uint32 dilation = 18; // The dilation; defaults to 1 + + // For 2D convolution only, the *_h and *_w versions may also be used to + // specify both spatial dimensions. + optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) + optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) + optional uint32 kernel_h = 11; // The kernel height (2D only) + optional uint32 kernel_w = 12; // The kernel width (2D only) + optional uint32 stride_h = 13; // The stride height (2D only) + optional uint32 stride_w = 14; // The stride width (2D only) + + optional uint32 group = 5 [default = 1]; // The group size for group conv + + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; + + // The axis to interpret as "channels" when performing convolution. + // Preceding dimensions are treated as independent inputs; + // succeeding dimensions are treated as "spatial". + // With (N, C, H, W) inputs, and axis == 1 (the default), we perform + // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for + // groups g>1) filters across the spatial axes (H, W) of the input. + // With (N, C, D, H, W) inputs, and axis == 1, we perform + // N independent 3D convolutions, sliding (C/g)-channels + // filters across the spatial axes (D, H, W) of the input. + optional int32 axis = 16 [default = 1]; + + // Whether to force use of the general ND convolution, even if a specific + // implementation for blobs of the appropriate number of spatial dimensions + // is available. (Currently, there is only a 2D-specific convolution + // implementation; for input blobs with num_axes != 2, this option is + // ignored and the ND implementation will be used.) + optional bool force_nd_im2col = 17 [default = false]; +} + +message CropParameter { + // To crop, elements of the first bottom are selected to fit the dimensions + // of the second, reference bottom. The crop is configured by + // - the crop `axis` to pick the dimensions for cropping + // - the crop `offset` to set the shift for all/each dimension + // to align the cropped bottom with the reference bottom. + // All dimensions up to but excluding `axis` are preserved, while + // the dimensions including and trailing `axis` are cropped. + // If only one `offset` is set, then all dimensions are offset by this amount. + // Otherwise, the number of offsets must equal the number of cropped axes to + // shift the crop in each dimension accordingly. + // Note: standard dimensions are N,C,H,W so the default is a spatial crop, + // and `axis` may be negative to index from the end (e.g., -1 for the last + // axis). + optional int32 axis = 1 [default = 2]; + repeated uint32 offset = 2; +} + +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + // Force the encoded image to have 3 color channels + optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Increase if data feeding bandwidth varies, within the + // limit of device memory for GPU training) + optional uint32 prefetch = 10 [default = 4]; +} + +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio + optional bool scale_train = 2 [default = true]; // scale train or test phase +} + +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // shape fields, and 0, 1 or N data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated BlobShape shape = 6; + + // 4D dimensions -- deprecated. Use "shape" instead. + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ELULayer +message ELUParameter { + // Described in: + // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate + // Deep Network Learning by Exponential Linear Units (ELUs). arXiv + optional float alpha = 1 [default = 1]; +} + +// Message that stores parameters used by EmbedLayer +message EmbedParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + // The input is given as integers to be interpreted as one-hot + // vector indices with dimension num_input. Hence num_input should be + // 1 greater than the maximum possible input value. + optional uint32 input_dim = 2; + + optional bool bias_term = 3 [default = true]; // Whether to use a bias term + optional FillerParameter weight_filler = 4; // The filler for the weight + optional FillerParameter bias_filler = 5; // The filler for the bias + +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +/// Message that stores parameters used by FlattenLayer +message FlattenParameter { + // The first axis to flatten: all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 1 [default = 1]; + + // The last axis to flatten: all following axes are retained in the output. + // May be negative to index from the end (e.g., the default -1 for the last + // axis). + optional int32 end_axis = 2 [default = -1]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; + + // Specify whether to shuffle the data. + // If shuffle == true, the ordering of the HDF5 files is shuffled, + // and the ordering of data within any given HDF5 file is shuffled, + // but data between different files are not interleaved; all of a file's + // data are output (in a random order) before moving onto another file. + optional bool shuffle = 3 [default = false]; +} + +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4 [default = 1]; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; + optional int32 axis = 2 [default = 1]; // axis of prob +} + +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias + + // The first axis to be lumped into a single inner product computation; + // all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 5 [default = 1]; + // Specify whether to transpose the weight matrix or not. + // If transpose == true, any operations will be performed on the transpose + // of the weight matrix. The weight matrix itself is not going to be transposed + // but rather the transfer flag of operations will be toggled accordingly. + optional bool transpose = 6 [default = false]; +} + +message InputParameter { + // This layer produces N >= 1 top blob(s) to be assigned manually. + // Define N shapes to set a shape for each top. + // Define 1 shape to set the same shape for every top. + // Define no shape to defer to reshaping manually. + repeated BlobShape shape = 1; +} + +// Message that stores parameters used by LogLayer +message LogParameter { + // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = ln(shift + scale * x) = log_e(shift + scale * x) + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; + + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 3 [default = 1e-9]; +} + +message ParameterParameter { + optional BlobShape shape = 1; +} + +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; + optional bool ceil_mode = 13 [default = true]; + // How to calculate the output size - using ceil (default) or floor rounding. + enum RoundMode { + CEIL = 0; + FLOOR = 1; + } + optional RoundMode round_mode = 14 [default = CEIL]; +} + +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +message PythonParameter { + optional string module = 1; + optional string layer = 2; + // This value is set to the attribute `param_str` of the `PythonLayer` object + // in Python before calling the `setup()` method. This could be a number, + // string, dictionary in Python dict format, JSON, etc. You may parse this + // string in `setup` method and use it in `forward` and `backward`. + optional string param_str = 3 [default = '']; + // Whether this PythonLayer is shared among worker solvers during data parallelism. + // If true, each worker solver sequentially run forward from this layer. + // This value should be set true if you are using it as a data layer. + optional bool share_in_parallel = 4 [default = false]; +} + +// Message that stores parameters used by RecurrentLayer +message RecurrentParameter { + // The dimension of the output (and usually hidden state) representation -- + // must be explicitly set to non-zero. + optional uint32 num_output = 1 [default = 0]; + + optional FillerParameter weight_filler = 2; // The filler for the weight + optional FillerParameter bias_filler = 3; // The filler for the bias + + // Whether to enable displaying debug_info in the unrolled recurrent net. + optional bool debug_info = 4 [default = false]; + + // Whether to add as additional inputs (bottoms) the initial hidden state + // blobs, and add as additional outputs (tops) the final timestep hidden state + // blobs. The number of additional bottom/top blobs required depends on the + // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. + optional bool expose_hidden = 5 [default = false]; +} + +// Message that stores parameters used by ReductionLayer +message ReductionParameter { + enum ReductionOp { + SUM = 1; + ASUM = 2; + SUMSQ = 3; + MEAN = 4; + } + + optional ReductionOp operation = 1 [default = SUM]; // reduction operation + + // The first axis to reduce to a scalar -- may be negative to index from the + // end (e.g., -1 for the last axis). + // (Currently, only reduction along ALL "tail" axes is supported; reduction + // of axis M through N, where N < num_axes - 1, is unsupported.) + // Suppose we have an n-axis bottom Blob with shape: + // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). + // If axis == m, the output Blob will have shape + // (d0, d1, d2, ..., d(m-1)), + // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) + // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. + // If axis == 0 (the default), the output Blob always has the empty shape + // (count 1), performing reduction across the entire input -- + // often useful for creating new loss functions. + optional int32 axis = 2 [default = 0]; + + optional float coeff = 3 [default = 1.0]; // coefficient for output +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +message ReshapeParameter { + // Specify the output dimensions. If some of the dimensions are set to 0, + // the corresponding dimension from the bottom layer is used (unchanged). + // Exactly one dimension may be set to -1, in which case its value is + // inferred from the count of the bottom blob and the remaining dimensions. + // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: + // + // layer { + // type: "Reshape" bottom: "input" top: "output" + // reshape_param { ... } + // } + // + // If "input" is 2D with shape 2 x 8, then the following reshape_param + // specifications are all equivalent, producing a 3D blob "output" with shape + // 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } + // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } + // + optional BlobShape shape = 1; + + // axis and num_axes control the portion of the bottom blob's shape that are + // replaced by (included in) the reshape. By default (axis == 0 and + // num_axes == -1), the entire bottom blob shape is included in the reshape, + // and hence the shape field must specify the entire output shape. + // + // axis may be non-zero to retain some portion of the beginning of the input + // shape (and may be negative to index from the end; e.g., -1 to begin the + // reshape after the last axis, including nothing in the reshape, + // -2 to include only the last axis, etc.). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are all equivalent, + // producing a blob "output" with shape 2 x 2 x 4: + // + // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } + // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } + // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } + // + // num_axes specifies the extent of the reshape. + // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on + // input axes in the range [axis, axis+num_axes]. + // num_axes may also be -1, the default, to include all remaining axes + // (starting from axis). + // + // For example, suppose "input" is a 2D blob with shape 2 x 8. + // Then the following ReshapeLayer specifications are equivalent, + // producing a blob "output" with shape 1 x 2 x 8. + // + // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } + // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } + // reshape_param { shape { dim: 1 } num_axes: 0 } + // + // On the other hand, these would produce output blob shape 2 x 1 x 8: + // + // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } + // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } + // + optional int32 axis = 2 [default = 0]; + optional int32 num_axes = 3 [default = -1]; +} + + +message ScaleParameter { + // The first axis of bottom[0] (the first input Blob) along which to apply + // bottom[1] (the second input Blob). May be negative to index from the end + // (e.g., -1 for the last axis). + // + // For example, if bottom[0] is 4D with shape 100x3x40x60, the output + // top[0] will have the same shape, and bottom[1] may have any of the + // following shapes (for the given value of axis): + // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 + // (axis == 1 == -3) 3; 3x40; 3x40x60 + // (axis == 2 == -2) 40; 40x60 + // (axis == 3 == -1) 60 + // Furthermore, bottom[1] may have the empty shape (regardless of the value of + // "axis") -- a scalar multiplier. + optional int32 axis = 1 [default = 1]; + + // (num_axes is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer. Otherwise, num_axes is determined by the + // number of axes by the second bottom.) + // The number of axes of the input (bottom[0]) covered by the scale + // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. + // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. + optional int32 num_axes = 2 [default = 1]; + + // (filler is ignored unless just one bottom is given and the scale is + // a learned parameter of the layer.) + // The initialization for the learned scale parameter. + // Default is the unit (1) initialization, resulting in the ScaleLayer + // initially performing the identity operation. + optional FillerParameter filler = 3; + + // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but + // may be more efficient). Initialized with bias_filler (defaults to 0). + optional bool bias_term = 4 [default = false]; + optional FillerParameter bias_filler = 5; + optional bool scale_from_blob = 6 [default = true]; +} + +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +message SliceParameter { + // The axis along which to slice -- may be negative to index from the end + // (e.g., -1 for the last axis). + // By default, SliceLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 3 [default = 1]; + repeated uint32 slice_point = 2; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 slice_dim = 1 [default = 1]; +} + +message SmoothL1LossParameter { + // SmoothL1Loss(x) = + // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma + // |x| - 0.5 / sigma / sigma -- otherwise + optional float sigma = 1 [default = 1]; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; + + // The axis along which to perform the softmax -- may be negative to index + // from the end (e.g., -1 for the last axis). + // Any other axes will be evaluated as independent softmaxes. + optional int32 axis = 2 [default = 1]; +} + +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by TileLayer +message TileParameter { + // The index of the axis to tile. + optional int32 axis = 1 [default = 1]; + + // The number of copies (tiles) of the blob to output. + optional int32 tiles = 2; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +message SPPParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional uint32 pyramid_height = 1; + optional PoolMethod pool = 2 [default = MAX]; // The pooling method + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; +} + +// DEPRECATED: use LayerParameter. +message V1LayerParameter { + repeated string bottom = 2; + repeated string top = 3; + optional string name = 4; + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + enum LayerType { + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + QUANT = 208; + DEQUANT = 209; + } + optional LayerType type = 5; + repeated BlobProto blobs = 6; + repeated string param = 1001; + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + STRICT = 0; + PERMISSIVE = 1; + } + repeated float blobs_lr = 7; + repeated float weight_decay = 8; + repeated float loss_weight = 35; + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + optional TransformationParameter transform_param = 36; + optional LossParameter loss_param = 42; + optional V0LayerParameter layer = 1; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} + +message PReLUParameter { + // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: + // Surpassing Human-Level Performance on ImageNet Classification, 2015. + + // Initial value of a_i. Default is a_i=0.25 for all i. + optional FillerParameter filler = 1; + // Whether or not slope parameters are shared across channels. + optional bool channel_shared = 2 [default = false]; +} + +// Message that stores parameters used by DetectionOutputLayer +//message DetectionOutputParameter { +// optional int32 num_classes = 1 [default = 21]; +// optional float nms_threshold = 2 [default = 0.3]; +// optional int32 top_k = 3; +// optional float confidence_threshold = 4 [default = 0.8]; +//} + +// Message that store parameters used by PriorBoxLayer +message PriorBoxParameter { + // Encode/decode type. + enum CodeType { + CORNER = 1; + CENTER_SIZE = 2; + CORNER_SIZE = 3; + } + // Minimum box size (in pixels). Required! + repeated float min_size = 1; + // Maximum box size (in pixels). Required! + repeated float max_size = 2; + // Various of aspect ratios. Duplicate ratios will be ignored. + // If none is provided, we use default ratio 1. + repeated float aspect_ratio = 3; + // If true, will flip each aspect ratio. + // For example, if there is aspect ratio "r", + // we will generate aspect ratio "1.0/r" as well. + optional bool flip = 4 [default = true]; + // If true, will clip the prior so that it is within [0, 1] + optional bool clip = 5 [default = false]; + // Variance for adjusting the prior bboxes. + repeated float variance = 6; + // By default, we calculate img_height, img_width, step_x, step_y based on + // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely + // provided. + // Explicitly provide the img_size. + optional uint32 img_size = 7; + // Either img_size or img_h/img_w should be specified; not both. + optional uint32 img_h = 8; + optional uint32 img_w = 9; + + // Explicitly provide the step size. + optional float step = 10; + // Either step or step_h/step_w should be specified; not both. + optional float step_h = 11; + optional float step_w = 12; + + // Offset to the top left corner of each cell. + optional float offset = 13 [default = 0.5]; +} + +// Message that stores parameters used by PermutetLayer +message PermuteParameter { + // The new orders of the axes of data. Notice it should be with + // in the same range as the input data, and it starts from 0. + // Do not provide repeated order. + repeated uint32 order = 1; +} + +message NormalizeParameter { + optional bool across_spatial = 1 [default = true]; + // Initial value of scale. Default is 1.0 for all + optional FillerParameter scale_filler = 2; + // Whether or not scale parameters are shared across channels. + optional bool channel_shared = 3 [default = true]; + // Epsilon for not dividing by zero while normalizing variance + optional float eps = 4 [default = 1e-10]; +} + +// needed by ssd +message SaveOutputParameter { + // Output directory. If not empty, we will save the results. + optional string output_directory = 1; + // Output name prefix. + optional string output_name_prefix = 2; + // Output format. + // VOC - PASCAL VOC output format. + // COCO - MS COCO output format. + optional string output_format = 3; + // If you want to output results, must also provide the following two files. + // Otherwise, we will ignore saving results. + // label map file. + optional string label_map_file = 4; + // A file which contains a list of names and sizes with same order + // of the input DB. The file is in the following format: + // name height width + // ... + optional string name_size_file = 5; + // Number of test images. It can be less than the lines specified in + // name_size_file. For example, when we only want to evaluate on part + // of the test images. + optional uint32 num_test_image = 6; + // The resize parameter used in saving the data. + // optional ResizeParameter resize_param = 7; +} + +message NonMaximumSuppressionParameter { + // Threshold to be used in nms. + optional float nms_threshold = 1 [default = 0.3]; + // Maximum number of results to be kept. + optional int32 top_k = 2; + // Parameter for adaptive nms. + optional float eta = 3 [default = 1.0]; +} + +message GeneralNmsParameter { + optional int32 post_top_k = 1 ; + optional float nms_threshold = 2 [default = 0]; + optional float iou_threshold_decay = 3 [default = 1.0]; + optional float coor_scale_factor = 4 [default = 1.0]; +} + +// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn +message DetectionOutputParameter { + optional int32 num_classes = 1; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional NonMaximumSuppressionParameter nms_param = 4; + optional SaveOutputParameter save_output_param = 5; + optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; + optional bool variance_encoded_in_target = 8 [default = true]; + optional int32 keep_top_k = 7; + optional float confidence_threshold = 9; + optional float nms_threshold = 13; + optional int32 top_k = 14; + optional int32 boxes = 15 [default = 1]; + optional bool relative = 17 [default = true]; + optional float objectness_threshold = 18 [default = 0.5]; + optional float class_threshold = 19 [default = 0.5]; + repeated float biases = 20; + optional GeneralNmsParameter general_nms_param = 21; + optional float objectness_score = 22; +} +message PSROIPoolingParameter { + required float spatial_scale = 1; + required int32 output_dim = 2; // output channel number + required int32 group_size = 3; // number of groups to encode position-sensitive score maps +} +// Message that stores parameters used by FreespaceExtractLayer +message FreespaceExtractParameter { + optional float org_height = 1; +} + +// Message that stores parameters used by DetectpostprocessLayer +message PostprocessParameter { + optional float nms_thresh = 1 [default = 0.3]; + optional float conf_thresh = 2 [default = 0.5]; + optional uint32 post_nms_topn = 3 [default = 100]; + optional uint32 cls_num = 4 [default = 12]; + repeated float bbox_reg_weights = 5; +} + +// Message that stores parameters used by SpatialTransformLayer +message SpatialTransformParameter { + optional uint32 output_h = 1 [default = 0]; + optional uint32 output_w = 2 [default = 0]; + optional float border_value = 3 [default = 0]; + repeated float affine_transform = 4; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} +message ROIAlignParameter { + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pooled_h = 1 [default = 0]; // The pooled output height + optional uint32 pooled_w = 2 [default = 0]; // The pooled output width + // Multiplicative spatial scale factor to translate ROI coords from their + // input scale to the scale used when pooling + optional float spatial_scale = 3 [default = 1]; + optional int32 sampling_ratio = 4 [default = -1]; + optional int32 roi_end_mode = 5 [default = 0]; +} + +message RegionParameter { + optional uint32 classes = 1 [default = 20]; // Category of classification + optional uint32 coords = 2 [default = 4]; // Coordinates of box + optional uint32 boxes = 3 [default = 1]; // Number of boxes predicted per grid + optional uint32 softmax = 4 [default = 0]; + optional string softmax_tree = 5 [default = ""]; + optional uint32 background = 6 [default = 0]; +} +message ReorgParameter{ + optional uint32 stride = 2 [default = 2]; + optional bool reverse = 1 [default = false]; +} +message ReverseParameter{ + repeated int32 axis = 1; +} +message InterpParameter{ + optional int32 height = 1 [default = 0];//Height of output + optional int32 width = 2 [default = 0];//Width of output + optional int32 zoom_factor = 3 [default = 1];//zoom factor + optional int32 shrink_factor = 4 [default = 1];//shrink factor + optional int32 pad_beg = 5 [default = 0];//padding at begin of input + optional int32 pad_end = 6 [default = 0];//padding at end of input +} +message ShuffleChannelParameter{ + optional uint32 group = 1[default = 1]; // The number of group +} +message UpsampleParameter{ + optional float scale = 1[default = 1]; + optional int32 stride = 2[default = 2]; + optional int32 stride_h = 3[default = 2]; + optional int32 stride_w = 4[default=2]; +} +message ROIPoolingParameter { + required int32 pooled_h = 1; + required int32 pooled_w = 2; + optional float spatial_scale = 3 [default=0.0625]; + optional float spatial_scale_h = 4; + optional float spatial_scale_w = 5; +} + +message YoloParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 coords = 2 [default = 4]; + optional int32 classes = 3 [default = 80]; + optional string yolo_version = 4 [default = "V3"]; + optional bool softmax = 5 [default = false]; + optional bool background = 6 [default = false]; + optional bool softmaxtree = 7 [default = false]; +} + +message YoloV3DetectionOutputParameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; +} + +message YoloV3DetectionOutputV2Parameter { + optional int32 boxes = 1 [default = 3]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases_high = 9; + repeated float biases_mid = 10; + repeated float biases_low = 11; + optional int32 coords = 12 [default = 4]; + repeated float biases = 13; + optional bool resize_origin_img_to_net = 14 [default = false]; + optional int32 out_box_dim = 15 [default = 3]; +} + +message ProposalParameter { + optional float feat_stride = 1 [default = 16]; + optional float base_size = 2 [default = 16]; + optional float min_size = 3 [default = 16]; + repeated float ratio = 4; + repeated float scale = 5; + optional int32 pre_nms_topn = 6 [default = 3000]; + optional int32 post_nms_topn = 7 [default = 304]; + optional float iou_threshold = 8 [default = 0.7]; + optional bool output_actual_rois_num = 9 [default = false]; +} + +message FSRDetectionOutputParameter { + required int32 num_classes = 1; + required float score_threshold = 2; + required float iou_threshold = 3; + optional int32 batch_rois = 4 [default = 1]; +} + +message SSDDetectionOutputParameter { + required int32 num_classes= 1 [default = 2]; + optional bool share_location = 2 [default = true]; + optional int32 background_label_id = 3 [default = 0]; + optional float iou_threshold = 4 [default = 0.3]; + optional int32 top_k = 5 [default = 200]; + optional float eta = 6 [default = 1.0]; + optional bool variance_encoded_in_target = 7 [default = false]; + optional int32 code_type = 8 [default = 1]; + optional int32 keep_top_k = 9 [default = -1]; + optional float confidence_threshold = 10 [default = 0.0]; +} +message YoloV2DetectionOutputParameter { + optional int32 boxes = 1 [default = 5]; + optional int32 classes = 2 [default = 80]; + optional bool relative = 3 [default = true]; + optional float obj_threshold = 4 [default = 0.5]; + optional float score_threshold = 5 [default = 0.5]; + optional float iou_threshold = 6 [default = 0.45]; + optional int32 pre_nms_topn = 7 [default = 512]; + optional int32 post_nms_topn = 8 [default = 1024]; + repeated float biases = 9; + optional int32 coords = 10 [default = 4]; + optional bool resize_origin_img_to_net = 11 [default = false]; +} + +message QuantParameter { + optional float scale = 2; + optional bytes offset = 3; +} + +message BatchMatMulParameter{ + optional bool adj_x1 = 1 [default = false]; + optional bool adj_x2 = 2 [default = false]; +} + +message CondTakeParameter { + required string mode = 1; + required float val = 2; + optional float eps = 3 [default = 1e-06]; +} + +message MatrixInverseParameter { + optional bool adjoint = 1 [default = false]; +} + +message WarpPerspectiveParameter { + required int32 out_height = 1; + required int32 out_width = 2; + optional float constant = 3; + optional string border_type = 4 [default = 'BORDER_CONSTANT']; +} + +message SpatialTransformerParameter { + // How to use the parameter passed by localisation network + optional string transform_type = 1 [default = "affine"]; + // What is the sampling technique + optional string sampler_type = 2 [default = "bilinear"]; + + // If not set,stay same with the input dimension H and W + optional int32 output_H = 3; + optional int32 output_W = 4; + // If false, only compute dTheta, DO NOT compute dU + optional bool to_compute_dU = 5 [default = true]; + + // The default value for some parameters + optional double theta_1_1 = 6; + optional double theta_1_2 = 7; + optional double theta_1_3 = 8; + optional double theta_2_1 = 9; + optional double theta_2_2 = 10; + optional double theta_2_3 = 11; +} diff --git a/parser/proto/caffe/module.mk b/parser/proto/caffe/module.mk new file mode 100644 index 0000000..a4c6fdd --- /dev/null +++ b/parser/proto/caffe/module.mk @@ -0,0 +1,20 @@ +LOCAL_PATH := $(call my-dir) + +include $(CLEAR_VARS) + +LOCAL_MODULE := lib_caffe_parser + +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + +LOCAL_SRC_FILES := \ + caffe.proto \ + +LOCAL_C_INCLUDES := \ + third_party/protobuf/include \ + +LOCAL_SHARED_LIBRARIES := \ + libprotobuf \ + +include $(BUILD_HOST_SHARED_LIBRARY) diff --git a/parser/proto/ge_ir.proto b/parser/proto/ge_ir.proto new file mode 100644 index 0000000..e7bfe0c --- /dev/null +++ b/parser/proto/ge_ir.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/parser/proto/insert_op.proto b/parser/proto/insert_op.proto new file mode 100644 index 0000000..c635ca1 --- /dev/null +++ b/parser/proto/insert_op.proto @@ -0,0 +1,136 @@ +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ־̬AIPPͶ̬AIPP + AippMode aipp_mode = 1; + + // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 + // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 + uint32 related_input_rank = 2; + + // input_edge_idxΪѡΪͣ÷ΧΪ>=0 + // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP + // ֵ <= Dataߵĸ + repeated uint32 input_edge_idx = 3; + + // [Begin] ̬AIPPþ̬AIPPʱЧ + uint32 max_src_image_size = 4; + + // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ + bool support_rotation = 5; + + // [End] ̬AIPP + + + // [Begin] ̬AIPPö̬AIPPʱЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ̬AIPP + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //̬batch + resolution = 1; //ֱ̬ʣչ + } + + MultiShapeMode mode = 1; //ģʽ + uint32 related_input_rank = 2; //Ӳ뵽ĸ + + + repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ +} diff --git a/parser/proto/om.proto b/parser/proto/om.proto new file mode 100644 index 0000000..e15e5f8 --- /dev/null +++ b/parser/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/parser/proto/task.proto b/parser/proto/task.proto new file mode 100644 index 0000000..d0c0984 --- /dev/null +++ b/parser/proto/task.proto @@ -0,0 +1,165 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/parser/proto/tensorflow/attr_value.proto b/parser/proto/tensorflow/attr_value.proto new file mode 100644 index 0000000..1cc67d6 --- /dev/null +++ b/parser/proto/tensorflow/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/parser/proto/tensorflow/function.proto b/parser/proto/tensorflow/function.proto new file mode 100644 index 0000000..075897c --- /dev/null +++ b/parser/proto/tensorflow/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/parser/proto/tensorflow/graph.proto b/parser/proto/tensorflow/graph.proto new file mode 100644 index 0000000..d639a7d --- /dev/null +++ b/parser/proto/tensorflow/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/parser/proto/tensorflow/graph_library.proto b/parser/proto/tensorflow/graph_library.proto new file mode 100644 index 0000000..e393d38 --- /dev/null +++ b/parser/proto/tensorflow/graph_library.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package domi.tensorflow; + +import "graph.proto"; + +message GeGraphDef { + string name = 1; + GraphDef graph = 2; +} + +message GraphDefLibrary { + repeated GeGraphDef graph_def = 1; +}; \ No newline at end of file diff --git a/parser/proto/tensorflow/node_def.proto b/parser/proto/tensorflow/node_def.proto new file mode 100644 index 0000000..b9bc97e --- /dev/null +++ b/parser/proto/tensorflow/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/parser/proto/tensorflow/op_def.proto b/parser/proto/tensorflow/op_def.proto new file mode 100644 index 0000000..3485d04 --- /dev/null +++ b/parser/proto/tensorflow/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/parser/proto/tensorflow/resource_handle.proto b/parser/proto/tensorflow/resource_handle.proto new file mode 100644 index 0000000..a345235 --- /dev/null +++ b/parser/proto/tensorflow/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/parser/proto/tensorflow/tensor.proto b/parser/proto/tensorflow/tensor.proto new file mode 100644 index 0000000..d0a4d02 --- /dev/null +++ b/parser/proto/tensorflow/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/parser/proto/tensorflow/tensor_shape.proto b/parser/proto/tensorflow/tensor_shape.proto new file mode 100644 index 0000000..4225a2e --- /dev/null +++ b/parser/proto/tensorflow/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/parser/proto/tensorflow/types.proto b/parser/proto/tensorflow/types.proto new file mode 100644 index 0000000..ba7a72b --- /dev/null +++ b/parser/proto/tensorflow/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/parser/proto/tensorflow/versions.proto b/parser/proto/tensorflow/versions.proto new file mode 100644 index 0000000..4806121 --- /dev/null +++ b/parser/proto/tensorflow/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/parser/tensorflow/graph_functiondef.cc b/parser/tensorflow/graph_functiondef.cc new file mode 100644 index 0000000..0a242f8 --- /dev/null +++ b/parser/tensorflow/graph_functiondef.cc @@ -0,0 +1,547 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph_functiondef.h" +#include +#include "common/fmk_error_codes.h" +#include "graph/debug/ge_attr_define.h" +#include "framework/omg/parser/parser_types.h" +#include "parser/common/acl_graph_parser_util.h" +#include "common/types_map.h" +#include "common/util.h" +#include "graph/op_desc.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "framework/common/ge_inner_error_codes.h" + +namespace { +constexpr char UNKNOWN[] = "unknown"; +constexpr char UNDERLINE = '_'; +} // namespace +namespace ge { +using AttrValueMap = ::google::protobuf::Map; +vector arg_datetypes_; +vector result_datetypes_; + +string NameMapHelper::GetUniqueName(const string &name) { + if (used_names_.insert(name).second) { + return name; + } + int i = 0; + while (true) { + const string candidate = name + "_" + to_string(i); + if (used_names_.insert(candidate).second) { + return candidate; + } + ++i; + } +} + +string NameMapHelper::UniqueInputOrOutputName(const string &name) { + // Normalize first + string normalized = name; + if (name.empty()) { + normalized = UNKNOWN; + } + for (auto ch : normalized) { + if (!isalnum(ch)) { + ch = UNDERLINE; + } else if (isupper(ch)) { + ch = tolower(ch); + } + } + // uniquify + const string unique_name = GetUniqueName(normalized); + name_mapping_[name] = unique_name; + return unique_name; +} + +string NameMapHelper::UniqueNodeName(const string &name) { + // uniquify + const string unique_name = GetUniqueName(name); + name_mapping_[name] = unique_name; + return unique_name; +} + +string NameMapHelper::Renormalize(const string &name) const { + const auto iter = name_mapping_.find(name); + if (iter == name_mapping_.end()) return string(); + return iter->second; +} + +domi::Status ComputeArgRange(const domi::tensorflow::NodeDef &node_def, const domi::tensorflow::OpDef::ArgDef &arg_def, + const domi::tensorflow::OpDef &op_def, int *num) { + GE_CHECK_NOTNULL(num); + if (!arg_def.number_attr().empty()) { + // Same type repeated "num" times. + domi::tensorflow::AttrValue attr_value; + // Get attribute number_att, if the attribute does not exist, return failure + GE_IF_BOOL_EXEC( + !GraphToFunctionDef::FindAttrValue(&node_def, arg_def.number_attr(), attr_value), + GELOGE(domi::INTERNAL_ERROR, "In NodeDef %s Attr number_attr is not exist.", node_def.name().c_str()); + return domi::INTERNAL_ERROR); + *num = attr_value.i(); + } else if (!arg_def.type_list_attr().empty()) { + domi::tensorflow::AttrValue attr_value; + /// Get the attribute type_list_attr, if the attribute does not exist, return + /// failure + GE_IF_BOOL_EXEC( + !GraphToFunctionDef::FindAttrValue(&node_def, arg_def.type_list_attr(), attr_value), + GELOGE(domi::INTERNAL_ERROR, "In NodeDef %s Attr type_list_attr is not exist.", node_def.name().c_str()); + return domi::INTERNAL_ERROR); + *num = attr_value.list().type_size(); + } else if ((!arg_def.type_attr().empty()) || (arg_def.type() != DT_INVALID)) { + *num = 1; + } else { + GELOGE(domi::INTERNAL_ERROR, "In NodeDef %s Attr type_list_attr is not exist.", node_def.name().c_str()); + return domi::INTERNAL_ERROR; + } + return SUCCESS; +} + +using NameRangeMap = std::unordered_map>; + +domi::Status NameRangesHelper(const domi::tensorflow::NodeDef &node_def, + const google::protobuf::RepeatedPtrField &args, + const domi::tensorflow::OpDef &op_def, NameRangeMap *result) { + GE_CHECK_NOTNULL(result); + int start = 0; + int num = 0; + for (const auto &arg : args) { + GE_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num)); + (*result)[arg.name()] = std::make_pair(start, start + num); + start += num; + } + return SUCCESS; +} + +domi::Status NameRangesForNode(const domi::tensorflow::NodeDef &node_def, const domi::tensorflow::OpDef &op_def, + NameRangeMap *outputs) { + GE_IF_BOOL_EXEC(outputs == nullptr, return FAILED); + + return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs); +} + +domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelper &node_names, + std::unordered_map &tensor_renaming, + std::unordered_map &return_values) { + GE_CHECK_NOTNULL(fdef); + // Detect missing function inputs.. + for (int i = 0; i < fdef->signature().input_arg_size(); ++i) { + const string &input_name = fdef->signature().input_arg(i).name(); + GE_IF_BOOL_EXEC(input_name.empty(), + GELOGE(domi::INTERNAL_ERROR, "In fdef %s input_name null .", fdef->signature().name().c_str()); + return domi::INTERNAL_ERROR); + } + + /// Remap input names. We do this as a second pass to allow the nodes to be in + /// any order. + for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) { + NodeDef *node_def = fdef->mutable_node_def(n_index); + for (int i = 0; i < node_def->input_size(); ++i) { + if (node_def->input(i).find("^") != string::npos) { + // Control input + const string normalized = node_names.Renormalize(node_def->input(i).substr(1)); + + GE_IF_BOOL_EXEC(normalized.empty(), + GELOGE(domi::INTERNAL_ERROR, "Could not remap control input %s of node %s in function %s .", + node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); + return domi::INTERNAL_ERROR); + + *node_def->mutable_input(i) = "^" + normalized; + } else { + const auto iter = tensor_renaming.find(node_def->input(i)); + + GE_IF_BOOL_EXEC(iter == tensor_renaming.end(), + GELOGE(domi::INTERNAL_ERROR, "Could not remap input %s of node %s in function %s .", + node_def->input(i).c_str(), node_def->name().c_str(), name.c_str()); + return domi::INTERNAL_ERROR); + + *node_def->mutable_input(i) = iter->second; + } + } + } + + // Remap return values. + for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { + const string &ret_name = fdef->signature().output_arg(r).name(); + + GE_IF_BOOL_EXEC(ret_name.empty(), + GELOGE(domi::INTERNAL_ERROR, "Missing output %d to function %s .", r, name.c_str()); + return domi::INTERNAL_ERROR); + + const string &return_value = return_values[ret_name]; + + GE_IF_BOOL_EXEC(return_value.empty(), + GELOGE(domi::INTERNAL_ERROR, "Could not remap return value %d ,%s of %s in function %s .", r, + ret_name.c_str(), return_value.c_str(), name.c_str()); + return domi::INTERNAL_ERROR); + + const auto iter = tensor_renaming.find(return_value); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter == tensor_renaming.end(), return domi::INTERNAL_ERROR, + "can not find value[%s] in tensor_renaming map.", return_value.c_str()); + + (*fdef->mutable_ret())[ret_name] = iter->second; + } + + return SUCCESS; +} + +// Add output operator for graph before converting func +domi::Status GraphToFunctionDef::RecordResult(ge::ComputeGraphPtr graph, + const vector &out_anchor) { + GE_CHECK_NOTNULL(graph); + int32_t index = 0; + result_datetypes_.clear(); + for (const auto &anchor : out_anchor) { + GE_CHECK_NOTNULL(anchor); + GE_CHECK_NOTNULL(anchor->GetOwnerNode()->GetOpDesc()); + int32_t type = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()).GetDataType(); + auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); + GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), + GELOGE(PARAM_INVALID, "data_type %d not supported", type); + return PARAM_INVALID); + int32_t dtype = iter->second; + + string op_name = anchor->GetOwnerNode()->GetName() + "_" + to_string(anchor->GetIdx()) + "_retval"; + ge::OpDescPtr op = nullptr; + GE_MAKE_SHARED(op = std::make_shared(op_name, ge::parser::NETOUTPUT), return FAILED); + graphStatus status = op->AddInputDesc(ge::GeTensorDesc()); + if (status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add input desc for op:%s failed.", op->GetName().c_str()); + return FAILED; + } + status = op->AddOutputDesc(ge::GeTensorDesc()); + if (status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add output desc for op:%s failed.", op->GetName().c_str()); + return FAILED; + } + (void)ge::AttrUtils::SetInt(op, "T", static_cast(dtype)); + (void)ge::AttrUtils::SetInt(op, "ret_index", static_cast(index)); + ge::NodePtr res_node = graph->AddNode(op); + GE_CHECK_NOTNULL(res_node); + bool node_exists = false; + for (const ge::NodePtr &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + if (node->GetName() == anchor->GetOwnerNode()->GetName()) { + ge::OutDataAnchorPtr out_archor_ptr = node->GetOutDataAnchor(anchor->GetIdx()); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = res_node->GetInDataAnchor(0); + GE_CHECK_NOTNULL(in_archor_ptr); + ge::graphStatus ret = ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr); + if (ret != ge::GRAPH_SUCCESS) { + GELOGE(domi::INTERNAL_ERROR, "Add edge failed,src op:%s,dst op:%s", node->GetName().c_str(), + res_node->GetName().c_str()); + return FAILED; + } + node_exists = true; + } + } + GE_IF_BOOL_EXEC(!node_exists, GELOGE(FAILED, "node not exists!"); return FAILED); + result_datetypes_.emplace_back(domi::tensorflow::DataType(dtype)); + + index++; + } + return SUCCESS; +} + +/// Add input operator for graph before converting function. +/// Input operator will generate input parameters during function conversion +domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vector &in_anchor) { + GE_CHECK_NOTNULL(graph); + int32_t index = 0; + arg_datetypes_.clear(); + for (const auto &anchor : in_anchor) { + GE_CHECK_NOTNULL(anchor); + GE_CHECK_NOTNULL(anchor->GetOwnerNode()->GetOpDesc()); + auto tensor_desc_ptr = anchor->GetOwnerNode()->GetOpDesc()->GetInputDescPtr(anchor->GetIdx()); + GE_CHECK_NOTNULL_EXEC(tensor_desc_ptr, return domi::FAILED); + + int32_t type = tensor_desc_ptr->GetDataType(); + auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find(type); + GE_IF_BOOL_EXEC(iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), + GELOGE(PARAM_INVALID, "data_type %d not supported", type); + return PARAM_INVALID); + int32_t dtype = iter->second; + + GE_CHECK_NOTNULL(anchor->GetPeerOutAnchor()); + string op_name = anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName() + "_" + + to_string(anchor->GetPeerOutAnchor()->GetIdx()) + "_arg"; + ge::OpDescPtr op = nullptr; + GE_MAKE_SHARED(op = std::make_shared(op_name, ge::parser::DATA), return FAILED); + graphStatus status = op->AddOutputDesc(ge::GeTensorDesc()); + if (status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add output desc for op:%s failed.", op->GetName().c_str()); + return FAILED; + } + + (void)ge::AttrUtils::SetInt(op, "T", (int32_t)dtype); + (void)ge::AttrUtils::SetInt(op, "arg_index", (int32_t)index); + ge::NodePtr arg_node = graph->AddNode(op); + GE_CHECK_NOTNULL(arg_node); + bool node_exists = false; + for (const auto &node : graph->GetDirectNode()) { + if (node->GetName() == anchor->GetOwnerNode()->GetName()) { + ge::OutDataAnchorPtr out_archor_ptr = arg_node->GetOutDataAnchor(0); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(anchor->GetPeerOutAnchor()->GetIdx()); + GE_CHECK_NOTNULL(in_archor_ptr); + (void)ge::GraphUtils::RemoveEdge(in_archor_ptr->GetPeerOutAnchor(), in_archor_ptr); + ge::graphStatus ret = ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr); + if (ret != ge::GRAPH_SUCCESS) { + GELOGE(domi::INTERNAL_ERROR, "Add edge failed,src op:%s,dst op:%s", arg_node->GetName().c_str(), + node->GetName().c_str()); + return FAILED; + } + node_exists = true; + } + } + GE_IF_BOOL_EXEC(!node_exists, GELOGE(FAILED, "node not exists!"); return FAILED); + arg_datetypes_.emplace_back(domi::tensorflow::DataType(dtype)); + index++; + } + return SUCCESS; +} + +// Convert Davinci's graph to tensorflow's functiondef +domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph, const string &name, + FunctionDef *fdef) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(fdef); + fdef->mutable_signature()->set_name(name); + + std::unordered_map tensor_renaming; + std::unordered_map return_values; + NameMapHelper node_names; + + for (const ge::NodePtr &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + if (node->GetOpDesc()->GetType() == ge::parser::DATA) { + int64_t index = 0; + + int64_t type = 1; + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, + "Get type attr failed"); + + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "arg_index", index), PARAM_INVALID, + "Get arg_index attr failed"); + + while (fdef->signature().input_arg_size() <= index) { + fdef->mutable_signature()->add_input_arg(); + } + domi::tensorflow::OpDef::ArgDef *argdef = fdef->mutable_signature()->mutable_input_arg(index); + argdef->set_type(domi::tensorflow::DataType(type)); + const string normalized = node_names.UniqueInputOrOutputName(node->GetName()); + argdef->set_name(normalized); + tensor_renaming[node->GetName() + ":0"] = normalized; + continue; + } + + if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) { + int64_t index = 0; + int64_t type = 1; + + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "T", type), PARAM_INVALID, + "Get type attr failed"); + + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(node->GetOpDesc(), "ret_index", index), PARAM_INVALID, + "Get arg_index attr failed"); + + while (fdef->signature().output_arg_size() <= index) { + fdef->mutable_signature()->add_output_arg(); + } + + domi::tensorflow::OpDef::ArgDef *argdef = fdef->mutable_signature()->mutable_output_arg(index); + argdef->set_type(domi::tensorflow::DataType(type)); + const string normalized = node_names.UniqueInputOrOutputName(node->GetName()); + argdef->set_name(normalized); + + ge::OutDataAnchorPtr o_anchor = node->GetAllInDataAnchors().at(0)->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(o_anchor); + string n_name = o_anchor->GetOwnerNode()->GetName() + ":" + to_string(o_anchor->GetIdx()); + return_values[normalized] = n_name; + continue; + } + + // Analysis of nodedef of original tensorflow + ge::GeAttrValue::BYTES nodedef_bytes; + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetBytes(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_NODE_DEF, nodedef_bytes), + PARAM_INVALID, "Get type attr nodedef failed."); + domi::tensorflow::NodeDef node_def_; + GE_CHK_BOOL_RET_STATUS(node_def_.ParseFromArray(nodedef_bytes.GetData(), nodedef_bytes.GetSize()), PARAM_INVALID, + "parse nodedef failed."); + + // Analysis of opdef of original tensorflow + string opdef_string; + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_OP_DEF, opdef_string), + PARAM_INVALID, "Get type attr op_def failed."); + + domi::tensorflow::OpDef op_def; + GE_CHK_BOOL_RET_STATUS(op_def.ParseFromString(opdef_string), PARAM_INVALID, "parse op_def failed."); + + // add nodedef + NodeDef *node_def = fdef->add_node_def(); + *node_def = node_def_; + + node_def->mutable_attr()->erase(ge::ATTR_NAME_FRAMEWORK_OP_DEF); + node_def->mutable_attr()->erase(ge::ATTR_NAME_OUTPUT_TENSOR_DESC); + node_def->mutable_attr()->erase(ge::ATTR_NAME_INPUT_TENSOR_DESC); + // No device information required for framework + node_def->clear_device(); + + node_def->set_name(node_names.UniqueNodeName(node->GetName())); + + // Reset input names based on graph rather than the NodeDef. + node_def->clear_input(); + + // Edges, indexed by dst_input. + vector in_anchors; + ge::InControlAnchorPtr in_control_anchor; + + for (const auto &anchor : node->GetAllInDataAnchors()) { + if (static_cast(in_anchors.size()) <= anchor->GetIdx()) { + in_anchors.resize(anchor->GetIdx() + 1); + } + in_anchors[anchor->GetIdx()] = anchor; + } + + // Add regular inputs + for (auto anchor : in_anchors) { + GE_IF_BOOL_EXEC(anchor == nullptr, + GELOGE(domi::INTERNAL_ERROR, "Nonconsecutive input edges; missing input edge , for node %s .", + node_def_.name().c_str()); + return domi::INTERNAL_ERROR); + + if (anchor->GetPeerOutAnchor() != nullptr) { + string t_name = + anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName() + ":" + to_string(anchor->GetPeerOutAnchor()->GetIdx()); + node_def->add_input(t_name); + } + } + + // Add control inputs + GE_CHECK_NOTNULL(node->GetInControlAnchor()); + for (const auto &anchor : node->GetInControlAnchor()->GetPeerOutControlAnchors()) { + node_def->add_input("^" + anchor->GetOwnerNode()->GetName()); + } + + // Populate tensor_renaming. + NameRangeMap output_ranges; + GE_RETURN_IF_ERROR(NameRangesForNode(node_def_, op_def, &output_ranges)); + + for (const auto &output : output_ranges) { + for (int i = output.second.first; i < output.second.second; ++i) { + const string tensor_name = node_def->name() + ":" + output.first + ":" + to_string(i - output.second.first); + tensor_renaming[(node->GetName() + ":" + to_string(i))] = tensor_name; + } + } + } + + // Remap FunctionDef + GE_RETURN_IF_ERROR(RemapFunctionDef(fdef, name, node_names, tensor_renaming, return_values)); + + return SUCCESS; +} + +void SetInputOut(NodeDef *call_node_def, vector &in_anchor) { + GE_CHK_BOOL_EXEC(call_node_def != nullptr, return, "call_node_def is null."); + for (const auto &anchor : in_anchor) { + if ((anchor != nullptr) && (anchor->GetPeerOutAnchor() != nullptr)) { + call_node_def->add_input(anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName() + "_" + + to_string(anchor->GetPeerOutAnchor()->GetIdx())); + } + } +} + +domi::Status GraphToFunctionDef::BuildFunctionDef(ge::ComputeGraphPtr &graph, const string &name_in, + FunctionDefLibrary *library, NodeDef *call_node_def, + vector &in_anchor, + vector &out_anchor) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(library); + GE_CHECK_NOTNULL(call_node_def); + // Current date / time base on the current system + string now_time = ge::parser::CurrentTimeInStr(); + static int i = 0; + const string name = name_in + now_time + to_string(i); + i++; + // set node_def + call_node_def->set_op(name); + call_node_def->set_name(name); + + // Add func property + domi::tensorflow::AttrValue value; + domi::tensorflow::NameAttrList *function = value.mutable_func(); + function->set_name(name); + *function->mutable_attr() = call_node_def->attr(); + GraphToFunctionDef::AddNodeAttr("function", value, call_node_def); + + // Add input for nodedef + SetInputOut(call_node_def, in_anchor); + + // Add input and output nodes to the graph + GE_RETURN_IF_ERROR(GraphToFunctionDef::RecordArg(graph, in_anchor)); + GE_RETURN_IF_ERROR(GraphToFunctionDef::RecordResult(graph, out_anchor)); + + domi::tensorflow::AttrValue tin_value; + domi::tensorflow::AttrValue tout_value; + // Add tin tout attribute + domi::tensorflow::AttrValue::ListValue list; + for (auto type : arg_datetypes_) { + tin_value.mutable_list()->clear_type(); + tin_value.mutable_list()->add_type(type); + } + if (!arg_datetypes_.empty()) { + GraphToFunctionDef::AddNodeAttr("Tin", tin_value, call_node_def); + } + for (auto type : result_datetypes_) { + tout_value.mutable_list()->clear_type(); + tout_value.mutable_list()->add_type(type); + } + if (!result_datetypes_.empty()) { + GraphToFunctionDef::AddNodeAttr("Tout", tout_value, call_node_def); + } + // Convert DaVinci graph to functiondef + FunctionDef *fdef = library->add_function(); + GE_RETURN_IF_ERROR(GraphToFunctionDef::DavGraphToFunctionDef(graph, name, fdef)); + + return SUCCESS; +} + +bool GraphToFunctionDef::FindAttrValue(const domi::tensorflow::NodeDef *node_def, const string attr_name, + domi::tensorflow::AttrValue &attr_value) { + if (node_def == nullptr) { + GELOGE(PARAM_INVALID, "Input param node is nullptr."); + return false; + } + const google::protobuf::Map &attr = node_def->attr(); + + const google::protobuf::Map::const_iterator it = attr.find(attr_name); + if (it != attr.end()) { + attr_value = it->second; + return true; + } + + return false; +} + +void GraphToFunctionDef::AddNodeAttr(const string &attr_name, const domi::tensorflow::AttrValue &value, + domi::tensorflow::NodeDef *node_def) { + GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); + node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); +} +} // namespace ge diff --git a/parser/tensorflow/graph_functiondef.h b/parser/tensorflow/graph_functiondef.h new file mode 100644 index 0000000..bf1bb26 --- /dev/null +++ b/parser/tensorflow/graph_functiondef.h @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_OPTIMIZE_GRAPH_FUNCTIONDEF_H +#define GE_GRAPH_OPTIMIZE_GRAPH_FUNCTIONDEF_H + +#include +#include +#include +#include +#include +#include "graph/anchor.h" +#include "graph/ge_attr_value.h" +#include "graph/graph.h" +#include "proto/tensorflow/graph.pb.h" +#include "register/register_error_codes.h" + +using domi::tensorflow::AttrValue; +using domi::tensorflow::AttrValue_ListValue; +using domi::tensorflow::DataType; +using domi::tensorflow::DT_INVALID; +using domi::tensorflow::FunctionDef; +using domi::tensorflow::FunctionDefLibrary; +using domi::tensorflow::NodeDef; +using std::string; +using std::to_string; +using std::vector; + +namespace ge { +class GraphToFunctionDef { + public: + static domi::Status RecordArg(ge::ComputeGraphPtr graph, + const vector &in_anchor); + + static domi::Status RecordResult(ge::ComputeGraphPtr graph, + const vector &out_anchor); + + static domi::Status DavGraphToFunctionDef(ge::ComputeGraphPtr graph, + const string &name, FunctionDef *fdef); + + static domi::Status BuildFunctionDef(ge::ComputeGraphPtr &graph, + const string &nme_in, + FunctionDefLibrary *library, + NodeDef *call_node_def, + vector &in_anchor, + vector &out_anchor); + + static bool FindAttrValue(const domi::tensorflow::NodeDef *nodeDef, + const string attr_name, + domi::tensorflow::AttrValue &attr_value); + + static void AddNodeAttr(const string &attr_name, + const domi::tensorflow::AttrValue &value, + domi::tensorflow::NodeDef *node_def); +}; + +class NameMapHelper { + public: + NameMapHelper() = default; + + ~NameMapHelper() {} + + string UniqueInputOrOutputName(const string &name); + + string UniqueNodeName(const string &name); + + string Renormalize(const string &name) const; + + private: + string GetUniqueName(const string &name); + + std::set used_names_; + std::unordered_map name_mapping_; +}; +} // namespace ge + +#endif // GE_GRAPH_OPTIMIZE_GRAPH_FUNCTIONDEF_H diff --git a/parser/tensorflow/graph_insert_trans_op.h b/parser/tensorflow/graph_insert_trans_op.h new file mode 100644 index 0000000..abd6c2a --- /dev/null +++ b/parser/tensorflow/graph_insert_trans_op.h @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ +#define GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ +#include +#include +#include +#include "common/fmk_types.h" +#include "common/op/ge_op_utils.h" +#include "framework/omg/parser/parser_types.h" +#include "graph/compute_graph.h" +#include "graph/node.h" +#include "graph/types.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "register/op_registry.h" + +namespace ge { +enum InFmtSupportEnum { + InFmtSupportUndefined, + InFmtSupportElewise, + InFmtSupport4D, + InFmtSupport5D, + InFmtSupport4D_5D, + InFmtSupportNCHW_NC1HWC0 +}; + +enum InDtSupportEnum { + InDtSupportUndefined = 0, + InDtSupportAll = 1, +}; + +enum OutFmtSupportEnum { + OutFmtSupportUndefined = 0, + OutFmtSupportAsInput = 1, +}; + +enum OutDtSupportEnum { + OutDtSupportUndefined = 0, + OutDtSupportAsInput = 1, +}; + +struct OpSupportTranInfo { + InFmtSupportEnum inputFormatSupportEnum = InFmtSupportUndefined; + InDtSupportEnum inputDataTypeSupportEnum = InDtSupportUndefined; + OutFmtSupportEnum outputFormatSupportEnum = OutFmtSupportUndefined; + OutDtSupportEnum outputDataTypeSupportEnum = OutDtSupportUndefined; + + std::vector inputFormats; + std::vector inputDataTypes; + ge::Format limitOutputFormat = ge::FORMAT_RESERVED; + ge::DataType limitOutputDataType = ge::DT_UNDEFINED; +}; + +extern std::map g_OpSupportTranInfo; + +class OpTransAddSupportReg { + public: + template + OpTransAddSupportReg(const std::string &cceTbeTg, const std::string &opType, + InFmts inputFormats, InDts inputDataTypes, + OutFmts outputormat, OutDts outputDataType) { + auto cceTbeOpType = cceTbeTg + ":" + opType; + g_OpSupportTranInfo.erase(cceTbeOpType); + SetInputFormat(cceTbeOpType, inputFormats); + SetInputDataType(cceTbeOpType, inputDataTypes); + SetOutputFormat(cceTbeOpType, outputormat); + SetOutputDataType(cceTbeOpType, outputDataType); + } + ~OpTransAddSupportReg() = default; + + private: + void SetInputFormat(std::string opType, + const std::vector& supportFormat) { + auto& opInfo = g_OpSupportTranInfo[opType]; + for (auto& format : supportFormat) { + opInfo.inputFormats.push_back(format); + } + } + + void SetInputFormat(std::string opType, ge::Format supportFormat) { + auto& opInfo = g_OpSupportTranInfo[opType]; + opInfo.inputFormats.push_back(supportFormat); + } + + void SetInputFormat(std::string opType, InFmtSupportEnum enumFormat) { + auto& opInfo = g_OpSupportTranInfo[opType]; + opInfo.inputFormatSupportEnum = enumFormat; + switch (enumFormat) { + case InFmtSupportElewise: + opInfo.inputFormats = {ge::FORMAT_FRACTAL_Z, ge::FORMAT_HWCN, + ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC, + ge::FORMAT_NCHW}; + break; + case InFmtSupport4D: + opInfo.inputFormats = {ge::FORMAT_HWCN, ge::FORMAT_NHWC, + ge::FORMAT_NCHW}; + break; + case InFmtSupport5D: + opInfo.inputFormats = {ge::FORMAT_NC1HWC0}; + break; + case InFmtSupport4D_5D: + opInfo.inputFormats = {ge::FORMAT_HWCN, ge::FORMAT_NHWC, + ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0}; + break; + case InFmtSupportNCHW_NC1HWC0: + opInfo.inputFormats = {ge::FORMAT_NC1HWC0, ge::FORMAT_NCHW}; + break; + default: + break; + } + } + + void SetInputDataType(std::string opType, + const std::vector& supportDataType) { + auto& opInfo = g_OpSupportTranInfo[opType]; + for (auto& dataType : supportDataType) { + opInfo.inputDataTypes.push_back(dataType); + } + } + + void SetInputDataType(std::string opType, ge::DataType supportDataType) { + auto& opInfo = g_OpSupportTranInfo[opType]; + opInfo.inputDataTypes.push_back(supportDataType); + } + + void SetInputDataType(std::string opType, InDtSupportEnum enumDataType) { + auto& opInfo = g_OpSupportTranInfo[opType]; + opInfo.inputDataTypeSupportEnum = enumDataType; + } + + void SetOutputFormat(std::string opType, ge::Format limitOutputormat) { + auto& opInfo = g_OpSupportTranInfo[opType]; + opInfo.limitOutputFormat = limitOutputormat; + } + + void SetOutputFormat(std::string opType, OutFmtSupportEnum enumFormat) { + auto& opInfo = g_OpSupportTranInfo[opType]; + opInfo.outputFormatSupportEnum = enumFormat; + } + + void SetOutputDataType(std::string opType, ge::DataType limitOutputDataType) { + auto& opInfo = g_OpSupportTranInfo[opType]; + opInfo.limitOutputDataType = limitOutputDataType; + } + + void SetOutputDataType(std::string opType, OutDtSupportEnum enumDataType) { + auto& opInfo = g_OpSupportTranInfo[opType]; + opInfo.outputDataTypeSupportEnum = enumDataType; + } +}; + +#define TBE_SET_FORMAT_DATAYPE_INFO(cce_tbe, op, inputFormats, inputDataType, \ + outFormats, outputDataTypes) \ + TBE_SET_FORMAT_DATAYPE_INFO_UNIQ_HELPER(__COUNTER__, #cce_tbe, op, \ + inputFormats, inputDataType, \ + outFormats, outputDataTypes) +#define TBE_SET_FORMAT_DATAYPE_INFO_UNIQ_HELPER(ctr, cce_tbe, op, \ + inputFormats, inputDataType, \ + outFormats, outputDataTypes) \ + TBE_SET_FORMAT_DATAYPE_INFO_UNIQ(ctr, cce_tbe, op, inputFormats, \ + inputDataType, outFormats, outputDataTypes) +#define TBE_SET_FORMAT_DATAYPE_INFO_UNIQ(ctr, cce_tbe, op, inputFormats, \ + inputDataType, outFormats, \ + outputDataTypes) \ + OpTransAddSupportReg __gOpTransAddSupportReg##ctr( \ + cce_tbe, op, inputFormats, inputDataType, outFormats, outputDataTypes); +} // namespace domi +#endif // GE_GRAPH_OPTIMIZE_GRAPH_INSERT_TRANS_OP_H_ diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc new file mode 100644 index 0000000..7988b67 --- /dev/null +++ b/parser/tensorflow/graph_optimizer.cc @@ -0,0 +1,2003 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph_optimizer.h" +#include +#include +#include +#include +#include "./graph_insert_trans_op.h" +#include "cce/cce.h" +#include "cce/dnn.h" +#include "common/debug/log.h" +#include "common/math/math_util.h" +#include "common/op/ge_op_utils.h" +#include "common/op_map.h" +#include "common/types_map.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "framework/omg/parser/parser_types.h" +#include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/ge_tensor.h" +#include "graph/types.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" +#include "graph_functiondef.h" +#include "parser/common/acl_graph_parser_util.h" +#include "proto/tensorflow/attr_value.pb.h" +#include "register/op_registry.h" + +using domi::tensorflow::NodeDef; +using domi::tensorflow::TensorProto; +using domi::tensorflow::TensorShapeProto; +using domi::tensorflow::TensorShapeProto_Dim; + +using ge::FORMAT_NC1HWC0; +using ge::FORMAT_NCHW; +using ge::FORMAT_NHWC; + +using ge::AttrUtils; +using ge::Buffer; +using ge::ComputeGraph; +using ge::ComputeGraphPtr; +using ge::GE_TENSORFLOW_DATA_TYPE_MAP; +using ge::GeShape; +using ge::GeTensorDesc; +using ge::GeTensorPtr; +using ge::GRAPH_SUCCESS; +using ge::GraphToFunctionDef; +using ge::GraphUtils; +using ge::InControlAnchorPtr; +using ge::InDataAnchorPtr; +using ge::is_dataset_op_vec; +using ge::local_framework_op_vec; +using ge::NodePtr; +using ge::OpDesc; +using ge::OpDescPtr; +using ge::OpUtils; +using ge::OutControlAnchorPtr; +using ge::OutDataAnchorPtr; +using ge::TensorUtils; + +using ge::ATTR_NAME_INPUT_DATATYPE; +using ge::ATTR_NAME_OUTPUT_DATATYPE; + +namespace ge { +REGISTER_OPTYPE_DEFINE(TF_MAXIMUM_GRAD, "MaximumGrad"); +REGISTER_OPTYPE_DEFINE(TF_MATMUL, "Matmul"); +REGISTER_OPTYPE_DEFINE(TFRELU6, "Relu6"); +REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul"); +} // namespace ge + +namespace ge { +namespace { +const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; +} // namespace + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map g_OpSupportTranInfo = {}; + +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportUndefined) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportUndefined) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::MUL, + std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, + ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), + InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::L2LOSS, + std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC, + ge::FORMAT_HWCN}), // inputformats + ge::DT_FLOAT, ge::FORMAT_NC1HWC0, ge::DT_FLOAT) + +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONVGRADFILTER, InFmtSupportUndefined, InDtSupportUndefined, + ge::FORMAT_FRACTAL_Z, ge::DT_FLOAT) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2DBACKPROPINPUT, InFmtSupportUndefined, InDtSupportUndefined, + ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::BIASADDGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, + ge::DT_FLOAT) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::BIASADD, ge::FORMAT_NCHW, ge::DT_FLOAT, ge::FORMAT_NCHW, ge::DT_FLOAT) + +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATION, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, + ge::DT_FLOAT16) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATIONGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, + ge::DT_FLOAT16) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SOFTMAX, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, + ge::DT_FLOAT16) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SOFTMAX, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) + +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPROPFILTER, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, + ge::FORMAT_C1HWNCoC0, ge::DT_FLOAT) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPORPINPUT, InFmtSupportUndefined, InDtSupportUndefined, + OutFmtSupportAsInput, OutDtSupportUndefined) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DFORWARDNATIVE, InFmtSupportUndefined, InDtSupportUndefined, + OutFmtSupportAsInput, OutDtSupportUndefined) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORM, InFmtSupportUndefined, InDtSupportUndefined, + OutFmtSupportAsInput, OutDtSupportUndefined) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORMGRAD, InFmtSupportUndefined, InDtSupportUndefined, + OutFmtSupportAsInput, OutDtSupportUndefined) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2D, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, + OutDtSupportUndefined) + +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::RESHAPE, ge::FORMAT_NHWC, InDtSupportAll, ge::FORMAT_NHWC, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, InFmtSupport5D, ge::DT_FLOAT16, + OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MAXIMUM_GRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::APPLYRMSPROP, + std::vector({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, + ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), + ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DROPOUTDOMASK, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::LOG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRTGRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOIDGRAD, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOID, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ARGMAX, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::AVGPOOLGRAD, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::NEG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RECIPROCAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQUARE, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUB, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUM, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MATMUL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GATHERV2, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATEREQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::REALDIV, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::STRIDEDSLICE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TILE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TFRELU6, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RELU6GRAD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::EQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATER, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SELECT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_BATCH_MATMUL, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TRANSPOSE, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, + OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::STREAMMERGE, + std::vector({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), + InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) +TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::MEMCPYASYNC, + std::vector({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), + InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) + +bool GetCceTbeTransInfo(string opType, OpSupportTranInfo &opSupportInfo) { + static bool fmtInited = false; + GE_IF_BOOL_EXEC( + !fmtInited, fmtInited = true; + if (domi::OpRegistry().Instance()->GetImplyType(ge::parser::DEPTHWISEWEIGHT4D26D) == domi::ImplyType::TVM) { + auto it = g_OpSupportTranInfo.find(string("TBE:") + ge::parser::MUL); + if (it != g_OpSupportTranInfo.end()) { + auto &fmts = it->second.inputFormats; + auto itFmt = std::find(fmts.begin(), fmts.end(), ge::FORMAT_NC1HWC0); + fmts.erase(itFmt); + } + }) + string cceTbeOpType = "TBE"; + GE_IF_BOOL_EXEC(domi::OpRegistry().Instance()->GetImplyType(opType) == domi::ImplyType::BUILDIN, + cceTbeOpType = "CCE";) + cceTbeOpType = cceTbeOpType + ":" + opType; + GE_IF_BOOL_EXEC(g_OpSupportTranInfo.find(cceTbeOpType) != g_OpSupportTranInfo.end(), + opSupportInfo = g_OpSupportTranInfo[cceTbeOpType]; + return true;) + return false; +} + +Status ParserGraphOptimizer::Optimize() { return SUCCESS; } + +Status ParserGraphOptimizer::OptimizeAfterCal() { return SUCCESS; } + +void SetStringAttr(const string &originalType, OpDescPtr &opDesc, + google::protobuf::Map *tfAttr, + const pair &attr) { + string s; + (void)AttrUtils::GetStr(opDesc, attr.first, s); + + if (originalType == "ParallelMapDataset" || originalType == "FilterDataset" || + originalType == "MapAndBatchDatasetV2") { + ::domi::tensorflow::NameAttrList *nameAttrList = (*tfAttr)[attr.first].mutable_func(); + nameAttrList->set_name(s); + } else { + (*tfAttr)[attr.first].set_s(s); + } +} + +void SetIntAttr(const string &originalType, OpDescPtr &opDesc, + google::protobuf::Map *tfAttr, + const pair &attr) { + int32_t i = 0; + (void)AttrUtils::GetInt(opDesc, attr.first, i); + + if (originalType == "Pack" && (attr.first == "axis" || attr.first == "N")) { + (*tfAttr)[attr.first].set_i(i); + } else if (originalType == "TruncatedNormal" && (attr.first == "seed" || attr.first == "seed2")) { + (*tfAttr)[attr.first].set_i(i); + } else { + (*tfAttr)[attr.first].set_type((domi::tensorflow::DataType)i); + } +} + +void SetSqueezeDims(const string &originalType, google::protobuf::Map *tfAttr, + const pair &attr, const vector &intList, + const domi::tensorflow::AttrValue &attrValue, domi::tensorflow::AttrValue_ListValue *list) { + if (originalType == "Squeeze" && (attr.first == "squeeze_dims")) { + for (auto i : intList) { + list->add_i(i); + } + (*tfAttr)[attr.first] = attrValue; + } +} + +void SetListIntAttr(const string &originalType, OpDescPtr &opDesc, + google::protobuf::Map *tfAttr, + const pair &attr) { + vector intList; + (void)AttrUtils::GetListInt(opDesc, attr.first, intList); + + domi::tensorflow::AttrValue attrValue; + domi::tensorflow::AttrValue_ListValue *list = attrValue.mutable_list(); + + vector::iterator iter = std::find(is_dataset_op_vec.begin(), is_dataset_op_vec.end(), originalType); + if (iter != is_dataset_op_vec.end()) { + if (attr.first == "Toutput_types" || attr.first == "output_types") { + for (auto i : intList) { + list->add_type((domi::tensorflow::DataType)i); + } + (*tfAttr)[attr.first] = attrValue; + } else if ((originalType == "ParallelMapDataset" || originalType == "FilterDataset" || + originalType == "MapAndBatchDatasetV2") && + attr.first == "Targuments") { + (*tfAttr)[attr.first] = attrValue; + } + } else { + SetSqueezeDims(originalType, tfAttr, attr, intList, attrValue, list); + } +} + +void SetListListIntAttr(const string &originalType, OpDescPtr &opDesc, + google::protobuf::Map *tfAttr, + const pair &attr) { + vector> intListList; + (void)AttrUtils::GetListListInt(opDesc, attr.first, intListList); + + domi::tensorflow::AttrValue attrValue; + domi::tensorflow::AttrValue_ListValue *list = attrValue.mutable_list(); + + if ((originalType == "IteratorV2" || originalType == "BatchDatasetV2" || originalType == "IteratorGetNext" || + originalType == "ParallelMapDataset" || originalType == "DeviceQueueDataset" || originalType == "QueueDataset" || + originalType == "FilterDataset" || originalType == "MapAndBatchDatasetV2") && + attr.first == "output_shapes") { + for (size_t ill = 0; ill < intListList.size(); ill++) { + TensorShapeProto *tensorShape = list->add_shape(); + auto intList_ = intListList[ill]; + for (auto i : intList_) { + TensorShapeProto_Dim *dim = tensorShape->add_dim(); + dim->set_size(i); + } + } + (*tfAttr)[attr.first] = attrValue; + } else if (originalType == "TensorDataset" && attr.first == "output_shapes") { + domi::tensorflow::TensorShapeProto *tensorShape = list->add_shape(); + domi::tensorflow::TensorShapeProto_Dim *dim = tensorShape->add_dim(); + dim->set_size(0); + (*tfAttr)[attr.first] = attrValue; + } +} + +void SetTensorValue(const ge::ConstGeTensorPtr &geTensor, domi::tensorflow::TensorProto *tfTensor, int32_t dataCount) { + if (dataCount > 1) { + tfTensor->set_tensor_content(geTensor->GetData().data(), geTensor->GetData().size()); + } else { + switch (geTensor->GetTensorDesc().GetDataType()) { + case ge::DT_FLOAT: { + float f = *(reinterpret_cast(geTensor->GetData().data())); + tfTensor->add_float_val(f); + break; + } + + case ge::DT_INT32: { + int32_t i = *(reinterpret_cast(geTensor->GetData().data())); + tfTensor->add_int_val(i); + break; + } + + case ge::DT_BOOL: { + bool b = *(reinterpret_cast(geTensor->GetData().data())); + tfTensor->add_bool_val(b); + break; + } + + case ge::DT_INT64: { + int64_t i = *(reinterpret_cast(geTensor->GetData().data())); + tfTensor->add_int64_val(i); + break; + } + + case ge::DT_FLOAT16: { + int32_t f = *(reinterpret_cast(geTensor->GetData().data())); + tfTensor->add_half_val(f); + break; + } + + default: { + GELOGW("SetTensorValue not support the data type %s.", + ge::TypeUtils::DataTypeToSerialString(geTensor->GetTensorDesc().GetDataType()).c_str()); + } + } + } +} + +Status SetTensorAttr(ge::OpDescPtr &opDesc, google::protobuf::Map *tfAttr, + const pair &attr) { + ge::ConstGeTensorPtr ge_tensor; + (void)ge::AttrUtils::GetTensor(opDesc, attr.first, ge_tensor); + + domi::tensorflow::TensorProto *tf_tensor = (*tfAttr)[attr.first].mutable_tensor(); + + // Set datatype + domi::tensorflow::DataType datatype; + auto ge_datatype = ge_tensor->GetTensorDesc().GetDataType(); + int32_t data_count = 1; + switch (ge_datatype) { + case ge::DataType::DT_FLOAT: + datatype = domi::tensorflow::DataType::DT_FLOAT; + data_count = ge_tensor->GetData().size() / sizeof(float); + break; + case ge::DataType::DT_FLOAT16: + datatype = domi::tensorflow::DataType::DT_HALF; + data_count = ge_tensor->GetData().size() / sizeof(int16_t); + break; + case ge::DataType::DT_INT32: + datatype = domi::tensorflow::DataType::DT_INT32; + data_count = ge_tensor->GetData().size() / sizeof(int32_t); + break; + case ge::DataType::DT_INT64: + datatype = domi::tensorflow::DataType::DT_INT64; + data_count = ge_tensor->GetData().size() / sizeof(int64_t); + break; + case ge::DataType::DT_UINT8: + datatype = domi::tensorflow::DataType::DT_UINT8; + data_count = ge_tensor->GetData().size() / sizeof(uint8_t); + break; + case ge::DataType::DT_BOOL: + datatype = domi::tensorflow::DataType::DT_BOOL; + data_count = ge_tensor->GetData().size() / sizeof(bool); + break; + default: + GELOGE(PARAM_INVALID, "NO SUPPORT datatype = %s", ge::TypeUtils::DataTypeToSerialString(ge_datatype).c_str()); + return PARAM_INVALID; + } + tf_tensor->set_dtype(datatype); + + domi::tensorflow::TensorShapeProto *tf_shape = tf_tensor->mutable_tensor_shape(); + for (auto dim : ge_tensor->GetTensorDesc().GetShape().GetDims()) { + domi::tensorflow::TensorShapeProto_Dim *tf_dims = tf_shape->add_dim(); + tf_dims->set_size(dim); + } + + SetTensorValue(ge_tensor, tf_tensor, data_count); + return SUCCESS; +} + +Status SetNodedefProto(domi::tensorflow::NodeDef &proto, ge::NodePtr n, string original_type) { + GELOGI("graph_optimizer.cpp && SetNodedefProto"); + // Set proto head + Status ret; + auto op_desc = n->GetOpDesc(); + GELOGI("n->GetName =%s, original_type =%s", n->GetName().c_str(), original_type.c_str()); + proto.set_name(n->GetName()); + proto.set_op(original_type); + + // Set input + auto input_names = op_desc->GetInputName(); + + for (auto anchor : n->GetAllInDataAnchors()) { + GE_IF_BOOL_EXEC(anchor == nullptr || anchor->GetPeerOutAnchor() == nullptr || + anchor->GetPeerOutAnchor()->GetOwnerNode() == nullptr || + anchor->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc() == nullptr, + continue); + OutDataAnchorPtr src_anchor = anchor->GetPeerOutAnchor(); + NodePtr src_node = anchor->GetPeerOutAnchor()->GetOwnerNode(); + OpDescPtr src_opdesc = src_node->GetOpDesc(); + GELOGI("inedge src:%s, src_out_index:%d, dst:%s, dst_in_index:%d", src_opdesc->GetName().c_str(), + src_anchor->GetIdx(), op_desc->GetName().c_str(), anchor->GetIdx()); + string inputName; + inputName = src_opdesc->GetName() + ":" + "output:" + std::to_string(src_anchor->GetIdx()); + GELOGI("inputName =%s\n", inputName.c_str()); + proto.add_input(inputName); + } + + // Set device + proto.set_device("CPU"); + + // Set proto attr + google::protobuf::Map *tf_attr = proto.mutable_attr(); + map allattrs = op_desc->GetAllAttrs(); + allattrs.erase(ge::ATTR_NAME_FRAMEWORK_FWK_TYPE); + allattrs.erase(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE); + if (original_type == "Add") { + allattrs.erase(ge::ATTR_NAME_MODE); + } else if (original_type == "IteratorGetNext") { + allattrs.erase("output_num"); + } + + for (const auto &attr : allattrs) { + ge::GeAttrValue::ValueType v_t = attr.second.GetValueType(); + switch (v_t) { + case ge::GeAttrValue::ValueType::VT_STRING: { + SetStringAttr(original_type, op_desc, tf_attr, attr); + + break; + } + + case ge::GeAttrValue::ValueType::VT_INT: { + SetIntAttr(original_type, op_desc, tf_attr, attr); + + break; + } + case ge::GeAttrValue::ValueType::VT_BOOL: { + bool i = false; + (void)ge::AttrUtils::GetBool(op_desc, attr.first, i); + (*tf_attr)[attr.first].set_b(i); + break; + } + case ge::GeAttrValue::ValueType::VT_LIST_INT: { + SetListIntAttr(original_type, op_desc, tf_attr, attr); + + break; + } + case ge::GeAttrValue::ValueType::VT_LIST_LIST_INT: { + SetListListIntAttr(original_type, op_desc, tf_attr, attr); + + break; + } + case ge::GeAttrValue::ValueType::VT_TENSOR: { + ret = SetTensorAttr(op_desc, tf_attr, attr); + GE_IF_BOOL_EXEC(ret != SUCCESS, return ret); + break; + } + default: + break; + } + } + + return SUCCESS; +} + +typedef Status (*PIOListHandle)(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc); + +Status GatherV2IOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int tparams; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Tparams", tparams)), return PARAM_INVALID, "Get Tparams error."); + int tindices; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Tindices", tindices)), return PARAM_INVALID, "Get Tindices error."); + int taxis; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Taxis", taxis)), return PARAM_INVALID, "Get Taxis error."); + + // input_list - eg:{1, 3, 3} + input_list.push_back(tparams); + input_list.push_back(tindices); + input_list.push_back(taxis); + // output_list - eg:{3} + output_list.push_back(tparams); + + return SUCCESS; +} + +Status ConstIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int dtype; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "dtype", dtype)), return PARAM_INVALID, "Get dtype error."); + // output_list - {3} + output_list.push_back(dtype); + + return SUCCESS; +} + +Status MaxMinIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int attrT; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", attrT)), return PARAM_INVALID, "Get Tparams error."); + + // input_list + input_list.push_back(attrT); + input_list.push_back(attrT); + + // output_list + output_list.push_back(attrT); + + return SUCCESS; +} + +Status CastIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int srcT; + int dstT; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "SrcT", srcT)), return PARAM_INVALID, "Get srcT error."); + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "DstT", dstT)), return PARAM_INVALID, "Get dstT error."); + input_list.push_back(srcT); + output_list.push_back(dstT); + + return SUCCESS; +} + +Status AddIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, ge::OpDescPtr &opDesc) { + int type; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", type)), return PARAM_INVALID, "Get T error."); + + input_list.push_back(type); + input_list.push_back(type); + + output_list.push_back(type); + + return SUCCESS; +} + +Status LessIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int dtype; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", dtype)), return PARAM_INVALID, "Get dtype error."); + + input_list.push_back(dtype); + input_list.push_back(dtype); + output_list.push_back(domi::tensorflow::DataType::DT_BOOL); + + return SUCCESS; +} + +Status MulIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, ge::OpDescPtr &opDesc) { + int dataType; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, ge::ATTR_NAME_T, dataType)), return PARAM_INVALID, + "Get Tparams error."); + + input_list.push_back(dataType); + input_list.push_back(dataType); + + output_list.push_back(dataType); + + return SUCCESS; +} + +Status RealDivIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int t; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), return PARAM_INVALID, "Get beta error."); + + input_list.push_back(t); + input_list.push_back(t); + + output_list.push_back(t); + + return SUCCESS; +} + +Status SelectIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int t; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), return PARAM_INVALID, "Get e error."); + + input_list.push_back(domi::tensorflow::DataType::DT_BOOL); + input_list.push_back(t); + input_list.push_back(t); + + output_list.push_back(t); + + return SUCCESS; +} + +Status SqrtIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int dataType; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, ge::ATTR_NAME_T, dataType)), return PARAM_INVALID, + "Get Tparam error."); + + input_list.push_back(dataType); + + output_list.push_back(dataType); + + return SUCCESS; +} + +Status TruncatedNormalIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int t; + int dtype; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), return PARAM_INVALID, "Get T error."); + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "dtype", dtype)), return PARAM_INVALID, "Get e error."); + + input_list.push_back(t); + + output_list.push_back(dtype); + + return SUCCESS; +} + +Status PackIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int t; + int n; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), return PARAM_INVALID, "Get T error."); + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "N", n)), return PARAM_INVALID, "Get N error."); + + for (int i = 0; i < n; i++) { + input_list.push_back(t); + } + + output_list.push_back(t); + + return SUCCESS; +} + +Status DropOutGenMaskIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + input_list.push_back(domi::tensorflow::DT_INT64); + input_list.push_back(domi::tensorflow::DT_FLOAT); + output_list.push_back(domi::tensorflow::DT_UINT8); + + return SUCCESS; +} + +Status ExpandDimsIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + int dataType; + int dimType; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", dataType)), return PARAM_INVALID, "Get T error."); + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "Tdim", dimType)), return PARAM_INVALID, "Get Tdim error."); + // input_list - x y data type + input_list.push_back(dataType); + input_list.push_back(dimType); + // output_list - z data type + output_list.push_back(dataType); + + return SUCCESS; +} + +Status SqueezeIOList(ge::GeAttrValue::LIST_INT &input_list, ge::GeAttrValue::LIST_INT &output_list, + ge::OpDescPtr &opDesc) { + // Set - TENSORFLOW_IN_DATATYPE/TENSORFLOW_OUT_DATATYPE + int dataType; + vector dimTypeList; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", dataType)), return PARAM_INVALID, "Get T error."); + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetListInt(opDesc, "squeeze_dims", dimTypeList)), return PARAM_INVALID, + "Get squeeze_dims error."); + for (auto i : dimTypeList) { + GELOGI("squeeze_dims = %d.\n", i); + } + + // input_list - x y data type + input_list.push_back(dataType); + // output_list - z data type + output_list.push_back(dataType); + + return SUCCESS; +} + +Status TopKV2IOList(ge::GeAttrValue::LIST_INT &inputList, ge::GeAttrValue::LIST_INT &outputList, + ge::OpDescPtr &opDesc) { + int t; + GE_CHK_BOOL_EXEC((ge::AttrUtils::GetInt(opDesc, "T", t)), return PARAM_INVALID, "Get T error."); + + // input_list - eg:{1, 3} + inputList.push_back(t); + inputList.push_back(domi::tensorflow::DataType::DT_INT32); + + // output_list - eg:{1, 3} + outputList.push_back(t); + outputList.push_back(domi::tensorflow::DataType::DT_INT32); + + return SUCCESS; +} + +void CreateIOListFuncMap(map &mOpIOListFuncMap) { + mOpIOListFuncMap.insert({"GatherV2", GatherV2IOList}); + mOpIOListFuncMap.insert({"Const", ConstIOList}); + mOpIOListFuncMap.insert({"Maximum", MaxMinIOList}); + mOpIOListFuncMap.insert({"Minimum", MaxMinIOList}); + mOpIOListFuncMap.insert({"Cast", CastIOList}); + mOpIOListFuncMap.insert({"Add", AddIOList}); + mOpIOListFuncMap.insert({"Less", LessIOList}); + mOpIOListFuncMap.insert({"Mul", MulIOList}); + mOpIOListFuncMap.insert({"RealDiv", RealDivIOList}); + mOpIOListFuncMap.insert({"Select", SelectIOList}); + mOpIOListFuncMap.insert({"TruncatedNormal", TruncatedNormalIOList}); + mOpIOListFuncMap.insert({"Pack", PackIOList}); + mOpIOListFuncMap.insert({"DropOutGenMask", DropOutGenMaskIOList}); + mOpIOListFuncMap.insert({"ExpandDims", ExpandDimsIOList}); + mOpIOListFuncMap.insert({"Squeeze", SqueezeIOList}); + mOpIOListFuncMap.insert({"TopKV2", TopKV2IOList}); +} + +Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map &mOpIOListFuncMap) { + Status ret; + auto opDesc = n->GetOpDesc(); + GELOGI("n->GetName() = %s.\n", n->GetName().c_str()); + // Set - NodeDef PROTO + domi::tensorflow::NodeDef proto; + ge::GeAttrValue::LIST_INT inputList; + ge::GeAttrValue::LIST_INT outputList; + ret = SetNodedefProto(proto, n, originalType); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "SetNodedefProto failed."); + + // Set inputList & outputList + // Set - TENSORFLOW_IN_DATATYPE/TENSORFLOW_OUT_DATATYPE + PIOListHandle funcPtr = nullptr; + map::iterator it = mOpIOListFuncMap.find(originalType); + if (it != mOpIOListFuncMap.end()) { + funcPtr = it->second; + } + + if (funcPtr != nullptr) { + ret = ((PIOListHandle)funcPtr)(inputList, outputList, opDesc); + if (ret != SUCCESS) { + return ret; + } + } + + vector::iterator iter = std::find(is_dataset_op_vec.begin(), is_dataset_op_vec.end(), originalType); + if (iter == is_dataset_op_vec.end()) { + (void)ge::AttrUtils::SetListInt(opDesc, ge::T_IN_DATATYPE, inputList); + (void)ge::AttrUtils::SetListInt(opDesc, ge::T_OUT_DATATYPE, outputList); + } + + // Set size + for (auto ge_desc : opDesc->GetAllOutputsDescPtr()) { + int64_t real_size = 1; + int64_t tmp_dim = 0; + auto data_type = ge_desc->GetDataType(); + + uint32_t size_type = 1; + bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); + GE_IF_BOOL_EXEC(!type_ret, GELOGE(PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s", + ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); + return PARAM_INVALID); + + // calculate size + for (uint32_t j = 0; j < ge_desc->GetShape().GetDimNum(); ++j) { + tmp_dim = ge_desc->GetShape().GetDim(j); + GE_CHECK_GE(tmp_dim, 0); + FMK_INT64_MULCHECK(real_size, tmp_dim); + real_size *= tmp_dim; + } + ge::TensorUtils::SetSize(*ge_desc, real_size * size_type); + ge::TensorUtils::SetRealDimCnt(*ge_desc, ge_desc->GetShape().GetDimNum()); + } + + // Serial - nodedef proto + string nodefStr; + GE_IF_BOOL_EXEC(!proto.SerializeToString(&nodefStr), GELOGE(PARAM_INVALID, "Serialize nodedef to string failed."); + return PARAM_INVALID); + + // Set - ATTR_NAME_FRAMEWORK_NODE_DEF + ge::GeAttrValue::BYTES nodeDefBytes; + (void)ge::AttrUtils::SetZeroCopyBytes( + opDesc, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + nodeDefBytes.CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); + + // print proto + string nodefstr; + google::protobuf::TextFormat::PrintToString(proto, &nodefstr); + GELOGI("---> ! CreateNodeDefBytes() nodefstr : %s", nodefstr.c_str()); + return SUCCESS; +} + +Status CreateOpDefBytes(ge::NodePtr n, string original_type) { + auto opDesc = n->GetOpDesc(); + GELOGI("n->GetName() =%s, original_type =%s", n->GetName().c_str(), original_type.c_str()); + + // Set - OpDef PROTO + domi::tensorflow::OpDef proto; + proto.set_name(original_type); + + if (original_type == "Const") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("output"); + outputArgdef->set_type_attr("dtype"); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("value"); + attr1->set_type("tensor"); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("dtype"); + attr2->set_type("type"); + } else if (original_type == "TensorDataset") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *inputArgdef = proto.add_input_arg(); + inputArgdef->set_name("components"); + inputArgdef->set_type_list_attr("Toutput_types"); + + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("handle"); + outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("Toutput_types"); + attr1->set_type("list(type)"); + attr1->set_has_minimum(true); + attr1->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("output_shapes"); + attr2->set_type("list(shape)"); + attr2->set_has_minimum(true); + attr2->set_minimum(1); + + // Set stateful + proto.set_is_stateful(true); + } else if (original_type == "QueueDataset") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *inputArgdef = proto.add_input_arg(); + inputArgdef->set_name("input_dataset"); + inputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("handle"); + outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("sourcedata"); + attr1->set_type("string"); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("output_types"); + attr2->set_type("list(type)"); + attr2->set_has_minimum(true); + attr2->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); + attr3->set_name("output_shapes"); + attr3->set_type("list(shape)"); + attr3->set_has_minimum(true); + attr3->set_minimum(1); + + // Set stateful + proto.set_is_stateful(true); + } else if (original_type == "DeviceQueueDataset") { + // Set output_arg + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("handle"); + outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("channel_name"); + attr1->set_type("string"); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("output_types"); + attr2->set_type("list(type)"); + attr2->set_has_minimum(true); + attr2->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); + attr3->set_name("output_shapes"); + attr3->set_type("list(shape)"); + attr3->set_has_minimum(true); + attr3->set_minimum(1); + + // Set stateful + proto.set_is_stateful(true); + } else if (original_type == "ParallelMapDataset") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); + inputArgdef1->set_name("input_dataset"); + inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); + inputArgdef2->set_name("other_arguments"); + inputArgdef2->set_type_list_attr("Targuments"); + + domi::tensorflow::OpDef::ArgDef *inputArgdef3 = proto.add_input_arg(); + inputArgdef3->set_name("num_parallel_calls"); + inputArgdef3->set_type(::domi::tensorflow::DataType::DT_INT32); + + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("handle"); + outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr0 = proto.add_attr(); + attr0->set_name("f"); + attr0->set_type("func"); + + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("Targuments"); + attr1->set_type("list(type)"); + attr1->set_has_minimum(true); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("output_types"); + attr2->set_type("list(type)"); + attr2->set_has_minimum(true); + attr2->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); + attr3->set_name("output_shapes"); + attr3->set_type("list(shape)"); + attr3->set_has_minimum(true); + attr3->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr4 = proto.add_attr(); + attr4->set_name("use_iter_op_parallelism"); + attr4->set_type("bool"); + ::domi::tensorflow::AttrValue *default_value = attr4->mutable_default_value(); + default_value->set_b(true); + } else if (original_type == "BatchDatasetV2") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); + inputArgdef1->set_name("input_dataset"); + inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); + inputArgdef2->set_name("batch_size"); + inputArgdef2->set_type(::domi::tensorflow::DataType::DT_INT64); + + domi::tensorflow::OpDef::ArgDef *inputArgdef3 = proto.add_input_arg(); + inputArgdef3->set_name("drop_remainder"); + inputArgdef3->set_type(::domi::tensorflow::DataType::DT_BOOL); + + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("handle"); + outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("output_types"); + attr1->set_type("list(type)"); + attr1->set_has_minimum(true); + attr1->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("output_shapes"); + attr2->set_type("list(shape)"); + attr2->set_has_minimum(true); + attr2->set_minimum(1); + } else if (original_type == "IteratorV2") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("handle"); + outputArgdef->set_type(::domi::tensorflow::DataType::DT_RESOURCE); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("shared_name"); + attr1->set_type("string"); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("container"); + attr2->set_type("string"); + + domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); + attr3->set_name("output_types"); + attr3->set_type("list(type)"); + attr3->set_has_minimum(true); + attr3->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr4 = proto.add_attr(); + attr4->set_name("output_shapes"); + attr4->set_type("list(shape)"); + attr4->set_has_minimum(true); + attr4->set_minimum(1); + + // Set stateful + proto.set_is_stateful(true); + } else if (original_type == "MakeIterator") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); + inputArgdef1->set_name("dataset"); + inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); + inputArgdef2->set_name("iterator"); + inputArgdef2->set_type(::domi::tensorflow::DataType::DT_RESOURCE); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("_kernel"); + attr1->set_type("dp"); + + // Set stateful + proto.set_is_stateful(true); + } else if (original_type == "IteratorGetNext") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *input_argdef_1 = proto.add_input_arg(); + input_argdef_1->set_name("iterator"); + input_argdef_1->set_type(::domi::tensorflow::DataType::DT_RESOURCE); + + domi::tensorflow::OpDef::ArgDef *output_argdef = proto.add_output_arg(); + output_argdef->set_name("components"); + output_argdef->set_type_list_attr("output_types"); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("output_types"); + attr1->set_type("list(type)"); + attr1->set_has_minimum(true); + attr1->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("output_shapes"); + attr2->set_type("list(shape)"); + attr2->set_has_minimum(true); + attr2->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); + attr3->set_name("_kernel"); + attr3->set_type("dp"); + + // Set stateful + proto.set_is_stateful(true); + } else if (original_type == "FilterDataset") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); + inputArgdef1->set_name("input_dataset"); + inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); + inputArgdef2->set_name("other_arguments"); + inputArgdef2->set_type_list_attr("Targuments"); + + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("handle"); + outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr0 = proto.add_attr(); + attr0->set_name("predicate"); + attr0->set_type("func"); + + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("Targuments"); + attr1->set_type("list(type)"); + attr1->set_has_minimum(true); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("output_types"); + attr2->set_type("list(type)"); + attr2->set_has_minimum(true); + attr2->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); + attr3->set_name("output_shapes"); + attr3->set_type("list(shape)"); + attr3->set_has_minimum(true); + attr3->set_minimum(1); + } else if (original_type == "MapAndBatchDatasetV2") { + // Set input_arg & output_arg + domi::tensorflow::OpDef::ArgDef *inputArgdef1 = proto.add_input_arg(); + inputArgdef1->set_name("input_dataset"); + inputArgdef1->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + domi::tensorflow::OpDef::ArgDef *inputArgdef2 = proto.add_input_arg(); + inputArgdef2->set_name("other_arguments"); + inputArgdef2->set_type_list_attr("Targuments"); + + domi::tensorflow::OpDef::ArgDef *inputArgdef3 = proto.add_input_arg(); + inputArgdef3->set_name("batch_size"); + inputArgdef3->set_type(::domi::tensorflow::DataType::DT_INT64); + + domi::tensorflow::OpDef::ArgDef *inputArgdef4 = proto.add_input_arg(); + inputArgdef4->set_name("num_parallel_calls"); + inputArgdef4->set_type(::domi::tensorflow::DataType::DT_INT64); + + domi::tensorflow::OpDef::ArgDef *inputArgdef5 = proto.add_input_arg(); + inputArgdef5->set_name("drop_remainder"); + inputArgdef5->set_type(::domi::tensorflow::DataType::DT_BOOL); + + domi::tensorflow::OpDef::ArgDef *outputArgdef = proto.add_output_arg(); + outputArgdef->set_name("handle"); + outputArgdef->set_type(::domi::tensorflow::DataType::DT_VARIANT); + + // Set domi::AttrDef + domi::tensorflow::OpDef_AttrDef *attr0 = proto.add_attr(); + attr0->set_name("f"); + attr0->set_type("func"); + + domi::tensorflow::OpDef_AttrDef *attr1 = proto.add_attr(); + attr1->set_name("Targuments"); + attr1->set_type("list(type)"); + attr1->set_has_minimum(true); + + domi::tensorflow::OpDef_AttrDef *attr2 = proto.add_attr(); + attr2->set_name("output_types"); + attr2->set_type("list(type)"); + attr2->set_has_minimum(true); + attr2->set_minimum(1); + + domi::tensorflow::OpDef_AttrDef *attr3 = proto.add_attr(); + attr3->set_name("output_shapes"); + attr3->set_type("list(shape)"); + attr3->set_has_minimum(true); + attr3->set_minimum(1); + } + // set - opdef + string opdefString; + GE_IF_BOOL_EXEC(!proto.SerializeToString(&opdefString), GELOGE(PARAM_INVALID, "Serialize opdef to string failed."); + return PARAM_INVALID); + + (void)ge::AttrUtils::SetStr(opDesc, ge::ATTR_NAME_FRAMEWORK_OP_DEF, opdefString); + + // print proto + string opdefstr; + google::protobuf::TextFormat::PrintToString(proto, &opdefstr); + GELOGI("---> ! CreateOpDefBytes() opdefstr :\n"); + GELOGI("%s", opdefstr.c_str()); + return SUCCESS; +} + +Status CreateFuncDefBytes(ge::NodePtr n, string original_type, string func_bin_path) { + GELOGI("func_bin_path = %s", func_bin_path.c_str()); + auto opDesc = n->GetOpDesc(); + + std::string func_string; + if (original_type == "ParallelMapDataset" || original_type == "MapAndBatchDatasetV2") { + GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), "f", func_string) != true, "get func string failed."); + } else if (original_type == "FilterDataset") { + GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), "predicate", func_string) != true, "get func string failed."); + } + GELOGI("func_string = %s", func_string.c_str()); + + std::string file = func_bin_path + "/" + func_string + ".bin"; + GELOGI("file = %s", file.c_str()); + + char *buf = nullptr; + int32_t len = 0; + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::parser::ReadBytesFromBinaryFile(file.c_str(), &buf, len), return false, + "read bytes file error!"); + + GELOGI("len =%d\n", len); + + ge::GeAttrValue::BYTES funcDefBytes; + funcDefBytes = ge::Buffer::CopyFrom((std::uint8_t *)buf, len); + (void)ge::AttrUtils::SetBytes(opDesc, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes); + GELOGI("funcDefBytes.GetSize() =%zu", funcDefBytes.GetSize()); + + // print proto + if (funcDefBytes.GetSize() > 0 && funcDefBytes.GetData() != nullptr) { + domi::tensorflow::FunctionDefLibrary funcdeflib; + (void)funcdeflib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()); + + string funcdeflibrarystr; + google::protobuf::TextFormat::PrintToString(funcdeflib, &funcdeflibrarystr); + GELOGI("---> !CreateFuncDefBytes() funcdeflibrarystr :"); + } + + delete[] buf; + return SUCCESS; +} + +Status ParserGraphOptimizer::MakeTfProtoDef() { + GE_CHK_STATUS_RET(graph_->TopologicalSorting(), "graph sort failed."); + + map mOpIOListFuncMap; + CreateIOListFuncMap(mOpIOListFuncMap); + + for (ge::NodePtr n : graph_->GetDirectNode()) { + if (n->GetType() != ge::parser::FRAMEWORKOP) continue; + std::string original_type; + GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type) != true, + "get original type failed."); + + // create frameworkop nodedefbytes & TFindatatype & TFoutdatatype + vector::iterator iter = + std::find(local_framework_op_vec.begin(), local_framework_op_vec.end(), original_type); + if (iter != local_framework_op_vec.end()) { + Status ret = CreateNodeDefBytes(n, original_type, mOpIOListFuncMap); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Create nodedefBytes failed!"); + + vector::iterator iter_dataset = + std::find(is_dataset_op_vec.begin(), is_dataset_op_vec.end(), original_type); + if (original_type == "Const" || iter_dataset != is_dataset_op_vec.end()) { + ret = CreateOpDefBytes(n, original_type); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Create opdefBytes failed!"); + if (original_type == "ParallelMapDataset" || original_type == "FilterDataset" || + original_type == "MapAndBatchDatasetV2") { + ret = CreateFuncDefBytes(n, original_type, GetFuncBinPath()); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Create funcdefBytes failed!"); + } + } + } + } + + return SUCCESS; +} + +Status ParserGraphOptimizer::FusionFmkop() { + GELOGI("graph_optimizer.cpp && FustionFmkop()"); + GELOGI("GetLocalFmkopFlag() =%d", GetLocalFmkopFlag()); + GE_IF_BOOL_EXEC(GetLocalFmkopFlag() == 1, MakeTfProtoDef()); + + GE_CHECK_NOTNULL(graph_); + std::unordered_map> node_cluser_Map; + GE_CHK_STATUS_RET(MarkForFusion(node_cluser_Map), "find framework node to be fused fail."); + GE_IF_BOOL_EXEC(node_cluser_Map.size() == 0, return SUCCESS); + + for (auto it = node_cluser_Map.begin(); it != node_cluser_Map.end(); ++it) { + GE_CHK_STATUS_RET(UpdateGraph(it->second), "fusion framework nodes failed. node:%s", (it->first).c_str()); + } + // fuse all fmkop and then delete node + for (auto it = node_cluser_Map.begin(); it != node_cluser_Map.end(); ++it) { + for (auto node : it->second) { + GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {}), "Isolate removed node: %s, type: %s failed", + node->GetName().c_str(), node->GetType().c_str()); + GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node), + "Remove node: %s, type: %s without relink failed", node->GetName().c_str(), + node->GetType().c_str()); + } + } + + return SUCCESS; +} + +Status ParserGraphOptimizer::MarkForFusion(unordered_map> &node_cluser_Map) { + GE_CHECK_NOTNULL(graph_); + bool hasGetNext = false; + for (auto node : graph_->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); + string type = ""; + GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); + if (type == "IteratorGetNext") { + hasGetNext = true; + break; + } + } + for (auto node : graph_->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) + string type = ""; + GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); + if (type == "IteratorGetNext") { + vector temp_node_cluser; + for (auto in_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + NodePtr src_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + temp_node_cluser.push_back(src_node); + } + temp_node_cluser.push_back(node); + if (temp_node_cluser.size() > 1) { + vector node_cluser; + node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); + node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser; + } + temp_node_cluser.clear(); + GELOGI("MarkForFusion, IteratorGetNext graph mark success."); + } + + if (!hasGetNext && (type == "Iterator" || type == "IteratorV2")) { + GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluser_Map), "find framework node to be fused fail."); + GELOGI("MarkForFusion, Iterator init graph mark success."); + } + } + return SUCCESS; +} + +// find frameworkOP +Status ParserGraphOptimizer::FindFmkNodeCluser(unordered_map> &node_cluser_Map) { + vector temp_node_cluser; + + for (auto node : graph_->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + OpDescPtr temp_node_desc_ptr = node->GetOpDesc(); + GE_CHECK_NOTNULL(temp_node_desc_ptr); + GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::parser::DATA_TYPE, continue); + + if (temp_node_desc_ptr->GetType() == ge::parser::FRAMEWORK_OP_TYPE && + (temp_node_desc_ptr->GetName().find(RRTVAL_NODE_NAME_SUFFIX) == string::npos)) { + temp_node_cluser.push_back(node); + } else { + if (temp_node_cluser.size() > 1) { + vector node_cluser; + node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); + node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser; + } + temp_node_cluser.clear(); + } + } + if (temp_node_cluser.size() > 1) { + vector node_cluser; + node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end()); + node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser; + } + return SUCCESS; +} + +Status CollectNodeFuncs(vector &nodes, FunctionDefLibrary *library) { + for (auto node : nodes) { + GE_CHECK_NOTNULL(node); + OpDescPtr opDef = node->GetOpDesc(); + string funcdefStr; + ge::GeAttrValue::BYTES funcDefBytes; + + GE_IF_BOOL_EXEC( + AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib; + string str(reinterpret_cast(funcDefBytes.GetData()), funcDefBytes.GetSize()); + GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC( + funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib))); + } + return SUCCESS; +} + +Status ParserGraphOptimizer::UpdateGraph(vector &nodes) { + ComputeGraphPtr sub_graph = nullptr; + GE_MAKE_SHARED(sub_graph = std::make_shared("subGraph"), sub_graph = nullptr; return PARAM_INVALID); + + unordered_map node_map; + vector input_anchors; + vector output_anchors; + map> output_in_map; + vector input_control_anchors; + vector output_control_anchors; + + GE_CHK_STATUS_RET(InsertNode(sub_graph, nodes, input_anchors, output_anchors, output_in_map, input_control_anchors, + output_control_anchors, node_map), + "insert node to sub_graph failed."); + GE_CHK_STATUS_RET(LinkInnerAnchor(node_map), "Link inner anchor failed."); + + std::unique_ptr node_def(new (std::nothrow) NodeDef()); // tensorflow NodeDef + GE_CHECK_NOTNULL(node_def); + std::unique_ptr func_def_lib(new (std::nothrow) FunctionDefLibrary()); + GE_CHECK_NOTNULL(func_def_lib); + // convert graph to FunctionDef + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(nodes.size() == 0, return PARAM_INVALID, "node size must greater than 0 ."); + GE_CHK_STATUS_RET(CollectNodeFuncs(nodes, func_def_lib.get()), "Collect functionDef in nodes failed."); + GE_CHK_STATUS_RET(GraphToFunctionDef::BuildFunctionDef(sub_graph, nodes[0]->GetName(), func_def_lib.get(), + node_def.get(), input_anchors, output_anchors), + "Build functiondef failed."); + string nodefStr; + string funcdefStr; + + GE_IF_BOOL_EXEC(!node_def->SerializeToString(&nodefStr), GELOGE(PARAM_INVALID, "Serialize nodedef to string failed."); + return PARAM_INVALID); + + GE_IF_BOOL_EXEC(!func_def_lib->SerializeToString(&funcdefStr), + GELOGE(PARAM_INVALID, "Serialize func_def to string failed."); + return PARAM_INVALID); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(nodes.size() == 0, return PARAM_INVALID, "nodes is empty."); + + OpDescPtr fusion_node_opdef = nullptr; + GE_MAKE_SHARED( + fusion_node_opdef = std::make_shared(nodes[0]->GetOpDesc()->GetName(), nodes[0]->GetOpDesc()->GetType()), + fusion_node_opdef = nullptr; + return FAILED); + + std::string type = ""; + GE_CHK_STATUS_RET(ge::parser::GetOriginalType(nodes[0], type)); + (void)AttrUtils::SetStr(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); + + (void)AttrUtils::SetZeroCopyBytes( + fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, + Buffer::CopyFrom(reinterpret_cast(funcdefStr.data()), funcdefStr.length())); + (void)AttrUtils::SetZeroCopyBytes( + fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, + Buffer::CopyFrom(reinterpret_cast(nodefStr.data()), nodefStr.length())); + + (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type); + + // reconstruct fusion_node and edges + GE_CHK_STATUS_RET(RebuildOutputAnchors(output_anchors, fusion_node_opdef), + "rebuild output edges to fusion node failed.") + GE_CHK_STATUS_RET(RebuildInputAnchors(input_anchors, fusion_node_opdef), + "rebuild input edges to fusion node failed."); + NodePtr fusion_node = graph_->AddNode(fusion_node_opdef); + + // add Anchors + GE_CHK_STATUS_RET(RebuildFusionNode(input_anchors, output_anchors, output_in_map, input_control_anchors, + output_control_anchors, fusion_node), + "rebuild node failed!"); + + return SUCCESS; +} + +Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector &nodes, + vector &input_anchors, + vector &output_anchors, + map> &output_in_map, + vector &input_control_anchors, + vector &output_control_anchors, + unordered_map &node_map) { + GE_CHECK_NOTNULL(sub_graph); + for (NodePtr node : nodes) { + GE_CHECK_NOTNULL(node); + OpDescPtr op_def = node->GetOpDesc(); + NodePtr new_node = sub_graph->AddNode(op_def); + node_map[node->GetName()] = new_node; + + // Input + for (auto in_anchor : node->GetAllInDataAnchors()) { // data + OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + vector::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode()); + GE_IF_BOOL_EXEC(iter == nodes.end(), input_anchors.emplace_back(in_anchor)); + } + // Output + for (auto out_anchor : node->GetAllOutDataAnchors()) { + bool hasOutNode = false; + // data anchor + for (auto peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + vector::iterator iter = find(nodes.begin(), nodes.end(), peer_in_anchor->GetOwnerNode()); + GE_IF_BOOL_EXEC(iter == nodes.end(), output_in_map[out_anchor].emplace_back(peer_in_anchor); hasOutNode = true); + } + GE_IF_BOOL_EXEC(hasOutNode == true, output_anchors.emplace_back(out_anchor)); + } + + InControlAnchorPtr node_in_control = node->GetInControlAnchor(); + GE_IF_BOOL_EXEC( + node_in_control != nullptr, for (auto peer_out_anchor + : node_in_control->GetPeerOutControlAnchors()) { + vector::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode()); + GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control)); + }); + OutControlAnchorPtr node_out_control = node->GetOutControlAnchor(); + GE_IF_BOOL_EXEC( + node_out_control != nullptr, for (auto peer_in_control_anchor + : node_out_control->GetPeerInControlAnchors()) { + vector::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode()); + GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control)); + }); + } + return SUCCESS; +} + +Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map &node_map) { + for (auto node : graph_->GetDirectNode()) { + GE_IF_BOOL_EXEC(node_map.count(node->GetName()) == 0, continue); + NodePtr dst = node_map[node->GetName()]; + for (auto in_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + GE_IF_BOOL_EXEC(node_map.count(peer_out_anchor->GetOwnerNode()->GetName()) == 0, continue); + NodePtr src = node_map[peer_out_anchor->GetOwnerNode()->GetName()]; + + GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(src->GetOutDataAnchor(peer_out_anchor->GetIdx()), + dst->GetInDataAnchor(in_anchor->GetIdx())) != GRAPH_SUCCESS, + GELOGE(FAILED, + "LinkInnerAnchor Link data anchor failed, src node: %s, " + "dst node: %s.", + src->GetName().c_str(), dst->GetName().c_str()); + return FAILED); + } + + InControlAnchorPtr node_in_control = node->GetInControlAnchor(); + GE_IF_BOOL_EXEC( + node_in_control != nullptr, for (auto peer_out_ctl_anchor + : node_in_control->GetPeerOutControlAnchors()) { + GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue); + NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()]; + GE_IF_BOOL_EXEC( + ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS, + GELOGE(FAILED, + "LinkInnerAnchor Link control anchor failed, src node: " + "%s, dst node: %s.", + src_ctrl->GetName().c_str(), dst->GetName().c_str()); + return FAILED); + }); + } + return SUCCESS; +} + +// rebuild output anchor +Status ParserGraphOptimizer::RebuildOutputAnchors(vector &output_anchors, + ge::OpDescPtr fusion_op_desc) { + ge::GeAttrValue::LIST_INT output_list; + GE_CHECK_NOTNULL(fusion_op_desc); + + // create input desc + for (auto out_anchor : output_anchors) { + NodePtr src_node = out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + + GeTensorDesc src_out_desc = src_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); + GE_CHK_BOOL_EXEC(fusion_op_desc->AddOutputDesc(src_out_desc) == ge::GRAPH_SUCCESS, return FAILED); + + ge::DataType data_type = src_out_desc.GetDataType(); + auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type); + GE_IF_BOOL_EXEC( + iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), + GELOGE(PARAM_INVALID, "data_type %s not supported", ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); + return PARAM_INVALID); + + int32_t dtype = iter->second; + output_list.push_back((int64_t)dtype); + GELOGI("FUNCDEF: output_list push_back %d.", dtype); + } + GE_IF_BOOL_EXEC(!output_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_OUT_DATATYPE, output_list)); + + return SUCCESS; +} +// rebuild input desc +Status ParserGraphOptimizer::RebuildInputAnchors(vector &input_anchors, + ge::OpDescPtr fusion_op_desc) { + ge::GeAttrValue::LIST_INT input_list; + GE_CHECK_NOTNULL(fusion_op_desc); + // add input desc + for (auto in_anchor : input_anchors) { + NodePtr dst_node = in_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(dst_node); + + auto tensorDescPtr = dst_node->GetOpDesc()->GetInputDescPtr(in_anchor->GetIdx()); + GE_CHECK_NOTNULL_EXEC(tensorDescPtr, return domi::FAILED); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((fusion_op_desc->AddInputDesc(*tensorDescPtr)) != GRAPH_SUCCESS, return FAILED, + "Add fusion_op_desc AddInputDesc failed"); + ge::DataType data_type = tensorDescPtr->GetDataType(); + auto iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type); + GE_IF_BOOL_EXEC( + iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(), + GELOGE(PARAM_INVALID, "data_type %s not supported", ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); + return PARAM_INVALID); + + int32_t dtype = iter->second; + input_list.push_back((int64_t)dtype); + GELOGI("FUNCDEF: input_list push_back %d.", dtype); + } + GE_IF_BOOL_EXEC(!input_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_IN_DATATYPE, input_list)); + + return SUCCESS; +} + +Status ParserGraphOptimizer::RebuildFusionNode(vector &input_anchors, + vector &output_anchors, + map> &output_in_map, + vector &input_control_anchors, + vector &output_control_anchors, + ge::NodePtr fusion_node) { + int32_t src_index = 0; + + for (auto out_anchor : output_anchors) { + for (auto in_anchor : output_in_map[out_anchor]) { + (void)in_anchor->Unlink(out_anchor); + GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(fusion_node->GetOutDataAnchor(src_index), in_anchor), + "Add anchor between fusion node and in anchor node!"); + } + src_index++; + } + src_index = 0; + for (auto in_anchor : input_anchors) { + OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor(); + out_anchor->Unlink(in_anchor); + GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(out_anchor, fusion_node->GetInDataAnchor(src_index)), + "Add anchor between out anchor node and fusion node!"); + src_index++; + } + + for (auto out_control_anchor : output_control_anchors) { + for (auto in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { + in_control_anchor->Unlink(out_control_anchor); + GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(fusion_node->GetOutControlAnchor(), in_control_anchor), + "Add anchor between fusion node and in control anchor node!"); + } + } + for (auto in_control_anchor : input_control_anchors) { + for (auto out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + out_control_anchor->Unlink(in_control_anchor); + GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(out_control_anchor, fusion_node->GetInControlAnchor()), + "Add anchor between out control anchor node and fusion node!"); + } + } + return SUCCESS; +} + +Status ParserGraphOptimizer::Insert4DTo5DTransOp(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, + enum ge::Format src_out_format, enum ge::DataType src_out_data_type, + enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type) { + bool isNCHWFP32To5DFP16 = (src_out_format == ge::FORMAT_NCHW && dst_in_format == ge::FORMAT_NC1HWC0); + if (isNCHWFP32To5DFP16) { + NodePtr cast_node = nullptr; + + if (src_out_data_type != dst_in_data_type) { + OpDescPtr cast_opdesc = CreateCastOp(src_out_data_type, dst_in_data_type, ge::FORMAT_NCHW); + cast_node = graph_->AddNode(cast_opdesc); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return INTERNAL_ERROR, "graph add cast node fail."); + } + + OpDescPtr trans_data_opdesc = CreateTransDataOp(FORMAT_NCHW); + NodePtr trans_data_node = graph_->AddNode(trans_data_opdesc); + GE_CHK_BOOL_EXEC(trans_data_node != nullptr, return INTERNAL_ERROR, "graph add TransData node node fail."); + GE_CHK_STATUS_RET(NewNodeAddEdges(src_anchor, dst_anchor, nullptr, cast_node, trans_data_node), + "NewNodeAddEdges ret fail."); + + return SUCCESS; + } + + OpDescPtr translateto5D = CreateTranslateOp(src_out_format, src_out_data_type, dst_in_format, dst_in_data_type); + GE_CHECK_NOTNULL(translateto5D); + NodePtr transNode = graph_->AddNode(translateto5D); + GE_CHECK_NOTNULL(transNode); + GELOGI("Create 4D To 5D fp32 node susscess!"); + + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transNode->GetInDataAnchor(0)), return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transNode->GetOutDataAnchor(0), dst_anchor), return INTERNAL_ERROR); + + GELOGI("Create 4D To 5D susscess!"); + return SUCCESS; +} + +Status ParserGraphOptimizer::InsertFZ2HWCK(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, + enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, + enum ge::Format dstInFormat, enum ge::DataType dstInDatatype) { + GELOGI("In InsertFZ2HWCK !"); + GE_IF_BOOL_EXEC( + srcOutFormat == ge::FORMAT_FRACTAL_Z, NodePtr transHalfNode = nullptr; + if (srcOutDatatype == ge::DT_FLOAT) { + // create FZ fp32->FZ fp16 node + OpDescPtr translatetoHalf = CreateTranslateOp(srcOutFormat, srcOutDatatype, srcOutFormat, ge::DT_FLOAT16); + transHalfNode = graph_->AddNode(translatetoHalf); + GE_CHECK_NOTNULL(transHalfNode); + GELOGI("Create FZ fp32 to FZ fp16 node susscess!"); + // create FZ fp16->HWCK fp32 node + } + + OpDescPtr translatetoHWCK = CreateTranslateOp(srcOutFormat, ge::DT_FLOAT16, dstInFormat, dstInDatatype); + NodePtr transHWCKNode = graph_->AddNode(translatetoHWCK); GELOGI("Create FZ 16 to HWCK fp32 node susscess!"); + GE_CHECK_NOTNULL(transHWCKNode); if (transHalfNode) { + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transHalfNode->GetInDataAnchor(0)), return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transHalfNode->GetOutDataAnchor(0), transHWCKNode->GetInDataAnchor(0)), + return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transHWCKNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, + return INTERNAL_ERROR); + } else { + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transHWCKNode->GetInDataAnchor(0)), return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transHWCKNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, + return INTERNAL_ERROR); + } GELOGI("Create InsertFZ2HWCK success!");) + return SUCCESS; +} + +Status ParserGraphOptimizer::InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, + enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, + enum ge::Format dstInFormat, enum ge::DataType dstInDatatype) { + GELOGI("In Insert 5D To 4D !"); + GE_IF_BOOL_EXEC( + srcOutFormat == ge::FORMAT_NC1HWC0, NodePtr cast_node = nullptr; + if (srcOutDatatype == ge::DT_FLOAT && dstInDatatype == ge::DT_FLOAT) { + auto cast_opdesc = CreateCastOp(ge::DT_FLOAT, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0); + cast_node = graph_->AddNode(cast_opdesc); + + srcOutDatatype = ge::DT_FLOAT16; + } NodePtr transHalfNode = nullptr; + OpDescPtr translateto4D = CreateTranslateOp(srcOutFormat, srcOutDatatype, dstInFormat, dstInDatatype); + NodePtr trans4DNode = graph_->AddNode(translateto4D); GELOGI("Create 5D To 4D fp32 node susscess!"); + GE_CHECK_NOTNULL(trans4DNode); + + if (cast_node) { + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, cast_node->GetInDataAnchor(0)), return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), trans4DNode->GetInDataAnchor(0)), + return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(trans4DNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, + return INTERNAL_ERROR); + } else { + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, trans4DNode->GetInDataAnchor(0)), return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(trans4DNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, + return INTERNAL_ERROR); + } GELOGI("Create 5D To 4D susscess!");) + return SUCCESS; +} + +Status ParserGraphOptimizer::InsertHWCK2FZ(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, + enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, + enum ge::Format dstInFormat, enum ge::DataType dstInDatatype) { + GELOGI("In InsertHWCK2FZ !"); + GE_IF_BOOL_EXEC( + srcOutFormat == ge::FORMAT_HWCN, NodePtr transHalfNode = nullptr; + OpDescPtr translatetoFZ = CreateTranslateOp(srcOutFormat, srcOutDatatype, dstInFormat, ge::DT_FLOAT16); + NodePtr transHWCK2FZNode = graph_->AddNode(translatetoFZ); GELOGI("Create HWCK fp32 to FZ 16 node susscess!"); + GE_CHECK_NOTNULL(transHWCK2FZNode); + + ge::NodePtr translateHalftoFp32Node = nullptr; if (dstInDatatype == ge::DT_FLOAT) { + // create FZ fp16 ->FZ fp32 node + ge::OpDescPtr translateHalftoFp32 = CreateTranslateOp(dstInFormat, ge::DT_FLOAT16, dstInFormat, dstInDatatype); + translateHalftoFp32Node = graph_->AddNode(translateHalftoFp32); + GE_CHECK_NOTNULL(translateHalftoFp32Node); + GELOGI("Create FZ fp32 to FZ fp16 node susscess!"); + // create FZ fp16->HWCK fp32 node + } + + if (translateHalftoFp32Node) { + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transHWCK2FZNode->GetInDataAnchor(0)), return INTERNAL_ERROR); + GE_IF_BOOL_EXEC( + GraphUtils::AddEdge(transHWCK2FZNode->GetOutDataAnchor(0), translateHalftoFp32Node->GetInDataAnchor(0)), + return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(translateHalftoFp32Node->GetOutDataAnchor(0), dst_anchor) != SUCCESS, + return INTERNAL_ERROR); + } else { + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(src_anchor, transHWCK2FZNode->GetInDataAnchor(0)), return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(transHWCK2FZNode->GetOutDataAnchor(0), dst_anchor) != SUCCESS, + return INTERNAL_ERROR); + } GELOGI("Create InsertHWCK2FZ success!");) + return SUCCESS; +} + +Status ParserGraphOptimizer::Insert5DTo4DTransOp(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, + enum ge::Format src_out_format, enum ge::DataType src_out_data_type, + enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type) { + // Status ret; + NodePtr permute_node = nullptr; + NodePtr cast_node = nullptr; + + OpDescPtr trans_data_opdesc = CreateTransDataOp(FORMAT_NC1HWC0); + NodePtr trans_data_node = graph_->AddNode(trans_data_opdesc); + GE_CHK_BOOL_EXEC(trans_data_node != nullptr, return INTERNAL_ERROR, "graph add TransData node node fail."); + + if (src_out_data_type != dst_in_data_type) { + OpDescPtr cast_opdesc = CreateCastOp(src_out_data_type, dst_in_data_type, ge::FORMAT_NCHW); + cast_node = graph_->AddNode(cast_opdesc); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return INTERNAL_ERROR, "graph add cast node fail."); + } + + if (dst_in_format == FORMAT_NHWC) { + OpDescPtr permute_opdec = CreatePermuteOp(FORMAT_NCHW, dst_in_format); + permute_node = graph_->AddNode(permute_opdec); + GE_CHK_BOOL_EXEC(permute_node != nullptr, return INTERNAL_ERROR, "graph add permute node fail."); + } + + GE_CHK_STATUS_RET(NewNodeAddEdges(src_anchor, dst_anchor, trans_data_node, cast_node, permute_node), + "NewNodeAddEdges ret fail."); + + return SUCCESS; +} + +Status ParserGraphOptimizer::NewNodeAddEdges(OutDataAnchorPtr src_anchor, InDataAnchorPtr dst_anchor, NodePtr first, + NodePtr second, NodePtr third) { + GE_CHECK_NOTNULL(src_anchor); + GE_CHECK_NOTNULL(dst_anchor); + OutDataAnchorPtr add_in_anchor = nullptr; + InDataAnchorPtr add_out_anchor = nullptr; + NodePtr src_node = src_anchor->GetOwnerNode(); + NodePtr dst_node = dst_anchor->GetOwnerNode(); + + if (first != nullptr) { + Status status = GraphUtils::AddEdge(src_anchor, first->GetInDataAnchor(0)); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + src_anchor->GetIdx(), 0); + if (second != nullptr) { + status = GraphUtils::AddEdge(first->GetOutDataAnchor(0), second->GetInDataAnchor(0)); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", 0, + 0); + if (third != nullptr) { + status = GraphUtils::AddEdge(second->GetOutDataAnchor(0), third->GetInDataAnchor(0)); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, 0); + status = GraphUtils::AddEdge(third->GetOutDataAnchor(0), dst_anchor); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, dst_anchor->GetIdx()); + } else { + status = GraphUtils::AddEdge(second->GetOutDataAnchor(0), dst_anchor); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, dst_anchor->GetIdx()); + } + } else { + if (third != nullptr) { + status = GraphUtils::AddEdge(first->GetOutDataAnchor(0), third->GetInDataAnchor(0)); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, 0); + status = GraphUtils::AddEdge(third->GetOutDataAnchor(0), dst_anchor); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, dst_anchor->GetIdx()); + } else { + status = GraphUtils::AddEdge(first->GetOutDataAnchor(0), dst_anchor); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, dst_anchor->GetIdx()); + } + } + } else { + if (second != nullptr) { + Status status = GraphUtils::AddEdge(src_anchor, second->GetInDataAnchor(0)); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, + "graph add src to cast edge fail, src index:%d, dst index:%d.", src_anchor->GetIdx(), 0); + GE_IF_BOOL_EXEC( + third != nullptr, status = GraphUtils::AddEdge(second->GetOutDataAnchor(0), third->GetInDataAnchor(0)); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, 0); + status = GraphUtils::AddEdge(third->GetOutDataAnchor(0), dst_anchor); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, dst_anchor->GetIdx());); + GE_IF_BOOL_EXEC(third == nullptr, status = GraphUtils::AddEdge(second->GetOutDataAnchor(0), dst_anchor); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, + "graph add edge fail, src index:%d, dst index:%d.", 0, 0);); + } else { + if (third != nullptr) { + Status status = GraphUtils::AddEdge(src_anchor, third->GetInDataAnchor(0)); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, 0); + status = GraphUtils::AddEdge(third->GetOutDataAnchor(0), dst_anchor); + GE_CHK_BOOL_EXEC(status == SUCCESS, return INTERNAL_ERROR, "graph add edge fail, src index:%d, dst index:%d.", + 0, dst_anchor->GetIdx()); + } + } + } + return SUCCESS; +} + +OpDescPtr ParserGraphOptimizer::CreateTranslateOp(enum ge::Format inFormat, enum ge::DataType inDatatype, + enum ge::Format outFormat, enum ge::DataType outDatatype) { + /** + * 0. FP32 <-> FP16 + * 1. from HWCK(FP32) to FracZ(FP16); + * 2. from FracZ(FP16) to HWCK(FP32); + * 3. from NHWC(FP32) to NC1HWC0(FP16); + * 4. from NC1HWC0(FP32) to NHWC(FP32); + * 5. from NC1HWC0(FP16) to NHWC(FP32) + */ + static uint32_t transop_count = 0; + OpDescPtr op_def = nullptr; + std::stringstream sstmp; + sstmp << "translate_" << ge::parser::TRANSDATA << "_" << transop_count++; + GE_MAKE_SHARED(op_def = std::make_shared(sstmp.str().c_str(), ge::parser::TRANSLATE), op_def = nullptr; + return op_def); + GELOGI( + "create translate op:%s, input format:%s, input datatype:%s, output " + "format:%s, output datatype:%s.", + op_def->GetName().c_str(), ge::TypeUtils::FormatToSerialString(inFormat).c_str(), + ge::TypeUtils::DataTypeToSerialString(inDatatype).c_str(), ge::TypeUtils::FormatToSerialString(outFormat).c_str(), + ge::TypeUtils::DataTypeToSerialString(outDatatype).c_str()); + + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_INPUT_FORMAT, inFormat), return nullptr, + "SetInt ATTR_NAME_INPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_INPUT_DATATYPE, inDatatype), return nullptr, + "SetInt ATTR_NAME_INPUT_DATATYPE failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_OUTPUT_FORMAT, outFormat), return nullptr, + "SetInt ATTR_NAME_INPUT_DATATYPE failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_OUTPUT_DATATYPE, outDatatype), return nullptr, + "SetInt ATTR_NAME_INPUT_DATATYPE failed."); + if (inDatatype != ge::DT_FLOAT16) { + GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddInputDesc(GeTensorDesc(GeShape(), inFormat)), return nullptr, + "create translate op:add input desc fail."); + } else { + GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddInputDesc(GeTensorDesc(GeShape(), inFormat, ge::DT_FLOAT16)), return nullptr, + "create translate op:add input desc fail."); + } + if (outDatatype != ge::DT_FLOAT16) { + GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddOutputDesc(GeTensorDesc(GeShape(), outFormat)), return nullptr, + "create translate op:add output desc fail."); + } else { + GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddOutputDesc(GeTensorDesc(GeShape(), outFormat, ge::DT_FLOAT16)), + return nullptr, "create translate op:add output desc fail."); + } + return op_def; +} + +OpDescPtr ParserGraphOptimizer::CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format) { + static uint32_t transop_count = 0; + + std::stringstream sstmp; + sstmp << "transdata_" << ge::parser::PERMUTE << "_" << transop_count++; + + OpDescPtr op_desc = nullptr; + GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::PERMUTE), op_desc = nullptr; + return op_desc); + GELOGI("create permute op:%s", op_desc->GetName().c_str()); + + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), return nullptr, + "SetInt ATTR_NAME_INPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), return nullptr, + "SetInt ATTR_NAME_OUTPUT_FORMAT failed."); + + GE_IF_BOOL_EXEC(input_format == FORMAT_NCHW, (void)AttrUtils::SetInt(op_desc, "NCHW_to_NHWC", (int64_t)1)); + GE_IF_BOOL_EXEC(input_format == FORMAT_NHWC, (void)AttrUtils::SetInt(op_desc, "NHWC_to_NCHW", (int64_t)1)); + + GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), input_format)), return nullptr, + "create permute op:add input desc fail."); + GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), output_format)), return nullptr, + "create permute op:add output desc fail."); + + return op_desc; +} + +OpDescPtr ParserGraphOptimizer::CreateCastOp(enum ge::DataType input_data_type, enum ge::DataType output_data_type, + enum ge::Format format) { + static uint32_t transop_count = 0; + std::stringstream sstmp; + sstmp << "transdata_" << ge::parser::CAST << "_" << transop_count++; + + OpDescPtr op_desc = nullptr; + GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::CAST), op_desc = nullptr; + return op_desc); + GELOGI("create cast op:%s, input datatype:%s, out datatype:%s", op_desc->GetName().c_str(), + ge::TypeUtils::DataTypeToSerialString(input_data_type).c_str(), + ge::TypeUtils::DataTypeToSerialString(output_data_type).c_str()); + + if (!(AttrUtils::SetInt(op_desc, ge::CAST_ATTR_SRCT, (int64_t)input_data_type) && + AttrUtils::SetInt(op_desc, ge::CAST_ATTR_DSTT, (int64_t)output_data_type) && + AttrUtils::SetInt(op_desc, ge::CAST_ATTR_DST_TYPE, (int64_t)output_data_type) && + AttrUtils::SetBool(op_desc, ge::CAST_ATTR_TRUNCATE, false))) { + GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE fail, node: %s.", + op_desc->GetName().c_str()); + return nullptr; + } + + GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), format, input_data_type)), return nullptr, + "create cast op:add input desc fail."); + GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), format, output_data_type)), return nullptr, + "create cast op:add output desc fail."); + + return op_desc; +} +OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) { + static uint32_t transop_count = 0; + std::stringstream sstmp; + sstmp << "transdata_" << ge::parser::TRANSDATA << "_" << transop_count++; + + OpDescPtr op_desc = nullptr; + GE_MAKE_SHARED(op_desc = std::make_shared(sstmp.str().c_str(), ge::parser::TRANSDATA), op_desc = nullptr; + return op_desc); + + GELOGI("create transdata op:%s, input format:%s.", op_desc->GetName().c_str(), + ge::TypeUtils::FormatToSerialString(input_format).c_str()); + enum ge::Format output_format = FORMAT_NC1HWC0; + if (input_format != FORMAT_NCHW) { + input_format = FORMAT_NC1HWC0; + output_format = FORMAT_NCHW; + } + + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), return nullptr, + "SetInt of ATTR_NAME_INPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), return nullptr, + "SetInt of ATTR_NAME_OUTPUT_FORMAT failed."); + GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), input_format)), return nullptr, + "create transdata op:add input desc fail."); + GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), output_format)), return nullptr, + "create transdata op:add output desc fail."); + + return op_desc; +} +} // namespace ge diff --git a/parser/tensorflow/graph_optimizer.h b/parser/tensorflow/graph_optimizer.h new file mode 100644 index 0000000..9f73d69 --- /dev/null +++ b/parser/tensorflow/graph_optimizer.h @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ +#define GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ +#include +#include +#include +#include +#include "framework/omg/parser/parser_types.h" +#include "graph/anchor.h" +#include "graph/compute_graph.h" +#include "graph/node.h" +#include "omg/omg_inner_types.h" + +using std::map; +using std::string; +using std::unordered_map; +using std::vector; + +namespace ge { +class ParserGraphOptimizer { + public: + explicit ParserGraphOptimizer(ge::ComputeGraphPtr graph, domi::FrameworkType type = domi::TENSORFLOW) + : graph_(graph), fmktype_(type), local_fmk_op_flag_(false) {} + + ~ParserGraphOptimizer() {} + + domi::Status Optimize(); + + domi::Status OptimizeAfterCal(); + + domi::Status FusionFmkop(); + + inline bool IsHCOMOp(const string &op_type) { + return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) || + (op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) || + (op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); + } + + void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; } + + const bool GetLocalFmkopFlag() const { return local_fmk_op_flag_; } + + void SetFuncBinPath(std::string isFuncBinPath) { func_bin_path_ = isFuncBinPath; } + const std::string GetFuncBinPath() const { return func_bin_path_; } + + domi::Status InsertHWCK2FZ(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, + enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, + enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); + + domi::Status Insert4DTo5DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, + enum ge::Format src_out_format, enum ge::DataType src_out_data_type, + enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); + + domi::Status InsertFZ2HWCK(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, + enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, + enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); + + domi::Status Insert5DTo4DTransOp(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, + enum ge::Format src_out_format, enum ge::DataType src_out_data_type, + enum ge::Format dst_in_format, enum ge::DataType dst_in_data_type); + + ge::OpDescPtr CreateCastOp(enum ge::DataType input_datatype, enum ge::DataType output_datatype, ge::Format format); + + ge::OpDescPtr CreatePermuteOp(enum ge::Format input_format, enum ge::Format output_format); + + ge::OpDescPtr CreateTransDataOp(enum ge::Format input_format); + + domi::Status NewNodeAddEdges(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, ge::NodePtr first, + ge::NodePtr second, ge::NodePtr third); + + domi::Status InsertVar5DTo4D(ge::OutDataAnchorPtr src_anchor, ge::InDataAnchorPtr dst_anchor, + enum ge::Format srcOutFormat, enum ge::DataType srcOutDatatype, + enum ge::Format dstInFormat, enum ge::DataType dstInDatatype); + + ge::OpDescPtr CreateTranslateOp(enum ge::Format inFormat, ge::DataType inDatatype, enum ge::Format outFormat, + ge::DataType outDatatype); + + private: + ge::ComputeGraphPtr graph_; + domi::FrameworkType fmktype_; + // local fmkop flag + bool local_fmk_op_flag_; + std::string func_bin_path_; + + domi::Status FindFmkNodeCluser(unordered_map> &node_cluser_Map); + + domi::Status MarkForFusion(unordered_map> &node_cluser_Map); + + domi::Status UpdateGraph(vector &nodes); + + domi::Status InsertNode(ge::ComputeGraphPtr sub_graph, vector &nodes, + vector &input_anchors, vector &output_anchors, + map> &output_in_map, + vector &input_control_anchors, + vector &output_control_anchors, + unordered_map &node_map); + + domi::Status LinkInnerAnchor(unordered_map &node_map); + + domi::Status RebuildOutputAnchors(vector &output_anchors, ge::OpDescPtr fusion_op_desc); + + domi::Status RebuildInputAnchors(vector &input_anchors, ge::OpDescPtr fusion_op_desc); + + domi::Status RebuildFusionNode(vector &input_anchors, + vector &output_anchors, + map> &output_in_map, + vector &input_control_anchors, + vector &output_control_anchors, ge::NodePtr fusion_node); + + domi::Status MakeTfProtoDef(); +}; +} // namespace ge +#endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ diff --git a/parser/tensorflow/iterator_fusion_pass.cc b/parser/tensorflow/iterator_fusion_pass.cc new file mode 100644 index 0000000..0324050 --- /dev/null +++ b/parser/tensorflow/iterator_fusion_pass.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "iterator_fusion_pass.h" + +#include + +#include "common/debug/log.h" +#include "framework/omg/parser/parser_types.h" +#include "common/util.h" +#include "graph_optimizer.h" +#include "framework/common/ge_inner_error_codes.h" + +namespace ge { +Status IteratorFusionPass::Run(ge::ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + domi::FrameworkType fmk_type = static_cast(fmk_type_); + std::unique_ptr graph_optimizer(new (std::nothrow) ParserGraphOptimizer(graph, fmk_type)); + if (graph_optimizer == nullptr) { + return FAILED; + } + + graph_optimizer->SetLocalFmkopFlag(local_fmk_op_flag_); + return graph_optimizer->FusionFmkop(); +} +} // namespace ge diff --git a/parser/tensorflow/iterator_fusion_pass.h b/parser/tensorflow/iterator_fusion_pass.h new file mode 100644 index 0000000..c193acf --- /dev/null +++ b/parser/tensorflow/iterator_fusion_pass.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_ITERATOR_FUSION_PASS_H_ +#define GE_GRAPH_PASSES_ITERATOR_FUSION_PASS_H_ + +#include "framework/common/ge_types.h" +#include "inc/graph_pass.h" + +namespace ge { +class IteratorFusionPass : public GraphPass { + public: + IteratorFusionPass(ge::FrameworkType type, bool local_fmk_op_flag) + : fmk_type_(type), local_fmk_op_flag_(local_fmk_op_flag) {} + + virtual ~IteratorFusionPass() {} + + Status Run(ge::ComputeGraphPtr graph) final; + + private: + ge::FrameworkType fmk_type_; + bool local_fmk_op_flag_; +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_ITERATOR_FUSION_PASS_H_ diff --git a/parser/tensorflow/proto/ge_ir.proto b/parser/tensorflow/proto/ge_ir.proto new file mode 100644 index 0000000..e7bfe0c --- /dev/null +++ b/parser/tensorflow/proto/ge_ir.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package ge.proto; + +enum DataType +{ + DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. + DT_FLOAT = 1; // float type + DT_FLOAT16 = 2; // fp16 type + DT_INT8 = 3; // int8 type + DT_UINT8 = 4; // uint8 type + DT_INT16 = 5; // int16 type + DT_UINT16 = 6; // uint16 type + DT_INT32 = 7; // + DT_INT64 = 8; // int64 type + DT_UINT32 = 9; // unsigned int32 + DT_UINT64 = 10; // unsigned int64 + DT_BOOL = 11; // bool type + DT_DOUBLE = 12; // double type + DT_STRING = 13; // string type + DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ + DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ + DT_COMPLEX64 = 16; // complex64 type + DT_COMPLEX128 = 17; // complex128 type + DT_QINT8 = 18; // qint8 type + DT_QINT16 = 19; // qint16 type + DT_QINT32 = 20; // qint32 type + DT_QUINT8 = 21; // quint8 type + DT_QUINT16 = 22; // quint16 type + DT_RESOURCE = 23; // resource type + DT_STRING_REF = 24; // string_ref type + DT_DUAL = 25; /**< dual output type */ +} + +message AttrDef +{ + message ListValue + { + enum ListValueType{ + VT_LIST_NONE = 0; + VT_LIST_STRING = 1; + VT_LIST_INT = 2; + VT_LIST_FLOAT = 3; + VT_LIST_BOOL = 4; + VT_LIST_BYTES = 5; + VT_LIST_TENSOR_DESC = 6; + VT_LIST_TENSOR = 7; + VT_LIST_GRAPH = 8; + VT_LIST_NAMED_ATTRS = 9; + VT_LIST_DATA_TYPE = 10; + } + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3; // "list(int)" + repeated float f = 4; // "list(float)" + repeated bool b = 5; // "list(bool)" + repeated bytes bt = 7; + repeated TensorDescriptor td = 8; + repeated TensorDef t = 9; + repeated GraphDef g = 10; + repeated NamedAttrs na = 11; + repeated int64 dt = 12; // list ge::DataType + + ListValueType val_type = 20; + } + + message ListListInt{ + message ListInt{ + repeated int64 list_i = 1; // list int + } + repeated ListInt list_list_i = 1; // list list int + } + + oneof value + { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; // Used to support attr nesting + TensorDescriptor td = 11; // GeTensorDesc type + TensorDef t = 12; // GeTensor type + GraphDef g = 13; // Graph type + ListListInt list_list_int = 14; // List List Int type + int64 dt = 15; // ge::DataType + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs +{ + string name = 1; + map attr = 2; +} + +// Shape / dimension description, using row-major order +message ShapeDef +{ + repeated int64 dim = 1; // Size of each dimension +} + +// Multidimensional data description +message TensorDescriptor +{ + string name = 1; // Optional parameter, tensor name + + DataType dtype = 2; // tensor datatype + ShapeDef shape = 3; // Shape / dimension + string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" + + bool has_out_attr = 9; + int64 size = 10; + int64 weight_size = 11; + bool reuse_input = 12; + bool output_tensor = 13; + string device_type = 14; + bool input_tensor =15; + int64 real_dim_cnt = 16; + int64 reuse_input_index = 17; + int64 data_offset = 18; + int64 cmps_size = 19; + string cmps_tab = 20; + int64 cmps_tab_offset = 21; + + map attr = 5; // Set of extra parameter fields +} + +// GeTensor definition +message TensorDef +{ + TensorDescriptor desc = 1; // Tensor description + bytes data = 2; // Tensor data +} + + +// Operator description +message OpDef +{ + string name = 1; // name + string type = 2; // type + + repeated string input = 5; // input original op name + outgoing index. op_name:index + + map attr = 10; // Set of operator parameter fields + + bool has_out_attr = 20; + int64 id = 21; + int64 stream_id =22; + repeated string input_name = 23; + repeated string src_name = 24; + repeated int64 src_index = 25; + repeated string dst_name = 26; + repeated int64 dst_index = 27; + repeated int64 input_i = 28; + repeated int64 output_i = 29; + repeated int64 workspace = 30; + repeated int64 workspace_bytes = 31; + repeated bool is_input_const = 32; + repeated TensorDescriptor input_desc = 33; + repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; +} + +// Graph definition +message GraphDef +{ + string name = 1; // name + + repeated string input = 4; // Graph input + repeated string output = 5; // Graph output + + repeated OpDef op = 6; // List of operators + + map attr = 11; // Extended field +} + +// model definition +message ModelDef +{ + string name = 1; // name + uint32 version = 2; // IR Proto verion + string custom_version = 3; // User model version number, passed in by user + + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef + + map attr = 11; // Extended field +} + diff --git a/parser/tensorflow/proto/insert_op.proto b/parser/tensorflow/proto/insert_op.proto new file mode 100644 index 0000000..c635ca1 --- /dev/null +++ b/parser/tensorflow/proto/insert_op.proto @@ -0,0 +1,136 @@ +syntax = "proto3"; + +package domi; + +message InsertNewOps { + repeated AippOpParams aipp_op = 1; + repeated MultiShapeOpParams multi_shape_op = 2; +} + +message AippOpParams { + enum InputFormat { + UNDEFINED = 0; + YUV420SP_U8 = 1; + XRGB8888_U8 = 2; + RGB888_U8 = 3; + YUV400_U8 = 4; + NC1HWC0DI_FP16 = 5; + NC1HWC0DI_S8 = 6; + ARGB8888_U8 = 7; + YUYV_U8 = 8; + YUV422SP_U8 = 9; + AYUV444_U8 = 10; + RAW10 = 11; + RAW12 = 12; + RAW16 = 13; + RAW24 = 14; + RGB16 = 15; + RGB20 = 16; + RGB24 = 17; + RGB8_IR = 18; + RGB16_IR = 19; + RGB24_IR = 20; + } + + enum AippMode { + undefined = 0; + static = 1; + dynamic = 2; + } + + // AIPPģʽ־̬AIPPͶ̬AIPP + AippMode aipp_mode = 1; + + // related_input_rankΪΪͣ÷Χ>=0, <=DataӵĸĬֵΪ0 + // ʶģ͵ĵڼAIPPģ룬ҪԵ2AIPPrelated_input_rankΪ1 + uint32 related_input_rank = 2; + + // input_edge_idxΪѡΪͣ÷ΧΪ>=0 + // øòãڶDataӲͬͬAIPPòûãĬ϶related_input_rankָģAIPP + // ֵ <= Dataߵĸ + repeated uint32 input_edge_idx = 3; + + // [Begin] ̬AIPPþ̬AIPPʱЧ + uint32 max_src_image_size = 4; + + // Ƿ֧תĬϲ֧֣֧תʱжĿռʧ + bool support_rotation = 5; + + // [End] ̬AIPP + + + // [Begin] ̬AIPPö̬AIPPʱЧ + InputFormat input_format = 51; + bool csc_switch = 52; + float cpadding_value = 53; + bool rbuv_swap_switch = 54; + bool ax_swap_switch = 55; + bool single_line_mode = 56; + + int32 src_image_size_w = 57; + int32 src_image_size_h = 58; + + bool crop = 59; + int32 load_start_pos_w = 60; + int32 load_start_pos_h = 61; + int32 crop_size_w = 62; + int32 crop_size_h = 63; + + bool resize = 64; + int32 resize_output_w = 65; + int32 resize_output_h = 66; + + bool padding = 67; + int32 left_padding_size = 68; + int32 right_padding_size = 69; + int32 top_padding_size = 70; + int32 bottom_padding_size = 71; + + int32 mean_chn_0 = 10; + int32 mean_chn_1 = 11; + int32 mean_chn_2 = 12; + int32 mean_chn_3 = 19; + float min_chn_0 = 13; + float min_chn_1 = 14; + float min_chn_2 = 15; + float min_chn_3 = 20; + repeated float var_reci_chn_0 = 16; + repeated float var_reci_chn_1 = 17; + repeated float var_reci_chn_2 = 18; + repeated float var_reci_chn_3 = 21; + + repeated int32 matrix_r0c0 = 30; + repeated int32 matrix_r0c1 = 31; + repeated int32 matrix_r0c2 = 32; + repeated int32 matrix_r1c0 = 33; + repeated int32 matrix_r1c1 = 34; + repeated int32 matrix_r1c2 = 35; + repeated int32 matrix_r2c0 = 36; + repeated int32 matrix_r2c1 = 37; + repeated int32 matrix_r2c2 = 38; + repeated int32 output_bias_0 = 39; + repeated int32 output_bias_1 = 40; + repeated int32 output_bias_2 = 41; + repeated int32 input_bias_0 = 42; + repeated int32 input_bias_1 = 43; + repeated int32 input_bias_2 = 44; + + // [End] ̬AIPP + + // The n number that is used for raw/rgbir data into f16 transformation. + // The transformation equation is x/(2^n). If set to 0, no transform is performed. + uint32 raw_rgbir_to_f16_n = 45; +} + +message MultiShapeOpParams { + enum MultiShapeMode { + batch = 0; //̬batch + resolution = 1; //ֱ̬ʣչ + } + + MultiShapeMode mode = 1; //ģʽ + uint32 related_input_rank = 2; //Ӳ뵽ĸ + + + repeated uint32 batch_list = 11; //batch_listֵbatch_listĸ28֮ +} diff --git a/parser/tensorflow/proto/om.proto b/parser/tensorflow/proto/om.proto new file mode 100644 index 0000000..e15e5f8 --- /dev/null +++ b/parser/tensorflow/proto/om.proto @@ -0,0 +1,396 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +enum TargetType +{ + MINI = 0; + TINY = 1; + LITE = 2; +} + +// offline model +message ModelDef { + string name = 1; + uint32 version = 2; + + uint64 memory_size = 10; + uint32 stream_num = 11; + uint32 event_num = 12; + uint64 weight_size = 13; + uint32 label_num = 15; + repeated OpDef op = 20; + TargetType target_type = 23; + + map attr = 30; +}; + +// operator define +message OpDef { + string name = 1; + string type = 2; + + uint32 id = 3; + uint32 stream_id = 4; + + repeated string input_name = 5; + + repeated string src_name = 8; + repeated int32 src_index = 9; + repeated int64 input = 10; + repeated int64 output = 11; + repeated TensorDescriptor input_desc = 12; + repeated TensorDescriptor output_desc = 13; + repeated WeightDef weights = 14; + repeated string dst_name = 15; + repeated int32 dst_index = 16; + + repeated int64 workspace = 20; + repeated uint32 workspace_bytes = 21; + + repeated string weight_name = 22; + repeated bool is_input_const = 23; + + map attr = 30; + + QuantizeFactorParams quantize_factor = 31; + + oneof op_params { + // start at 100 here + SendOpParams sender_param = 100; + RecvOpParams receiver_param = 200; + ConvolutionOpParams convolution_param = 300; + PoolingOpParams pooling_param = 400; + EltwiseOpParams eltwise_param = 500; + BatchNormOpParams batchnorm_param = 600; + ScaleOpParams scale_param = 700; + FullConnectionOpParams full_connection_param = 800; + SoftmaxOpParams softmax_param = 900; + ActivationOpParams activation_param = 1000; + ReshapeOpParams reshape_param = 1100; + } +}; + +message SendOpParams { + uint32 event_id = 1; +}; + +message RecvOpParams { + uint32 event_id = 1; +}; + +enum QuantizeScaleType +{ + VECTOR_SCALE = 0; + SCALAR_SCALE = 1; +} + +enum QuantizeScaleMode +{ + NORMAL_MODE = 0; + SQRT_MODE = 1; +} + +enum QuantizeAlgorithm +{ + NON_OFFSET_ALGO = 0; + HALF_OFFSET_ALGO = 1; + ALL_OFFSET_ALGO = 2; +} +message QuantizeFactor +{ + QuantizeScaleMode scale_mode = 1; + bytes scale_value = 2; + int64 scale_offset = 3; + bytes offset_data_value = 4; + int64 offset_data_offset = 5; + bytes offset_weight_value = 6; + int64 offset_weight_offset = 7; + bytes offset_pad_value = 8; + int64 offset_pad_offset = 9; +}; + +message QuantizeCalcFactor +{ + bytes offsetw = 1; + int64 offsetw_offset = 2; + bytes offsetd = 3; + int64 offsetd_offset = 4; + bytes scalereq = 5; + int64 scaledreq_offset = 6; + bytes offsetdnext = 7; + int64 offsetdnext_offset = 8; +} + +message QuantizeFactorParams +{ + QuantizeAlgorithm quantize_algo = 1; + QuantizeScaleType scale_type = 2; + QuantizeFactor quantize_param = 3; + QuantizeFactor dequantize_param = 4; + QuantizeFactor requantize_param = 5; + QuantizeCalcFactor quantizecalc_param = 6; +}; + +message ConvolutionOpParams { + int32 mode = 1; + int32 algo = 2; + int32 pad_mode = 3; + uint32 group = 4; + uint32 num_output = 5; + + repeated uint32 pad = 10; + repeated uint32 stride = 11; + repeated uint32 dilation = 12; + repeated uint32 kernel = 13; + + float alpha = 20; + float beta = 21; + + WeightDef filter = 40; + WeightDef bias = 41; + + bool relu_flag = 62; + repeated uint32 adj = 70; + repeated uint32 target_shape = 71; + repeated uint32 before_pad = 72; +}; + +message PoolingOpParams { + int32 mode = 1; + int32 nan_opt = 2; + int32 pad_mode = 3; + bool global_pooling = 4; + + repeated uint32 window = 10; + repeated uint32 pad = 11; + repeated uint32 stride = 12; + bool ceil_mode = 13; + int32 data_mode = 14; + + float alpha = 20; + float beta = 21; + repeated uint32 before_pad = 22; +}; + +message EltwiseOpParams { + int32 mode = 1; + repeated float coeff = 2; + float alpha = 3; + float beta = 4; + repeated WeightDef weight = 5; + bool relu_flag = 6; +}; + +message ActivationOpParams { + int32 mode = 1; + float coef = 2; + float alpha = 3; + float beta = 4; +}; + +message BatchNormOpParams { + int32 mode = 1; + + float alpha = 2; + float beta = 3; + double epsilon = 4;//optinal,[default = 1e-5] + bool use_global_stats = 5; //optinal,by default true,testing mode + float moving_average_fraction = 6; //optinal,[default = .999]; + + WeightDef estimated_mean = 7; + WeightDef estimated_variance = 8; + + WeightDef scale = 9; + WeightDef bias = 10; +}; + +message ScaleOpParams { + WeightDef scale = 1; + WeightDef bias = 2; +}; + +message ReshapeOpParams { + float alpha = 1; + float beta = 2; + ShapeDef shape = 3; + int32 axis = 4; + int32 num_axes = 5; + int32 format = 6; +}; + +message SoftmaxOpParams { + int32 algo = 1; + int32 mode = 2; + float alpha = 3; + float beta = 4; +}; + +message FullConnectionOpParams { + WeightDef filter = 1; + WeightDef bias = 2; + uint32 num_output = 3; + bool relu_flag = 12; +}; + +message FlattenOpParams { + float alpha = 1; + float beta = 2; + int32 start_axis = 3; + int32 end_axis = 4; +} + +message AddLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message MulLimitedOpParams { + float alpha = 1; + float beta = 2; + int32 axis = 3; + bool broadcast = 4; + + repeated WeightDef weight = 10; +}; + +message AddOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message MulOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message SubOpParams { + float alpha = 1; + float beta = 2; + + repeated WeightDef weight = 10; +}; + +message BiasAddOpParams { + float alpha = 1; + float beta = 2; + + WeightDef bias = 10; +}; + +message MatMulOpParams { + float alpha = 1; + float beta = 2; + bool transposeX = 3; + bool transposeW = 4; + + WeightDef filter = 10; + WeightDef bias = 12; +}; + +message RsqrtOpParams { + float alpha = 1; + float beta = 2; +}; + + +message WeightDef { + int32 format = 1; + int32 data_type = 2; + ShapeDef shape = 3; + bytes data = 4; + int64 data_offset = 5; + uint32 cmps_size = 6; + bytes cmps_tab = 7; + int64 cmps_tab_offset = 10; + CompressInfo cmps_info = 8; + AllOffsetQuantizeInfo alloffset_quantize_info = 11; +} + +message ShapeDef { + repeated int64 dim = 1; +} + +enum DeviceType { + NPU = 0; // In default, we will use NPU. + CPU = 1; // CPU +} + +message AllOffsetQuantizeInfo { + float scale = 1; + int32 offset = 2; +} + +message TensorDescriptor { + int32 format = 1; + int32 data_type = 2; + repeated int64 dim = 3; + uint32 size = 4; + bool reuse_input = 5; + bool output_tensor = 7; + DeviceType device_type = 8; + bool input_tensor = 9; + uint32 real_dim_cnt = 10; + uint32 reuse_input_index = 11; + AllOffsetQuantizeInfo alloffset_quantize_info = 12; +} + +message CompressInfo { + int32 blockRow = 1; // block row + int32 blockCol = 2; // block col + int32 fractalK = 3; // fractal K + int32 fractalN = 4; // fractal N + int32 lastFractalK = 5; // K of last fractal + int32 lastFractalN = 6; // N of last fractal + int32 cubeSize = 7; // cube's length + int32 loadDir = 8; // data load directtiono 0:col load 1:row load +} + +message AttrDef { + message ListValue { + repeated string s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated uint32 u = 6 [packed = true]; // "list(uint)" + repeated bytes bt = 7; + } + + oneof value { + string s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + uint32 u = 6; // "uint32" + bytes bt = 7; + ListValue list = 1; // any "list(...)" + NamedAttrs func = 10; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NamedAttrs { + string name = 1; + map attr = 2; +} + diff --git a/parser/tensorflow/proto/task.proto b/parser/tensorflow/proto/task.proto new file mode 100644 index 0000000..d0c0984 --- /dev/null +++ b/parser/tensorflow/proto/task.proto @@ -0,0 +1,165 @@ +/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Apache License for more details at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +syntax = "proto3"; + +package domi; + +message ModelTaskDef { + string version = 1; + + map attr = 9; // Extended field + repeated TaskDef task = 10; + + uint64 memory_size = 11; + uint32 stream_num = 12; + uint32 event_num = 13; + uint64 weight_size = 14; + + repeated bytes op = 15; // input/output opdef in bytes + + uint64 base_addr = 16; // base addr + uint64 weight_addr = 17; // weight addr + uint32 batch_num = 18; +} + + +message TaskDef { + uint32 id = 1; + uint32 type = 2; + + uint32 stream_id = 10; + uint32 event_id = 11; + + KernelDef kernel = 20; + KernelExDef kernel_ex = 21; + KernelHcclDef kernel_hccl = 25; + EventExDef event_ex = 26; + LogTimeStampDef log_timestamp = 28; + + uint32 label_id = 30; + + MemcpyAsyncDef memcpy_async = 31; + StreamSwitchDef stream_switch = 32; + StreamActiveDef stream_active = 33; + bytes private_def = 34; + uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future + StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; +} + +message KernelDef { + KernelContext context = 1; + + string stub_func = 10; + uint32 block_dim = 11; + uint32 args_size = 12; + bytes args = 13; + bytes sm_desc = 14; + bytes flowtable = 15; + string so_name = 16; + string kernel_name = 17; + bytes kernel_ext_info = 18; + uint32 kernel_ext_info_size = 19; +} + +message KernelContext { + uint32 kernel_type = 1; + uint32 op_id = 2; // OP type in CCE + uint32 kernel_func_id = 3; + uint32 op_index = 4; // TE/Custom operator + bool is_flowtable = 5; // Identify whether args is a flowtable structure + bytes args_offset = 6; // args offset information + uint32 args_count = 7; // args count + repeated uint32 origin_op_index = 8; +} + + +message KernelExDef { + uint32 flags = 1; + + uint32 op_index = 4; + uint32 args_size = 12; + bytes args = 13; + bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput + uint32 task_info_size = 15; + bytes kernel_ext_info = 16; + uint32 kernel_ext_info_size = 17; +} + + +message KernelHcclDef { + uint32 op_index = 8; + string hccl_type = 9; +} + + +message EventExDef { + uint32 op_index = 1; + uint32 event_type = 2; +} + +message LogTimeStampDef { + uint64 logid = 1; + bool notify = 2; + uint32 flat = 3; +} + +message MemcpyAsyncDef { + uint64 dst = 1; + uint64 dst_max = 2; + uint64 src = 3; + uint64 count = 4; + uint32 kind = 5; + uint32 op_index = 6; +} + +message StreamSwitchDef { + uint32 op_index = 1; + uint32 true_stream_id = 2; + int64 value = 3; + uint64 value_ptr = 4; + uint32 data_type = 5; +} + +message StreamActiveDef { + uint32 op_index = 1; + uint32 active_stream_id = 2; +} + +message StreamSwitchNDef { + uint32 op_index = 1; + uint32 size = 2; + repeated int64 target_value = 3; + repeated uint32 true_stream_id = 4; + uint32 element_size = 5; + uint32 data_type = 6; +} + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/parser/tensorflow/proto/tensorflow/attr_value.proto b/parser/tensorflow/proto/tensorflow/attr_value.proto new file mode 100644 index 0000000..1cc67d6 --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensor.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + repeated NameAttrList func = 9; // "list(attr)" + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/parser/tensorflow/proto/tensorflow/function.proto b/parser/tensorflow/proto/tensorflow/function.proto new file mode 100644 index 0000000..075897c --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/function.proto @@ -0,0 +1,100 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "node_def.proto"; +import "op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. + reserved 2; + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + + // By convention, "op" in node_def is resolved by consulting with a + // user-defined library first. If not resolved, "func" is assumed to + // be a builtin op. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/parser/tensorflow/proto/tensorflow/graph.proto b/parser/tensorflow/proto/tensorflow/graph.proto new file mode 100644 index 0000000..d639a7d --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "node_def.proto"; +import "function.proto"; +import "versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/parser/tensorflow/proto/tensorflow/graph_library.proto b/parser/tensorflow/proto/tensorflow/graph_library.proto new file mode 100644 index 0000000..e393d38 --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/graph_library.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package domi.tensorflow; + +import "graph.proto"; + +message GeGraphDef { + string name = 1; + GraphDef graph = 2; +} + +message GraphDefLibrary { + repeated GeGraphDef graph_def = 1; +}; \ No newline at end of file diff --git a/parser/tensorflow/proto/tensorflow/node_def.proto b/parser/tensorflow/proto/tensorflow/node_def.proto new file mode 100644 index 0000000..b9bc97e --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/node_def.proto @@ -0,0 +1,63 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= PARTIAL_SPEC + // + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) + // * "/job:worker/device:GPU:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // Add some examples here showing best practices. + map attr = 5; +}; diff --git a/parser/tensorflow/proto/tensorflow/op_def.proto b/parser/tensorflow/proto/tensorflow/op_def.proto new file mode 100644 index 0000000..3485d04 --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/op_def.proto @@ -0,0 +1,164 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "attr_value.proto"; +import "types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +// LINT.IfChange +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // Ops are marked as stateful if their behavior depends on some state beyond + // their input tensors (e.g. variable reading op) or if they have + // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops + // must always produce the same output for the same input and have + // no side-effects. + // + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/parser/tensorflow/proto/tensorflow/resource_handle.proto b/parser/tensorflow/proto/tensorflow/resource_handle.proto new file mode 100644 index 0000000..a345235 --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandle"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandleProto { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/parser/tensorflow/proto/tensorflow/tensor.proto b/parser/tensorflow/proto/tensorflow/tensor.proto new file mode 100644 index 0000000..d0a4d02 --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/tensor.proto @@ -0,0 +1,94 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "resource_handle.proto"; +import "tensor_shape.proto"; +import "types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized raw tensor content from either Tensor::AsProtoTensorContent or + // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation + // can be used for all tensor types. The purpose of this representation is to + // reduce serialization overhead during RPC call by avoiding serialization of + // many repeated small items. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll + // have some pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; + + // DT_UINT32 + repeated uint32 uint32_val = 16 [packed = true]; + + // DT_UINT64 + repeated uint64 uint64_val = 17 [packed = true]; +}; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/parser/tensorflow/proto/tensorflow/tensor_shape.proto b/parser/tensorflow/proto/tensorflow/tensor_shape.proto new file mode 100644 index 0000000..4225a2e --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package domi.tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/parser/tensorflow/proto/tensorflow/types.proto b/parser/tensorflow/proto/tensorflow/types.proto new file mode 100644 index 0000000..ba7a72b --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/types.proto @@ -0,0 +1,74 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types + DT_UINT32 = 22; + DT_UINT64 = 23; + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; + DT_UINT32_REF = 122; + DT_UINT64_REF = 123; +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/c/c_api.h, +// https://www.tensorflow.org/code/tensorflow/go/tensor.go, +// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, +// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, +// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, +// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/parser/tensorflow/proto/tensorflow/versions.proto b/parser/tensorflow/proto/tensorflow/versions.proto new file mode 100644 index 0000000..4806121 --- /dev/null +++ b/parser/tensorflow/proto/tensorflow/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package domi.tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/parser/tensorflow/scope/scope_pass_manager.cc b/parser/tensorflow/scope/scope_pass_manager.cc new file mode 100644 index 0000000..84ecbe0 --- /dev/null +++ b/parser/tensorflow/scope/scope_pass_manager.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/scope/scope_pass_manager.h" +#include "common/ge/ge_util.h" +#include "common/util.h" +#include "common/util/error_manager/error_manager.h" +#include "framework/common/debug/ge_log.h" +#include "register/scope/scope_graph_impl.h" +#include "register/scope/scope_pass_impl.h" + +namespace ge { +shared_ptr ScopePassManager::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) { + GE_CHK_BOOL_EXEC(graph_def != nullptr, return nullptr, "graph_def is nullptr"); + scope_graph_ = ge::MakeShared(); + if (scope_graph_ == nullptr) { + GELOGE(FAILED, "Scope graph make shared failed."); + return nullptr; + } + Status ret = scope_graph_->Init(); + if (ret != SUCCESS) { + GELOGE(FAILED, "Scope graph init failed."); + return nullptr; + } + + auto &impl = scope_graph_->impl_; + impl->BuildScopeGraph(graph_def); + + return scope_graph_; +} + +Status ScopePassManager::AddPass(unique_ptr &pass) { + GE_CHECK_NOTNULL(pass); + + graph_passes_.push_back(std::move(pass)); + + return SUCCESS; +} + +Status ScopePassManager::Run(shared_ptr &graph) { + GE_CHECK_NOTNULL(graph); + bool not_changed = true; + + for (auto &pass : graph_passes_) { + GE_CHECK_NOTNULL(pass); + auto &impl = pass->impl_; + if (impl == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "ScopeBasePass is not properly initialized."); + continue; + } + + Status status = impl->Run(graph); + if (status == SUCCESS) { + GELOGI("Run scope pass:%s success.", pass->PassName().c_str()); + not_changed = false; + } else if (status != domi::SCOPE_NOT_CHANGED) { + // exception + ErrorManager::GetInstance().ATCReportErrMessage("E12003", {"passname"}, {pass->PassName()}); + GELOGE(FAILED, "Pass Run failed, pass name:%s", pass->PassName().c_str()); + return status; + } + } + + return not_changed ? domi::SCOPE_NOT_CHANGED : SUCCESS; +} +} // namespace ge diff --git a/parser/tensorflow/scope/scope_pass_manager.h b/parser/tensorflow/scope/scope_pass_manager.h new file mode 100644 index 0000000..45342e9 --- /dev/null +++ b/parser/tensorflow/scope/scope_pass_manager.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_TENSORFLOW_SCOPE_SCOPE_PASS_MANAGER_H_ +#define PARSER_TENSORFLOW_SCOPE_SCOPE_PASS_MANAGER_H_ + +#include +#include "external/register/scope/scope_fusion_pass_register.h" +#include "proto/tensorflow/graph.pb.h" + +using std::shared_ptr; +using std::unique_ptr; + +namespace ge { +/** + * @ingroup domi_omg + * @brief manage passes + */ +class ScopePassManager { + public: + ScopePassManager() : scope_graph_(nullptr) {} + ScopePassManager(const ScopePassManager &scope_pass_manager) = delete; + ScopePassManager &operator=(const ScopePassManager &scope_pass_manager) = delete; + ~ScopePassManager() {} + + shared_ptr BuildScopeGraph(domi::tensorflow::GraphDef *graph_def); + + domi::Status AddPass(unique_ptr &pass); + domi::Status Run(shared_ptr &graph); + + std::shared_ptr scope_graph_; + + private: + std::vector> graph_passes_; +}; +} // namespace ge + +#endif // PARSER_TENSORFLOW_SCOPE_SCOPE_PASS_MANAGER_H_ diff --git a/parser/tensorflow/tensorflow_arg_parser.cc b/parser/tensorflow/tensorflow_arg_parser.cc new file mode 100644 index 0000000..35a34c8 --- /dev/null +++ b/parser/tensorflow/tensorflow_arg_parser.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/debug/log.h" +#include "parser/common/op_def/arg_op.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "parser/common/op_parser_factory.h" +#include "parser/tensorflow/tensorflow_op_parser.h" +#include "parser/tensorflow/tensorflow_parser_register.h" + +using domi::tensorflow::AttrValue; + +namespace ge { +namespace { +const char *const kSerializeFormat = "serialize_format"; +} // namespace +Status ParseParams(const Message *op_src, ArgOpOperator *op) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op); + const NodeDef *node = reinterpret_cast(op_src); + GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); + domi::tensorflow::AttrValue output_attr_value; + if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { + GE_CHK_STATUS_RET( + TensorFlowUtil::TransTensorDescriptor(output_attr_value, op, TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG), + "trans output_attr_value failed, op: %s", node->name().c_str()); + // For the needs of the Data operator, copy the output description to the input description + GE_CHK_STATUS_RET(TensorFlowUtil::TransTensorDescriptor(output_attr_value, op, TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG), + "trans output_attr_value failed, op: %s", node->name().c_str()); + + domi::tensorflow::AttrValue_ListValue attr_list = output_attr_value.list(); + GetParserContext().format = + static_cast(attr_list.func(0).attr().at(kSerializeFormat).i()); + } else { + /// _Arg constructed from inference function do not has input_tensor_dec + /// set input & output tensor desc for adding input & output tensor desc for op desc + ge::GeTensorDesc tensor_desc; + op->InputTensorDesc(tensor_desc); + op->OutputTensorDesc(tensor_desc); + } + + domi::tensorflow::AttrValue index_attr_value; + if (TensorFlowUtil::FindAttrValue(node, ATTR_NAME_INDEX, index_attr_value)) { + op->Index(index_attr_value.i()); + } + + GELOGI("In _ArgOp trans success.op name : %s.", node->name().c_str()); + + return SUCCESS; +} + +DOMI_REGISTER_TENSORFLOW_PARSER(ge::parser::ARG, ArgOpOperator).SetParseParamsFn(ParseParams); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc new file mode 100644 index 0000000..e9fe078 --- /dev/null +++ b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow_auto_mapping_parser_adapter.h" + +#include "framework/omg/parser/parser_types.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "parser/common/op_parser_factory.h" +#include "register/op_registry.h" +#include "register/register.h" + + +using domi::TENSORFLOW; +using namespace ge::parser; + +using ge::parser::PLACEHOLDERWITHDEFAULT; + +namespace ge { +namespace { +const char *const kTfAttrT = "T"; +} // namespace + +Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { + if (op_src == nullptr) { + GELOGE(PARAM_INVALID, "Op src is null"); + return PARAM_INVALID; + } + const NodeDef *node = reinterpret_cast(op_src); + GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); + if (op_dest == nullptr) { + GELOGE(FAILED, "Op dest is null"); + return PARAM_INVALID; + } + + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + Status ret = domi::AutoMappingFn(op_src, op); + if (ret != SUCCESS) { + GELOGE(FAILED, "Tensorflow auto mapping parser params failed"); + return FAILED; + } + op.BreakConnect(); + + // add dynamic input/output + if (op_dest->GetType() == IDENTITYN) { + uint32_t dynamic_tensor_num = 0; + domi::tensorflow::AttrValue attr_num; + if (!(TensorFlowUtil::FindAttrValue(node, kTfAttrT, attr_num))) { + GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_dest->GetName().c_str(), kTfAttrT); + } + dynamic_tensor_num = attr_num.list().type_size(); + + GE_CHK_STATUS_RET(op_dest->AddDynamicInputDesc("x", dynamic_tensor_num), "AddDynamicInputDesc failed"); + GE_CHK_STATUS_RET(op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num), "AddDynamicInputDesc failed"); + GELOGI("add dynamic intput and output for op [%s], type[%s], number:%u", op_dest->GetName().c_str(), + op_dest->GetType().c_str(), dynamic_tensor_num); + } + + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, PLACEHOLDERWITHDEFAULT, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, EXPANDDIMS, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SIZE, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SHAPE, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, GUARANTEECONST, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, BROADCASTARGS, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, PREVENTGRADIENT, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, RANK, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, BROADCASTGRADIENTARGS, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, STOPGRADIENT, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, DESTROYTEMPORARYVARIABLE, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SNAPSHOT, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, EMPTY, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, IDENTITYN, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, CONTROLTRIGGER, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SWITCH, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, LOOPCOND, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, NEXTITERATION, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFNEXTITERATION, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, EXIT, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFEXIT, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, CONSTANT, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, PARALLELCONCATSTART, TensorFlowAutoMappingParserAdapter); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, BITCAST, TensorFlowAutoMappingParserAdapter); +} diff --git a/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h new file mode 100644 index 0000000..3afea84 --- /dev/null +++ b/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_AUTO_MAPPING_PARSER_ADAPTER_H_ +#define GE_PARSER_TENSORFLOW_TENSORFLOW_AUTO_MAPPING_PARSER_ADAPTER_H_ + +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowAutoMappingParserAdapter : public TensorFlowOpParser { + public: + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; +}; +} // namespace ge +#endif // GE_PARSER_TENSORFLOW_TENSORFLOW_AUTO_MAPPING_PARSER_ADAPTER_H_ diff --git a/parser/tensorflow/tensorflow_constant_parser.cc b/parser/tensorflow/tensorflow_constant_parser.cc new file mode 100644 index 0000000..20216d2 --- /dev/null +++ b/parser/tensorflow/tensorflow_constant_parser.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_constant_parser.h" +#include +#include +#include +#include "common/debug/log.h" +#include "common/ge/ge_util.h" +#include "common/op/ge_op_utils.h" +#include "parser/common/op_def/constant_op.h" +#include "parser/common/op_def/ir_pb_converter.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_tensor.h" +#include "graph/utils/attr_utils.h" +#include "parser/common/op_parser_factory.h" +#include "framework/omg/parser/parser_types.h" +#include "register/tensor_assign.h" + +using domi::tensorflow::NodeDef; +using domi::TENSORFLOW; +using ge::parser::CONSTANTOP; + +namespace ge { +Status TensorFlowConstantParser::ParseDType(const domi::tensorflow::NodeDef *node, ConstantOperator *op) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(op); + domi::tensorflow::AttrValue attr; + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_DTYPE, attr), + op->DType(domi::TensorAssign::ConvertTensorflowDataType(domi::tensorflow::DT_FLOAT)); + return SUCCESS); + + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_TYPE), + "check Attr dtype fail"); + + domi::tensorflow::DataType tf_type = attr.type(); + ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_type); + + op->DType(type); + + return SUCCESS; +} + +Status TensorFlowConstantParser::ParseValue(const domi::tensorflow::NodeDef *node, const ge::OpDescPtr &opDesc) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(opDesc); + domi::tensorflow::AttrValue attr_value; + // Check that the attribute value must exist and get the value of value + GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_VALUE, attr_value), + domi::FAILED, "nodeDef %s Attr %s is not exist.", node->name().c_str(), + TENSORFLOW_ATTR_VALUE.c_str()); + // Check that the value attribute must be tensor + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR), + "check Attr %s failed", TENSORFLOW_ATTR_VALUE.c_str()); + + const domi::tensorflow::TensorProto &tensor = attr_value.tensor(); + + GeTensorPtr weight = ge::MakeShared(); + GE_CHECK_NOTNULL(weight); + int64_t dataType = 0; + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(opDesc, TENSORFLOW_ATTR_DTYPE, dataType), INTERNAL_ERROR, + "get dtype fail"); + GE_CHK_STATUS_RET(domi::TensorAssign::SetGeTensorDataType(dataType, weight), "set ge tensor data type fail"); + + GE_CHK_STATUS_RET(domi::TensorAssign::SetGeTensor(tensor, weight), "set ge tensor fail"); + GELOGI("TensorFlowConstantParser::ParseValue. TF op node name = %s", opDesc->GetName().c_str()); + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensor(opDesc, ATTR_NAME_WEIGHTS, weight), INTERNAL_ERROR, + "set tensor fail"); + return domi::SUCCESS; +} + +Status TensorFlowConstantParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { + GE_CHECK_NOTNULL(op_dest); + const NodeDef *node = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(node); + GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); + ConstantOperator op; + op.Name(node->name()); + + GE_RETURN_WITH_LOG_IF_ERROR(ParseDType(node, &op), "Parse dtype for node %s failed.", node->name().c_str()); + GE_CHK_STATUS_RET(ConvertToOpDesc(op, op_dest), "ConvertToOpDesc ret fail"); + GE_CHK_STATUS_RET(ParseValue(node, op_dest), "ParseValue ret fail"); + for (const auto &output_desc : op_dest->GetAllOutputsDescPtr()) { + // Fixed input ND + output_desc->SetFormat(ge::Format::FORMAT_ND); + output_desc->SetOriginFormat(ge::Format::FORMAT_ND); + } + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, CONSTANTOP, TensorFlowConstantParser); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_constant_parser.h b/parser/tensorflow/tensorflow_constant_parser.h new file mode 100644 index 0000000..319c845 --- /dev/null +++ b/parser/tensorflow/tensorflow_constant_parser.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ +#define GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ + +#include +#include "common/op_def/constant_op.h" +#include "parser/common/data_op_parser.h" +#include "parser/tensorflow/tensorflow_op_parser.h" + +using domi::tensorflow::NodeDef; + +namespace ge { +class TensorFlowConstantParser : public TensorFlowOpParser { + public: + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; + + private: + Status ParseDType(const domi::tensorflow::NodeDef *node, ConstantOperator *op); + Status ParseValue(const domi::tensorflow::NodeDef *node, const ge::OpDescPtr &opDesc); +}; +} // namespace ge + +#endif // GE_PARSER_TENSORFLOW_TENSORFLOW_CONSTANT_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_custom_parser_adapter.cc b/parser/tensorflow/tensorflow_custom_parser_adapter.cc new file mode 100644 index 0000000..856b96b --- /dev/null +++ b/parser/tensorflow/tensorflow_custom_parser_adapter.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_custom_parser_adapter.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "parser/common/op_parser_factory.h" +#include "register/op_registry.h" + +using domi::ParseParamFunc; +using domi::ParseParamByOpFunc; + +namespace ge { +Status TensorFlowCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { + GE_CHECK_NOTNULL(op_src); + const NodeDef *node_src = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(node_src); + GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); + GE_CHECK_NOTNULL(op_dest); + + ParseParamFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest->GetType(), node_src->op()); + GE_CHECK_NOTNULL(custom_op_parser); + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + GE_CHK_BOOL_RET_STATUS(custom_op_parser(op_src, op) == SUCCESS, FAILED, "Custom parser params failed"); + + op.BreakConnect(); + + return SUCCESS; +} + +Status TensorFlowCustomParserAdapter::ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest) { + GELOGI("Tensorflow custom op begin to parse params: op node name = %s, op type = %s.", + op_src.GetName().c_str(), op_src.GetOpType().c_str()); + GE_CHECK_NOTNULL(op_dest); + + ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); + GE_CHECK_NOTNULL(custom_op_parser); + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + GE_CHK_BOOL_RET_STATUS(custom_op_parser(op_src, op) == SUCCESS, FAILED, "Custom parser params failed"); + + op_src.BreakConnect(); + + return SUCCESS; +} +} // namespace ge diff --git a/parser/tensorflow/tensorflow_custom_parser_adapter.h b/parser/tensorflow/tensorflow_custom_parser_adapter.h new file mode 100644 index 0000000..3653718 --- /dev/null +++ b/parser/tensorflow/tensorflow_custom_parser_adapter.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_CUSTOM_PARSER_ADAPTER_H_ +#define GE_PARSER_TENSORFLOW_TENSORFLOW_CUSTOM_PARSER_ADAPTER_H_ + +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowCustomParserAdapter : public TensorFlowOpParser { + public: + /** + * @ingroup domi_omg + * @brief Parsing model file information + * @param [in] op_src model data to be parsed + * @param [out] op_dest model data after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; + + /** + * @ingroup domi_omg + * @brief parse params of the operation + * @param [in] op_src params to be parsed + * @param [out] op_dest params after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest); +}; +} // namespace ge + +#endif // GE_PARSER_TENSORFLOW_TENSORFLOW_CUSTOM_PARSER_ADAPTER_H_ diff --git a/parser/tensorflow/tensorflow_data_parser.cc b/parser/tensorflow/tensorflow_data_parser.cc new file mode 100644 index 0000000..8de0575 --- /dev/null +++ b/parser/tensorflow/tensorflow_data_parser.cc @@ -0,0 +1,152 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_data_parser.h" +#include +#include "common/debug/log.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "parser/common/op_parser_factory.h" +#include "framework/omg/parser/parser_types.h" + +using domi::tensorflow::AttrValue; +using domi::tensorflow::NodeDef; +using domi::TENSORFLOW; +using ge::parser::DATA; + +namespace ge { +namespace { +const int64_t kValidShapeMinValue = -2; +} // namespace +Status TensorFlowDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_def) { + GE_CHECK_NOTNULL(op_src); + const NodeDef *node_src = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(node_src); + GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); + GE_CHECK_NOTNULL(op_def); + GE_RETURN_WITH_LOG_IF_ERROR(ParseInputFromModel(op_src, op_def), "parse shape of data op %s from model failed", + op_def->GetName().c_str()); + + GE_RETURN_WITH_LOG_IF_ERROR(ParseInputFromUser(op_src, op_def), "parse shape of data op %s from user failed", + op_def->GetName().c_str()); + + GE_RETURN_WITH_LOG_IF_ERROR(CheckInputShape(op_def->GetName()), + "input node %s :check user designated input shape not match input shape defined in model", + op_def->GetName().c_str()); + + // Parse data dimension values and add them to op_def + GE_RETURN_WITH_LOG_IF_ERROR(ParseShape(user_input_dims_v, op_def), "TensorFlowDataParser::ParseShape failed"); + + return SUCCESS; +} + +Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, ge::OpDescPtr &op_def) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op_def); + + const NodeDef *node = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(node); + + domi::tensorflow::AttrValue attr_value; + if (TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_DTYPE, attr_value)) { + // Check dtype attribute must be type + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TYPE), + "check Attr %s failed", TENSORFLOW_ATTR_DTYPE.c_str()); + + domi::tensorflow::DataType tf_type = attr_value.type(); + ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_type); + CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED, + GELOGE(domi::PARAM_INVALID, + "Data type %s of node %s is not supported.", + DataType_Name(tf_type).c_str(), + node->name().c_str()); + return domi::PARAM_INVALID); + + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetInt(op_def, DATA_ATTR_NAME_DATA_TYPE, static_cast(type)), FAILED, + "SetInt failed"); + } + + if (!TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_SHAPE, attr_value)) { + GELOGE(domi::PARAM_INVALID, "input data node %s do not find shape.", node->name().c_str()); + return domi::PARAM_INVALID; + } + + // Check shape attribute must be shape + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_SHAPE), + "check Attr %s failed", TENSORFLOW_ATTR_SHAPE.c_str()); + + const TensorShapeProto &data_shape = attr_value.shape(); + for (auto i = 0; i < data_shape.dim_size(); i++) { + model_input_dims_v.push_back(data_shape.dim(i).size()); + } + + return SUCCESS; +} + +Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge::OpDescPtr &op_def) { + GE_CHECK_NOTNULL(op_def); + (void)op_src; + const ge::ParserContext &ctx = GetParserContext(); + std::unordered_map> input_dims = ctx.input_dims; + // User not designate the input_shape + std::string name = op_def->GetName(); + if (input_dims.count(name) == 0) { + GELOGI("input shape of node %s is not designated ,need parse from model", name.c_str()); + for (uint32_t i = 0; i < model_input_dims_v.size(); i++) { + user_input_dims_v.push_back(model_input_dims_v[i]); + } + + return SUCCESS; + } + + /* User designate the input_shape by passing '--input_shape=xxx:x,x,x,x' */ + // Two cases below both OK: + // 1. the input_shape not defined in the model(dimension is 0). + // 2. the input_shape defined in the model(dimension greater than 0), and the dimension matches with user + // designate_dim. + std::vector designated_dims = input_dims.at(name); + size_t input_dim_size_ = designated_dims.size(); + + GE_CHK_BOOL_RET_STATUS(model_input_dims_v.empty() || input_dim_size_ == model_input_dims_v.size(), + domi::PARAM_INVALID, + "user designated input_dim_num %zu does match input_dim_num %zu defined by model", + input_dim_size_, + model_input_dims_v.size()); + + // replace with the user designated_dims + user_input_dims_v.swap(designated_dims); + + return SUCCESS; +} + +Status TensorFlowDataParser::CheckInputShape(const std::string &name) { + const ge::ParserContext &ctx = GetParserContext(); + if (!ctx.is_dynamic_input) { + for (uint32_t i = 0; i < user_input_dims_v.size(); i++) { + // if input_shape has some placeholders, user should designate them. + // dim i = 0, means empty tensor. + // dim i = -1 or -2, means unknown shape. + GE_CHK_BOOL_RET_STATUS(user_input_dims_v[i] >= kValidShapeMinValue, domi::PARAM_INVALID, + "parse data node %s: shape contains placeholder ,but not designated by user", + name.c_str()); + } + } + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, DATA, TensorFlowDataParser); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_data_parser.h b/parser/tensorflow/tensorflow_data_parser.h new file mode 100644 index 0000000..baea90e --- /dev/null +++ b/parser/tensorflow/tensorflow_data_parser.h @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_DATA_PARSER_H_ +#define GE_PARSER_TENSORFLOW_TENSORFLOW_DATA_PARSER_H_ + +#include +#include +#include "parser/common/data_op_parser.h" +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowDataParser : public TensorFlowOpParser, public DataOpParser { + public: + /** + * @ingroup domi_omg + * @brief parse weight + * @param [in] v_input_const weight data to be parsed + * @param [out] op_dest weight data after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_def) override; + + private: + /** + * @ingroup domi_omg + * @brief Parsing input from model + * @param [in] op_src model to be parsed + * @param [out] op_def input information after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseInputFromModel(const Message *op_src, ge::OpDescPtr &op_def); + + /** + * @ingroup domi_omg + * @brief parse input set by users + * @param [in] op_src model to be parsed + * @param [out] op_def input information after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseInputFromUser(const Message *op_src, const ge::OpDescPtr &op_def); + + /** + * @ingroup domi_omg + * @brief Check whether the input shape entered by the user matches the input shape defined by the model + * @return SUCCESS match + * @return FAILED not match + * @author + */ + Status CheckInputShape(const std::string &name); + + std::vector model_input_dims_v; + + std::vector user_input_dims_v; +}; +} // namespace ge + +#endif // GE_PARSER_TENSORFLOW_TENSORFLOW_DATA_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_enter_parser.cc b/parser/tensorflow/tensorflow_enter_parser.cc new file mode 100644 index 0000000..af065bd --- /dev/null +++ b/parser/tensorflow/tensorflow_enter_parser.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_enter_parser.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "graph/debug/ge_attr_define.h" +#include "parser/common/op_parser_factory.h" +#include "framework/omg/parser/parser_types.h" + +using domi::TENSORFLOW; +using ge::parser::ENTER; +using ge::parser::REFENTER; + +namespace ge { +Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op_desc); + const std::string name = op_desc->GetName(); + + const NodeDef *node = reinterpret_cast(op_src); + domi::tensorflow::AttrValue str_attr; + if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { + GELOGE(FAILED, "In NodeDef %s attr [%s] not exist.", name.c_str(), ENTER_ATTR_FRAME_NAME.c_str()); + return FAILED; + } + std::string frame_name = str_attr.s(); + GELOGI("Enter node: %s, attr frame_name: %s", name.c_str(), frame_name.c_str()); + if (!ge::AttrUtils::SetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) { + GELOGE(FAILED, "Set attr ENTER_ATTR_FRAME_NAME fail, node: %s", name.c_str()); + return FAILED; + } + + domi::tensorflow::AttrValue bool_attr; + if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_CONSTANT_FLAG, bool_attr)) { + GELOGE(FAILED, "In NodeDef %s attr [%s] not exist.", name.c_str(), ENTER_ATTR_CONSTANT_FLAG.c_str()); + return FAILED; + } + bool is_constant = bool_attr.b(); + GELOGI("Enter node: %s, attr is_constant: %s", name.c_str(), is_constant ? "true" : "false"); + if (!ge::AttrUtils::SetBool(op_desc, ENTER_ATTR_CONSTANT_FLAG, is_constant)) { + GELOGE(FAILED, "Set attr ENTER_ATTR_CONSTANT_FLAG fail, node: %s", name.c_str()); + return FAILED; + } + + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, ENTER, TensorFlowEnterParser); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFENTER, TensorFlowEnterParser); +} // namespace ge \ No newline at end of file diff --git a/parser/tensorflow/tensorflow_enter_parser.h b/parser/tensorflow/tensorflow_enter_parser.h new file mode 100644 index 0000000..5f78e41 --- /dev/null +++ b/parser/tensorflow/tensorflow_enter_parser.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_ENTER_PARSER_H_ +#define GE_PARSER_TENSORFLOW_TENSORFLOW_ENTER_PARSER_H_ + +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowEnterParser : public TensorFlowOpParser { + public: + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) override; +}; +} // namespace ge +#endif // GE_PARSER_TENSORFLOW_TENSORFLOW_ENTER_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_fill_parser.cc b/parser/tensorflow/tensorflow_fill_parser.cc new file mode 100644 index 0000000..d467ed3 --- /dev/null +++ b/parser/tensorflow/tensorflow_fill_parser.cc @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright (c) <2018>, +#include "common/debug/log.h" +#include "common/op/attr_value_util.h" +#include "parser/common/op_def/fill_op.h" +#include "common/util.h" +#include "parser/tensorflow/tensorflow_parser_register.h" +#include "framework/omg/parser/parser_types.h" + +using ge::parser::ALPHA_DEFAULT_VALUE; +using ge::parser::BETA_DEFAULT_VALUE; +using ge::parser::FILL; + +namespace ge { +/* +node { + name: "model_with_buckets/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/zeros" + op: "Fill" + input: "model_with_buckets/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/concat" + input: "model_with_buckets/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/zeros/Const" + device: "/device:GPU:2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +*/ +domi::Status ParseParams(const NodeDef *node, FillOperator *op) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(op); + op->Name(node->name()); + + domi::tensorflow::DataType data_type; + GE_RETURN_IF_ERROR(TensorFlowUtil::ParseDataType(node, TENSORFLOW_ATTR_T, data_type)); + ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(data_type); + CHECK_FALSE_EXEC( + type != ge::DataType::DT_UNDEFINED, + GELOGE(PARAM_INVALID, "Data type %s of node %s is not supported.", DataType_Name(data_type).c_str(), + node->name().c_str()); + return PARAM_INVALID); + + op->DataType(type); + + op->Alpha(ge::parser::ALPHA_DEFAULT_VALUE); + op->Beta(ge::parser::BETA_DEFAULT_VALUE); + + return domi::SUCCESS; +} + +DOMI_REGISTER_TENSORFLOW_PARSER(FILL, FillOperator).SetParseParamsFn(ParseParams); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_frameworkop_parser.cc b/parser/tensorflow/tensorflow_frameworkop_parser.cc new file mode 100644 index 0000000..343c579 --- /dev/null +++ b/parser/tensorflow/tensorflow_frameworkop_parser.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/debug/log.h" +#include "parser/common/op_def/frameworkop_op.h" +#include "framework/common/debug/ge_log.h" +#include "parser/common/op_parser_factory.h" +#include "framework/omg/parser/parser_types.h" +#include "parser/tensorflow/tensorflow_op_parser.h" +#include "parser/tensorflow/tensorflow_parser_register.h" +#include "proto/tensorflow/tensor_shape.pb.h" + +using domi::tensorflow::TensorShapeProto; +using domi::tensorflow::AttrValue; +using domi::TENSORFLOW; +using ge::parser::FRAMEWORKOP; + +namespace ge { +Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op); + const NodeDef *node = reinterpret_cast(op_src); + GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); + string type = node->op(); + + // Parsing input / output desc in attr + domi::tensorflow::AttrValue input_attr_value; + domi::tensorflow::AttrValue output_attr_value; + if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { + GE_CHK_STATUS_RET( + TensorFlowUtil::TransTensorDescriptor(input_attr_value, op, TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG, type), + "trans input_attr_value failed, op: %s", node->name().c_str()); + } else { + GELOGD("Frameworkop has no input tensor desc, name:%s, type:%s.", node->name().c_str(), type.c_str()); + /// _Retval constructed from inference function do not has input_tensor_dec + /// set input tensor desc for adding input tensor desc for op desc + if (type == "_Retval") { + ge::GeTensorDesc tensor_desc; + op->InputTensorDesc(tensor_desc); + } + } + if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { + GE_CHK_STATUS_RET( + TensorFlowUtil::TransTensorDescriptor(output_attr_value, op, TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG, type), + "trans output_attr_value failed, op: %s", node->name().c_str()); + } else { + GELOGD("Frameworkop has no output tensor desc, name:%s, type:%s.", node->name().c_str(), type.c_str()); + } + + // Add index attribute, only Retval needs to be added + domi::tensorflow::AttrValue index_attr_value; + GE_IF_BOOL_EXEC(((type == "_Retval") && (TensorFlowUtil::FindAttrValue(node, ATTR_NAME_INDEX, index_attr_value))), + op->Index(index_attr_value.i())); + + NodeDef *pkg_node = new (std::nothrow) NodeDef(); + GE_CHECK_NOTNULL(pkg_node); + + pkg_node->CopyFrom(*node); + + domi::tensorflow::AttrValue attr_v; + // Get the property opdef, if the property does not exist, return failure + if (TensorFlowUtil::FindAttrValue(pkg_node, ge::ATTR_NAME_FRAMEWORK_OP_DEF, attr_v)) { + op->TfOpDef(attr_v.s()); + } else { + GE_CHK_BOOL_EXEC(type == "_Retval", + GE_DELETE_NEW_SINGLE(pkg_node); + return PARAM_INVALID, "In NodeDef %s Attr opdef is not exist.", pkg_node->name().c_str()); + } + + pkg_node->mutable_attr()->erase(ge::ATTR_NAME_FRAMEWORK_OP_DEF); + pkg_node->mutable_attr()->erase(ge::ATTR_NAME_OUTPUT_TENSOR_DESC); + pkg_node->mutable_attr()->erase(ge::ATTR_NAME_INPUT_TENSOR_DESC); + pkg_node->mutable_attr()->erase(ge::VAR_ATTR_NAME); + + // Get property func def + domi::tensorflow::AttrValue func_attr_v; + GE_IF_BOOL_EXEC(TensorFlowUtil::FindAttrValue(pkg_node, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, func_attr_v), + op->FuncDefPkg(func_attr_v.s()); + pkg_node->mutable_attr()->erase(ge::ATTR_NAME_FRAMEWORK_FUNC_DEF)); + GELOGD("pkg_node name is %s, op is %s.", pkg_node->name().c_str(), pkg_node->op().c_str()); + if (pkg_node->op() == "DPOP") { + pkg_node->set_op(pkg_node->name()); + } + + // Serialize nodedef into string and package as a whole + string serialized_node; + GE_IF_BOOL_EXEC(!pkg_node->SerializeToString(&serialized_node), + GELOGE(PARAM_INVALID, "In FrameworkOp trans NodeDef to string failed."); + GE_DELETE_NEW_SINGLE(pkg_node); return PARAM_INVALID); + + op->NodeDefPkg(serialized_node); + + string node_def_pkg = op->GetNodeDefPkg(); + + GELOGD("In FrameworkOp trans NodeDef to string success.op name : %s. nodedef_pkg [%s]", node->name().c_str(), + node_def_pkg.c_str()); + + // The framework operator of tensorflow preserves its framework type + op->Frameworktype(TENSORFLOW); + + op->OriginalType(type); + + // Add shape attribute, only variables need to be added + domi::tensorflow::AttrValue shape_value; + if (TensorFlowUtil::FindAttrValue(node, VAR_ATTR_SHAPE, shape_value)) { + vector shape_v; + TensorShapeProto shape_proto = shape_value.shape(); + for (auto dim : shape_proto.dim()) { + shape_v.push_back(dim.size()); + } + op->AttrVector(VAR_ATTR_SHAPE, shape_v); + } + + GE_DELETE_NEW_SINGLE(pkg_node); + return SUCCESS; +} + +DOMI_REGISTER_TENSORFLOW_PARSER(FRAMEWORKOP, FrameworkOpOperator).SetParseParamsFn(ParseParams); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc b/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc new file mode 100644 index 0000000..26d3c2b --- /dev/null +++ b/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "parser/common/op_parser_factory.h" +#include "register/op_registry.h" + +using domi::FusionParseParamFunc; +using domi::FusionParseParamByOpFunc; + +namespace ge { +Status TensorFlowFusionCustomParserAdapter::ParseParams(const vector &v_input_const, + ge::NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op_dest = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_dest); + + std::vector inside_nodes; + for (auto inside_node : v_input_const) { + GE_CHECK_NOTNULL(inside_node); + const google::protobuf::Message *node_src = reinterpret_cast(inside_node); + inside_nodes.push_back(node_src); + } + std::string ori_type = op_dest->GetType(); + (void)ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, ori_type); + FusionParseParamFunc + custom_op_parser = domi::OpRegistry::Instance()->GetFusionParseParamFunc(op_dest->GetType(), ori_type); + GE_CHECK_NOTNULL(custom_op_parser); + GELOGI("Get fusion parser succ, node: %s.", node->GetName().c_str()); + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + GE_CHK_BOOL_RET_STATUS(custom_op_parser(inside_nodes, op) == SUCCESS, FAILED, "Custom parser params failed"); + + op.BreakConnect(); + GELOGI("Run fusion parser succ, node: %s.", node->GetName().c_str()); + return SUCCESS; +} + +Status TensorFlowFusionCustomParserAdapter::ParseParams(const std::vector &v_input_const, + ge::NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op_dest = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_dest); + + GELOGI("Custom fusion begin to parse params, node: %s.", node->GetName().c_str()); + std::string ori_type = op_dest->GetType(); + (void)ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, ori_type); + FusionParseParamByOpFunc + custom_op_parser = domi::OpRegistry::Instance()->GetFusionParseParamByOpFunc(op_dest->GetType(), ori_type); + GE_CHECK_NOTNULL(custom_op_parser); + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + GE_CHK_BOOL_RET_STATUS(custom_op_parser(v_input_const, op) == SUCCESS, FAILED, "Custom parser params failed"); + + for (const auto &op_src : v_input_const) { + op_src.BreakConnect(); + } + op.BreakConnect(); + GELOGI("Run fusion parser succ, node: %s.", node->GetName().c_str()); + return SUCCESS; +} +} // namespace ge diff --git a/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h b/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h new file mode 100644 index 0000000..9243ca5 --- /dev/null +++ b/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_FUSION_CUSTOM_PARSER_ADAPTER_H_ +#define GE_PARSER_TENSORFLOW_TENSORFLOW_FUSION_CUSTOM_PARSER_ADAPTER_H_ + +#include "parser/tensorflow/tensorflow_fusion_op_parser.h" + +namespace ge { +class TensorFlowFusionCustomParserAdapter : public TensorFlowFusionOpParser { + public: + /** + * @ingroup domi_parser + * @brief Parsing model file information + * @param [in] v_input_const model data to be parsed + * @param [out] node model data after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseParams(const vector &v_input_const, ge::NodePtr &node) override; + + /** + * @ingroup domi_parser + * @brief Parsing model file information + * @param [in] v_input_const ge operators which save model data to be parsed + * @param [out] node model data after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseParams(const std::vector &v_input_const, ge::NodePtr &node); +}; +} // namespace ge + +#endif // GE_PARSER_TENSORFLOW_TENSORFLOW_FUSION_CUSTOM_PARSER_ADAPTER_H_ diff --git a/parser/tensorflow/tensorflow_fusion_op_parser.cc b/parser/tensorflow/tensorflow_fusion_op_parser.cc new file mode 100644 index 0000000..ca37f1d --- /dev/null +++ b/parser/tensorflow/tensorflow_fusion_op_parser.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_fusion_op_parser.h" +#include +#include "common/debug/log.h" +#include "common/ge/ge_util.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "omg/omg.h" +#include "parser/common/parser_fp16_t.h" +#include "parser/tensorflow/tensorflow_op_parser.h" +#include "register/tensor_assign.h" + +using domi::tensorflow::DataType; +using domi::tensorflow::NodeDef; + +namespace ge { +#define GET_CONST_VALUE(tensor, param, index, FIELD) \ + do { \ + google::protobuf::RepeatedField val_vec; \ + int32_t val_size = 0; \ + val_vec = tensor.FIELD##_val(); \ + val_size = val_vec.size(); \ + if (index < val_size) { \ + param = val_vec.Get(index); \ + } else if (tensor.has_tensor_shape()) { \ + const std::string tensor_content = tensor.tensor_content(); \ + char *buf = const_cast(tensor_content.data()); \ + FIELD *buf_v = reinterpret_cast(buf); \ + if (static_cast(index) >= tensor_content.length() / sizeof(FIELD)) { \ + GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index :%d,not supported!", index); \ + return domi::PARAM_INVALID; \ + } \ + param = buf_v[index]; \ + } else { \ + GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index :%d,not supported!", index); \ + return domi::PARAM_INVALID; \ + } \ + } while (false) + +Status TensorFlowFusionOpParser::GetTensorFromNode(const NodeDef *node_def, TensorProto &tensor) { + GE_CHECK_NOTNULL(node_def); + + string node_name = node_def->name(); + GELOGI("Convert NodeDef %s.", node_name.c_str()); + + domi::tensorflow::AttrValue attr_value; + // Check that the attribute value must exist and get the value. + if (!TensorFlowUtil::FindAttrValue(node_def, TENSORFLOW_ATTR_VALUE, attr_value)) { + GELOGE(domi::PARAM_INVALID, "NodeDef %s Attr %s is not exist.", node_name.c_str(), TENSORFLOW_ATTR_VALUE.c_str()); + return domi::PARAM_INVALID; + } + // Check that the value attribute must be tensor. + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR), + "check Attr %s failed", TENSORFLOW_ATTR_VALUE.c_str()); + tensor = attr_value.tensor(); + return SUCCESS; +} + +Status TensorFlowFusionOpParser::ParseParams(const vector &v_input_const, NodePtr &op_dest) { + return SUCCESS; +} + +Status TensorFlowFusionOpParser::ParseParams(const Message *op_src, OpDescPtr &op_dest) { return SUCCESS; } + +Status TensorFlowFusionOpParser::ParseParamFromConst(const NodeDef *node_def, int32_t ¶m) { + GE_CHECK_NOTNULL(node_def); + TensorProto tensor; + GetTensorFromNode(node_def, tensor); + GET_CONST_VALUE(tensor, param, 0, int); + return SUCCESS; +} +Status TensorFlowFusionOpParser::ParseParamFromConst(const NodeDef *node_def, int32_t ¶m, int index) { + GE_CHECK_NOTNULL(node_def); + TensorProto tensor; + GetTensorFromNode(node_def, tensor); + GET_CONST_VALUE(tensor, param, index, int); + return SUCCESS; +} +Status TensorFlowFusionOpParser::ParseParamFromConst(const NodeDef *node_def, float ¶m) { + GE_CHECK_NOTNULL(node_def); + TensorProto tensor; + GetTensorFromNode(node_def, tensor); + GET_CONST_VALUE(tensor, param, 0, float); + return SUCCESS; +} + +Status TensorFlowFusionOpParser::ParseParamFromConst(const NodeDef *node_def, float ¶m, int index) { + GE_CHECK_NOTNULL(node_def); + TensorProto tensor; + GetTensorFromNode(node_def, tensor); + GET_CONST_VALUE(tensor, param, index, float); + return SUCCESS; +} + +Status TensorFlowFusionOpParser::ParseHalfFromConst(const NodeDef *node_def, float ¶m, int index) { + GE_CHECK_NOTNULL(node_def); + TensorProto tensor; + GetTensorFromNode(node_def, tensor); + if (tensor.half_val().size() > 0) { + auto val_vec = tensor.half_val(); + int32_t val_size = val_vec.size(); + if (index < val_size) { + ge::parser::fp16_t fp16_value = static_cast(val_vec.Get(index)); + param = fp16_value.ToFloat(); + } else { + GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index:%d, not supported.", index); + return domi::PARAM_INVALID; + } + } else { + GELOGE(domi::PARAM_INVALID, "Node %s does not have half value, index:%d.", node_def->name().c_str(), index); + return domi::PARAM_INVALID; + } + return SUCCESS; +} + +Status TensorFlowFusionOpParser::ParseWeightFromConst(const NodeDef *node_def, ge::GeTensorPtr &weight) { + GE_CHECK_NOTNULL(node_def); + TensorProto tensor; + GE_CHK_STATUS_RET(GetTensorFromNode(node_def, tensor), "get tensor failed."); + weight = ge::MakeShared(); + GE_CHECK_NOTNULL(weight); + domi::tensorflow::DataType data_type = tensor.dtype(); + GE_CHK_STATUS_RET( + domi::TensorAssign::SetGeTensorDataType(domi::TensorAssign::ConvertTensorflowDataType(data_type), weight), + "set ge tensor data type fail"); + GE_CHK_STATUS_RET(domi::TensorAssign::SetGeTensor(tensor, weight), "set ge tensor fail"); + return SUCCESS; +} +} // namespace ge diff --git a/parser/tensorflow/tensorflow_fusion_op_parser.h b/parser/tensorflow/tensorflow_fusion_op_parser.h new file mode 100644 index 0000000..251bb4a --- /dev/null +++ b/parser/tensorflow/tensorflow_fusion_op_parser.h @@ -0,0 +1,89 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OMG_PARSER_TENSORFLOW_TENSORFLOW_FUSION_OP_PARSER_H_ +#define OMG_PARSER_TENSORFLOW_TENSORFLOW_FUSION_OP_PARSER_H_ + +#include +#include "common/op/attr_value_util.h" +#include "graph/ge_tensor.h" +#include "omg/parser/op_parser.h" +#include "parser/tensorflow/tensorflow_fusionop_util.h" +#include "parser/tensorflow/tensorflow_op_parser.h" +#include "parser/tensorflow/tensorflow_util.h" +#include "proto/tensorflow/graph.pb.h" +#include "proto/tensorflow/node_def.pb.h" + +using std::vector; +using google::protobuf::Message; +using domi::tensorflow::NodeDef; +using domi::tensorflow::TensorProto; + +namespace ge { +/** + * @ingroup domi_omg + * @brief Used to parse TensorFlow operator information + */ +class TensorFlowFusionOpParser : public TensorFlowOpParser { + public: + /** + * @ingroup domi_omg + * @brief Analytic operator parameters + * @param [in] v_input_const Operator parameters to be parsed + * @param [out] op_dest Parsed model data + * @return SUCCESS Parsing success + * @return FAILED Parsing failed + */ + virtual Status ParseParams(const vector &v_input_const, ge::NodePtr &node); + + /** + * @ingroup domi_omg + * @brief Analytic operator parameters + * @param [in] op_src Parameter data to be parsed + * @param [out] graph Parsed parameter data + * @return SUCCESS Parsing success + * @return FAILED Parsing failed + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) final; + + protected: + /** + * @ingroup domi_omg + * @brief Parse parameters from const op + * @param [in] op_src Model data to be parsed + * @param [out] op_dest Parsed model data + * @return SUCCESS Parsing success + * @return FAILED Parsing failed + * + */ + // template + Status ParseParamFromConst(const NodeDef *input_const, int32_t ¶m); + + Status ParseParamFromConst(const NodeDef *nodeDef, int32_t ¶m, int index); + + Status ParseParamFromConst(const NodeDef *input_const, float ¶m); + + Status ParseParamFromConst(const NodeDef *nodeDef, float ¶m, int index); + + Status GetTensorFromNode(const NodeDef *nodeDef, TensorProto &tensor); + + Status ParseHalfFromConst(const NodeDef *node_def, float ¶m, int index = 0); + + Status ParseWeightFromConst(const NodeDef *node_def, ge::GeTensorPtr &weight); +}; +} // namespace ge + +#endif // OMG_PARSER_TENSORFLOW_TENSORFLOW_FUSION_OP_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_fusionop_util.cc b/parser/tensorflow/tensorflow_fusionop_util.cc new file mode 100644 index 0000000..404f7e4 --- /dev/null +++ b/parser/tensorflow/tensorflow_fusionop_util.cc @@ -0,0 +1,379 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_fusionop_util.h" +#include "common/util/error_manager/error_manager.h" +#include "common/debug/log.h" +#include "common/op/ge_op_utils.h" +#include "framework/common/debug/ge_log.h" +#include "parser/tensorflow/tensorflow_parser.h" +#include "framework/omg/parser/parser_types.h" + +#include +#include +#include + +using domi::tensorflow::NodeDef; + +namespace ge { +// constraint: At present, only a few fixed fusion operators are supported, +// and forward matching method is used for recognition +// eg: in the MaskRCNN network, +// clip_boxes are treated as fusion operators but generate_rpn_proposals/clip_boxes is also fused +// considered to be a child operator of generate_rpn_proposals. +// clip_boxes +// fastrcnn_predictions +// decode_bbox_target +// generate_rpn_proposals +// roi_align +// cond_1/roi_align +namespace { +const char *const kLstmCellKernelFw = "fw/basic_lstm_cell/kernel"; +const char *const kLstmCellKernelBw = "bw/basic_lstm_cell/kernel"; +const char *const kLstmCellBiasFw = "fw/basic_lstm_cell/bias"; +const char *const kLstmCellBiasBw = "bw/basic_lstm_cell/bias"; +const char *const kAttentionDecoderEmbeeding = "embedding_attention_decoder/embedding"; +const char *const kAttentionDecoderAttenW0 = "embedding_attention_decoder/attention_decoder/AttnW_0"; +const char *const kAttentionDecoderAttenVa = "embedding_attention_decoder/attention_decoder/AttnV_0"; +const char *const kAttentionDecoderAttentionDecoderKernel = "embedding_attention_decoder/attention_decoder/kernel"; +const char *const kAttentionDecoderAtteBias = "embedding_attention_decoder/attention_decoder/bias"; +const char *const kAttentionDecoderCell0GatesKernel = + "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_0/gru_cell/gates/kernel"; +const char *const kAttentionDecoderCell0GatesBias = + "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_0/gru_cell/gates/bias"; +const char *const kAttentionDecoderCell0CandidateKernel = + "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_0/gru_cell/candidate/kernel"; +const char *const kAttentionDecoderCell0CandidateBias = + "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_0/gru_cell/candidate/bias"; +const char *const kAttentionDecoderCell1GatesKernel = + "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_1/gru_cell/gates/kernel"; +const char *const kAttentionDecoderCell1GatesBias = + "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_1/gru_cell/gates/bias"; +const char *const kAttentionDecoderCell1CandidateKernel = + "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_1/gru_cell/candidate/kernel"; +const char *const kAttentionDecoderCell1CandidateBias = + "embedding_attention_decoder/attention_decoder/multi_rnn_cell/cell_1/gru_cell/candidate/bias"; +const char *const kAttentionDecoderAttention0Kernel = + "embedding_attention_decoder/attention_decoder/Attention_0/kernel"; +const char *const kAttentionDecoderAttention0Bias = "embedding_attention_decoder/attention_decoder/Attention_0/bias"; +const char *const kAttentionDecoderAttnOutputProjectionKernel = + "embedding_attention_decoder/attention_decoder/AttnOutputProjection/kernel"; +const char *const kAttentionDecoderAttnOutputProjectionBias = + "embedding_attention_decoder/attention_decoder/AttnOutputProjection/bias"; +const char *const kHuberLossFill = "gradients/Fill"; +const char *const kHuberLossConst = "huber_loss/Const"; +const char *const kHuberLossMul2X = "huber_loss/Mul_2/x"; +const char *const kSparseSoftmaxConst = "sparse_softmax_cross_entropy_loss/Const"; +const char *const kDeeplabV3ConfusionMatrix = "Select"; +const char *const kDeeplabV3ConfusionMatrix1 = "ToFloat_1"; +const char *const kConstantFoldingSuffix = "ConstantFolding/"; +} // namespace +vector const_op_update_vec = {kLstmCellKernelFw, + kLstmCellKernelBw, + kLstmCellBiasFw, + kLstmCellBiasBw, + kAttentionDecoderAttenW0, + kAttentionDecoderAttention0Kernel, + kAttentionDecoderAttnOutputProjectionKernel, + kAttentionDecoderAttentionDecoderKernel, + kAttentionDecoderCell0GatesKernel, + kAttentionDecoderCell0CandidateKernel, + kAttentionDecoderCell1GatesKernel, + kAttentionDecoderCell1CandidateKernel, + kAttentionDecoderAttention0Bias, + kAttentionDecoderAttnOutputProjectionBias, + kAttentionDecoderAtteBias, + kAttentionDecoderCell0GatesBias, + kAttentionDecoderCell0CandidateBias, + kAttentionDecoderCell1GatesBias, + kAttentionDecoderCell1CandidateBias, + kAttentionDecoderEmbeeding, + kAttentionDecoderAttenVa, + kHuberLossFill, + kHuberLossConst, + kHuberLossMul2X, + kSparseSoftmaxConst, + kDeeplabV3ConfusionMatrix, + kDeeplabV3ConfusionMatrix1}; + +static map tensorflow_fusionop_map = { +}; + +// +static map> tensorflow_fusionop_children_nums_map = { + {ge::parser::CLIPBOXES, {8}}, + {ge::parser::FASTRCNNPREDICTIONS, {118, 119, 120, 123, 125}}, + {ge::parser::RPNPROPOSALS, {75, 85, 97}}, + {ge::parser::DECODEBBOX, {24, 28}}, + {ge::parser::ROIALIGN, {82, 83, 84}}, + {ge::parser::FUSIONBATCHNORM, {8}}, + {ge::parser::GETSPAN, {81, 71, 91}}, // The pbtxt only has 62 nodes when test GetSpan sub net. However the + {ge::parser::HUBERLOSSGRAD, {8, 9, 10, 20, 21}}, +}; + +// +static map> tensorflow_fusionop_children_names_map = { + {ge::parser::FUSIONBATCHNORM, {"add/y", "add", "Rsqrt", "mul", "mul_1", "mul_2", "sub", "add_1"}}, + {ge::parser::GETSPAN, {}}, + {ge::parser::HUBERLOSSGRAD, {}}, +}; + +// ----------------------------Index table of input and output of fusion operator-------------- +// The specific operator is the input and output of the whole fusion operator, and the index number is specified +// Considering that an operator may have multiple inputs / outputs, vector is used to save +// search method: new_index=vector(old_index), +// Generally, the old index is 0. If the new index value is kFusionDisableIndex, the edge can be ignored. +// If it is control edge input, the index is graph::kControlSlot(-1). +static map>>> tensorflow_fusionop_inputs_map = { + {ge::parser::FUSIONBATCHNORM, + {{"mul_1", {0, kFusionDisableIndex}}, + {"mul", {1, 1}}, + {"sub", {2, kFusionDisableIndex}}, + {"mul_2", {3, kFusionDisableIndex}}, + {"add", {4, kFusionDisableIndex}}}}, + {ge::parser::GETSPAN, {{"transpose", {0}}, {"TensorArray", {1}}, {"transpose_1", {2}}}}, + {ge::parser::HUBERLOSSGRAD, {{"Sub_1_grad/Neg", {1}}, {"Abs_grad/Sign", {0}}}}, +}; + +static map>>> tensorflow_fusionop_outputs_map = { + {ge::parser::FUSIONBATCHNORM, {{"add_1", {0}}}}, + {ge::parser::GETSPAN, {{"while/Exit_1", {0}}, {"while/Exit_2", {1}}}}, + {ge::parser::HUBERLOSSGRAD, {{"Abs_grad/mul", {0}}}}, +}; +map>> tensorflow_fusionop_input_const_weight_index_map = { + {ge::parser::FUSIONBATCHNORM, {{"mul", 0}, {"sub", 1}, {"mul_2", 2}, {"add", 3}}}, +}; + +// Can a string be converted to an integer +bool TensorFlowFunsionOPUtil::IsIntegerStr(const string &index_str) { + try { + if (std::stoi(index_str) > 0) { + return true; + } + } catch (std::invalid_argument &) { + GELOGE(FAILED, "index_str:%s is invalid", index_str.c_str()); + } catch (std::out_of_range &) { + GELOGE(FAILED, "index_str:%s is out of range", index_str.c_str()); + } catch (...) { + GELOGE(FAILED, "index_str:%s cannot change to int s", index_str.c_str()); + } + return false; +} + +// Get child node name of fusion operator. +// eg: input: fastrcnn_predictions/map/TensorArray_2 output: map/TensorArray_2 +string TensorFlowFunsionOPUtil::GetChildName(const string &node_name, const string &fusion_node_name) { + GE_CHK_BOOL_EXEC_NOLOG( + (node_name.length() - fusion_node_name.length()) > 0, GELOGW("fusion_node_name length not valid."); return "";); + + string child_name; + string sub_name; + + // node_name begin with "ConstantFolding/" + if (node_name.find(kConstantFoldingSuffix) == 0) { + auto length = strlen(kConstantFoldingSuffix); + sub_name = + node_name.substr(fusion_node_name.length() + length, node_name.length() - fusion_node_name.length() - length); + } else { + sub_name = node_name.substr(fusion_node_name.length(), node_name.length() - fusion_node_name.length()); + } + + auto index = sub_name.find('/'); + if (index != string::npos) { + child_name = sub_name.substr(index + 1, sub_name.length() - index - 1); + } + + return child_name; +} + +// Check whether the operator node name can be a fusion operator +bool TensorFlowFunsionOPUtil::MaybeFusionOp(const string &node_name, ScopeFusionOpInfo *info) { + GE_CHK_BOOL_EXEC(info != nullptr, return false, "info is null."); + info->node_name = node_name; + // Direct forward matching + for (auto iter = tensorflow_fusionop_map.begin(); iter != tensorflow_fusionop_map.end(); ++iter) { + const string fop_name = iter->first; + + string node_name_tmp = node_name; + // begin with "ConstantFolding/" + if (node_name_tmp.find(kConstantFoldingSuffix) == 0) { + auto length = strlen(kConstantFoldingSuffix); + node_name_tmp = node_name.substr(length, node_name.length() - length); + } + + // not match + if (node_name_tmp.find(fop_name) != 0) { + continue; + } + + // match,"FusionName/" scene: + if (node_name_tmp.substr(fop_name.length(), 1) == string("/")) { + info->fusion_node_name = fop_name; + info->fusion_op_type = tensorflow_fusionop_map[fop_name]; + info->description = ""; + info->scope_pass = false; + return true; + } + + // match "FusionName_Index/" scene: + // special characters need unified definition + string sub_name = node_name_tmp.substr(fop_name.length(), node_name_tmp.length() - fop_name.length()); + auto index = sub_name.find('/'); + if ((sub_name.substr(0, 1) == string("_")) && (index > 1) && IsIntegerStr(sub_name.substr(1, index - 1))) { + info->fusion_node_name = fop_name + sub_name.substr(0, index); + info->fusion_op_type = tensorflow_fusionop_map[fop_name]; + info->description = ""; + info->scope_pass = false; + return true; + } + } + + return false; +} + +// Confirm whether it is a fusion operator +bool TensorFlowFunsionOPUtil::IsFusionOp(const domi::tensorflow::NodeDef *node_def) { + GE_CHK_BOOL_EXEC(node_def != nullptr, return false, "node_def is null."); + string type = node_def->op(); + auto iter = tensorflow_fusionop_children_nums_map.find(type); + return iter != tensorflow_fusionop_children_nums_map.end(); +} + +// Check the validity of fusion operator (all child nodes) +Status TensorFlowFunsionOPUtil::CheckFusionOpChildren(const string &fusion_node_name, + const vector &nodedef_list, + const string &funsion_op_type) { + // Number matching of fusion operators + auto iter_children_nums = tensorflow_fusionop_children_nums_map.find(funsion_op_type); + if (iter_children_nums == tensorflow_fusionop_children_nums_map.end()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E12018", {"opname", "optype"}, {fusion_node_name, funsion_op_type}); + GELOGE(domi::INTERNAL_ERROR, + "Op[%s]'s optype[%s] not a Fusion OP!", fusion_node_name.c_str(), funsion_op_type.c_str()); + return domi::INTERNAL_ERROR; + } + + vector children_nums = iter_children_nums->second; + bool find = false; + int children_num = nodedef_list.size(); + for (uint32_t i = 0; i < children_nums.size(); i++) { + if (children_nums[i] == children_num) { + find = true; + break; + } + } + + if (!find) { + ErrorManager::GetInstance().ATCReportErrMessage("E12019", + {"opname", "optype", "childrennum"}, {fusion_node_name, funsion_op_type, std::to_string(children_num)}); + GELOGE(domi::INTERNAL_ERROR, + "Op[%s]'s optype[%s] children_nums:%d is not the same for define.", + fusion_node_name.c_str(), + funsion_op_type.c_str(), + children_num); + return domi::INTERNAL_ERROR; + } + + // Key children operators matching + auto iter_children_names = tensorflow_fusionop_children_names_map.find(funsion_op_type); + if (iter_children_names != tensorflow_fusionop_children_names_map.end()) { + vector children_names = iter_children_names->second; + if (!children_names.empty()) { + uint32_t count = 0; + for (uint32_t i = 0; i < children_names.size(); i++) { + for (uint32_t j = 0; j < nodedef_list.size(); j++) { + const domi::tensorflow::NodeDef *node_def = nodedef_list[j]; + GE_CHECK_NOTNULL(node_def); + string node_name = node_def->name(); + string child_name = GetChildName(node_name, fusion_node_name); + if (children_names[i] == child_name) { + count++; + break; + } + } + } + + GE_IF_BOOL_EXEC(count != children_names.size(), + ErrorManager::GetInstance().ATCReportErrMessage( + "E12020", {"opname", "optype"}, {fusion_node_name, funsion_op_type}); + GELOGE(domi::INTERNAL_ERROR, "Op[%s]'s optype[%s] has no enough importance child.", fusion_node_name.c_str(), + funsion_op_type.c_str()); + return domi::INTERNAL_ERROR;); + } + } + + return SUCCESS; +} + +// Get the child node of the fusion operator as the input / output index number of the whole fusion operator +Status TensorFlowFunsionOPUtil::GetNodeindex( + const ScopeFusionOpInfo &info, const int32_t old_index, int32_t &new_index, + const map>>> &fusionop_context_map) { + auto iter = fusionop_context_map.find(info.fusion_op_type); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter == fusionop_context_map.end(), + return domi::INTERNAL_ERROR, + "Op[%s] could not find item of optype[%s] in fusionop_context_map", + info.node_name.c_str(), info.fusion_op_type.c_str()); + + vector>> pairs = iter->second; + + string child_name = GetChildName(info.node_name, info.fusion_node_name); + + GELOGI("GetNodeindex: info.node_name:%s, old_index:%d", info.node_name.c_str(), old_index); + for (const auto &pair : pairs) { + if (pair.first == child_name) { + vector indexs = pair.second; + if (static_cast(indexs.size()) < (old_index + 1)) { + new_index = kFusionDisableIndex; + return SUCCESS; + } + + if (old_index != -1) { + new_index = indexs[old_index]; + return SUCCESS; + } + } + } + + new_index = kFusionDisableIndex; + return SUCCESS; +} + +// Get the input index of the fusion operator +Status TensorFlowFunsionOPUtil::GetInPutIndex(const ScopeFusionOpInfo &info, const int32_t old_index, + int32_t &new_index) { + return GetNodeindex(info, old_index, new_index, tensorflow_fusionop_inputs_map); +} + +// Get the output index of the fusion operator +Status TensorFlowFunsionOPUtil::GetOutPutIndex(const ScopeFusionOpInfo &info, const int32_t old_index, + int32_t &new_index) { + return GetNodeindex(info, old_index, new_index, tensorflow_fusionop_outputs_map); +} + +bool TensorFlowFunsionOPUtil::FusionOpChildIgnore(const ScopeFusionOpInfo &info) { + // If the small operator is not in the input and output index table of the fusion operator, + // it is unnecessary to establish the edge relationship and can be ignored + int32_t old_index = 0; + int32_t in_new_index = 0; + int32_t out_new_index = 0; + GE_CHK_STATUS(GetInPutIndex(info, old_index, in_new_index), "GetInPutIndex failed"); + GE_CHK_STATUS(GetOutPutIndex(info, old_index, out_new_index), "GetOutPutIndex failed"); + + return (in_new_index == kFusionDisableIndex) && (out_new_index == kFusionDisableIndex); +} +} // namespace ge diff --git a/parser/tensorflow/tensorflow_fusionop_util.h b/parser/tensorflow/tensorflow_fusionop_util.h new file mode 100644 index 0000000..f08ccf9 --- /dev/null +++ b/parser/tensorflow/tensorflow_fusionop_util.h @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_FUSIONOP_UTIL_H_ +#define GE_PARSER_TENSORFLOW_TENSORFLOW_FUSIONOP_UTIL_H_ +#include +#include +#include +#include +#include "common/debug/log.h" +#include "common/string_util.h" +#include "framework/omg/parser/parser_types.h" +#include "common/util.h" +#include "omg/omg_inner_types.h" +#include "proto/tensorflow/graph.pb.h" +#include "external/register/scope/scope_fusion_pass_register.h" +#include "register/scope/scope_graph_impl.h" + +namespace ge { +using std::string; +using std::vector; +extern map>> tensorflow_fusionop_input_const_weight_index_map; +extern vector const_op_update_vec; + +class TensorFlowFunsionOPUtil { + public: + /** + * @ingroup domi_omg + * @brief Check whether the operator can be a fusion operator + * @param [in] node_name operation name + * @return info fusion operator description + * @return true maybe + * @return false maybe not + * @author + */ + static bool MaybeFusionOp(const string &node_name, ScopeFusionOpInfo *info); + + /** + * @ingroup domi_omg + * @brief Confirm whether it is a fusion operator + * @param [in] nodeDef + * @return true + * @return false + * @author + */ + static bool IsFusionOp(const domi::tensorflow::NodeDef *node_def); + + /** + * @ingroup domi_omg + * @brief Check the validity of fusion operator(All child nodes) + * @param [in] fusion_node_name fusion operator name + * @param [in] nodedef_list child nodes list + * @param [in] funsion_op_type fusion operator type + * @return legal/illegal + * @author + */ + static Status CheckFusionOpChildren(const string &fusion_node_name, + const vector &nodedef_list, + const string &funsion_op_type); + + /** + * @ingroup domi_omg + * @brief get inPut index of the fusion operator + * @param [in] info Child node description of fusion operator + * @param [in] old_index Child node original index + * @return old_index As input index of the fusion operator + * @return return code + * @author + */ + static Status GetInPutIndex(const ScopeFusionOpInfo &info, const int32_t old_index, int32_t &new_index); + + /** + * @ingroup domi_omg + * @brief get outPut index of the fusion operator + * @param [in] info Child node description of fusion operator + * @param [in] old_index Child node original index + * @return old_index As output index of the fusion operator + * @return 返回码 + * @author + */ + static Status GetOutPutIndex(const ScopeFusionOpInfo &info, const int32_t old_index, int32_t &new_index); + + static bool FusionOpChildIgnore(const ScopeFusionOpInfo &info); + /** + * @ingroup domi_omg + * @brief Get child node name of fusion operator eg: input: fastrcnn_predictions/map/TensorArray_2 output + * :map/TensorArray_2 + * @param [in] node_name node name + * @param [in] fusion_node_name fusion node name + * @return Child node name of the fusion node + * @author + */ + static string GetChildName(const string &node_name, const string &fusion_node_name); + + private: + /** + * @ingroup domi_omg + * @brief whether a string can be converted to an integer + * @param [in] indexstr Operator suffix index + * @return true can + * @return false can not + * @author + */ + static bool IsIntegerStr(const string &index_str); + + /** + * @ingroup domi_omg + * @brief Get child node of fusion operator + * @param [in] info Description of fusion operator + * @param [in] old_index original index + * @return new_index Fusion operator index + * @author + */ + static Status GetNodeindex(const ScopeFusionOpInfo &info, const int32_t old_index, int32_t &new_index, + const std::map>>> &fusionop_context_map); +}; +} // namespace ge + +#endif // GE_PARSER_TENSORFLOW_TENSORFLOW_FUSIONOP_UTIL_H_ diff --git a/parser/tensorflow/tensorflow_identity_parser.cc b/parser/tensorflow/tensorflow_identity_parser.cc new file mode 100644 index 0000000..50f6277 --- /dev/null +++ b/parser/tensorflow/tensorflow_identity_parser.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/op/ge_op_utils.h" +#include "common/op_def/ir_pb_converter.h" +#include "parser/common/op_parser_factory.h" +#include "framework/omg/parser/parser_types.h" + +#include "parser/tensorflow/tensorflow_identity_parser.h" + +using domi::TENSORFLOW; +using ge::parser::IDENTITY; +using ge::parser::READVARIABLEOP; + +namespace ge { +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, IDENTITY, TensorFlowIdentityParser); +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, READVARIABLEOP, TensorFlowIdentityParser); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_identity_parser.h b/parser/tensorflow/tensorflow_identity_parser.h new file mode 100644 index 0000000..0b4a342 --- /dev/null +++ b/parser/tensorflow/tensorflow_identity_parser.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_TENSORFLOW_TENSORFLOW_IDENTITY_H_ +#define GE_PARSER_TENSORFLOW_TENSORFLOW_IDENTITY_H_ + +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowIdentityParser : public TensorFlowOpParser {}; +} // namespace ge + +#endif // GE_PARSER_TENSORFLOW_TENSORFLOW_IDENTITY_H_ diff --git a/parser/tensorflow/tensorflow_merge_parser.cc b/parser/tensorflow/tensorflow_merge_parser.cc new file mode 100644 index 0000000..6f1dedb --- /dev/null +++ b/parser/tensorflow/tensorflow_merge_parser.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_merge_parser.h" + +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" +#include "parser/common/op_parser_factory.h" +#include "framework/omg/parser/parser_types.h" + +using domi::TENSORFLOW; +using ge::parser::MERGE; + +namespace ge { +Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op_desc); + + const NodeDef *node = reinterpret_cast(op_src); + domi::tensorflow::AttrValue attr_num; + if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { + GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); + } + int32_t input_tensor_num = attr_num.i(); + + // add dynamic input + graphStatus ret = op_desc->AddDynamicInputDesc("x", input_tensor_num); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add dynamic input:x for node:%s failed.", op_desc->GetName().c_str()); + return FAILED; + } + GELOGI("add dynamic input for Merge op [%s], num:%d", op_desc->GetName().c_str(), input_tensor_num); + + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, MERGE, TensorFlowMergeParser); +} diff --git a/parser/tensorflow/tensorflow_merge_parser.h b/parser/tensorflow/tensorflow_merge_parser.h new file mode 100644 index 0000000..e5f9a35 --- /dev/null +++ b/parser/tensorflow/tensorflow_merge_parser.h @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef _DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_ +#define _DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_ + +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowMergeParser : public TensorFlowOpParser { + public: + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) override; +}; +} // namespace ge +#endif //_DOMI_OMG_PARSER_TENSORFLOW_TENSORFLOW_MERGE_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_no_op_parser.cc b/parser/tensorflow/tensorflow_no_op_parser.cc new file mode 100644 index 0000000..633e921 --- /dev/null +++ b/parser/tensorflow/tensorflow_no_op_parser.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_no_op_parser.h" +#include "framework/common/util.h" +#include "framework/common/debug/ge_log.h" +#include "parser/common/op_def/ir_pb_converter.h" +#include "parser/common/op_def/no_op_op.h" +#include "parser/common/op_parser_factory.h" + +using domi::TENSORFLOW; +using namespace ge::parser; + +namespace ge { +Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { + const NodeDef *node = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(node); + GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); + NoOpOperator op; + op.Name(node->name()); + + return ConvertToOpDesc(op, op_dest); +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, NOOP, TensorFlowNoOpParser); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_no_op_parser.h b/parser/tensorflow/tensorflow_no_op_parser.h new file mode 100644 index 0000000..de78fbc --- /dev/null +++ b/parser/tensorflow/tensorflow_no_op_parser.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_TENSORFLOW_TENSORFLOW_NO_OP_PARSER_H_ +#define PARSER_TENSORFLOW_TENSORFLOW_NO_OP_PARSER_H_ + +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowNoOpParser : public TensorFlowOpParser { + // AUTO GEN PLEASE DO NOT MODIFY IT + public: + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; +}; +} // namespace ge + +#endif // PARSER_TENSORFLOW_TENSORFLOW_NO_OP_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_op_parser.h b/parser/tensorflow/tensorflow_op_parser.h new file mode 100644 index 0000000..983597b --- /dev/null +++ b/parser/tensorflow/tensorflow_op_parser.h @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OMG_PARSER_TENSORFLOW_TENSORFLOW_OP_PARSER_H_ +#define OMG_PARSER_TENSORFLOW_TENSORFLOW_OP_PARSER_H_ + +#include +#include +#include "framework/common/op/attr_value_util.h" +#include "framework/omg/parser/op_parser.h" +#include "graph/ge_tensor.h" +#include "graph/node.h" +#include "register/tensor_assign.h" +#include "parser/tensorflow/tensorflow_util.h" +#include "proto/tensorflow/graph.pb.h" +#include "proto/tensorflow/node_def.pb.h" + + +using domi::tensorflow::NodeDef; +using domi::tensorflow::TensorProto; +using google::protobuf::int32; +using google::protobuf::int64; +using google::protobuf::Message; +using std::string; +using std::vector; +using Status = domi::Status; +using domi::tensorflow::AttrValue; +using domi::tensorflow::DataType; +using domi::tensorflow::DT_BOOL; +using domi::tensorflow::DT_FLOAT; +using domi::tensorflow::DT_INT32; +using domi::tensorflow::DT_INT64; +using domi::tensorflow::DT_INVALID; +using domi::tensorflow::TensorShapeProto; +using domi::tensorflow::TensorShapeProto_Dim; + +namespace ge { +/** + * @ingroup domi_omg + * @brief used to parse TensorFlow operator information + */ +class TensorFlowOpParser : public OpParser { + public: + + /** + * @ingroup domi_omg + * @brief parse params + * @param [in] op_src op to be parsed + * @param [out] op_dest the parsed op + * @return SUCCESS parse success + * @return FAILED Parse failed + * + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override { + return domi::SUCCESS; + } + + /** + * @ingroup domi_omg + * @brief parse params + * @param [in] op_src op to be parsed + * @param [out] op_dest the operator + * @return SUCCESS parse success + * @return FAILED Parse failed + * + */ + Status ParseParams(const Message *op_src, ge::Operator &op_dest) override { + return domi::SUCCESS; + } + + /** + * @ingroup domi_omg + * @brief parsie weight + * @param [in] op_src op to be parsed + * @param [out] op_dest the parsed op + * @return SUCCESS parsing success + * @return FAILED parsing failed + * + */ + Status ParseWeights(const Message *op_src, ge::NodePtr &node) final { + return domi::SUCCESS; + } +}; +} // namespace ge + +#endif // OMG_PARSER_TENSORFLOW_TENSORFLOW_OP_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc new file mode 100644 index 0000000..8c965e0 --- /dev/null +++ b/parser/tensorflow/tensorflow_parser.cc @@ -0,0 +1,3722 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_parser.h" +#include +#include +#include "parser/common/convert/pb2json.h" +#include "common/debug/log.h" +#include "common/ge/ge_util.h" +#include "common/util/error_manager/error_manager.h" +#include "external/graph/operator_factory.h" +#include "external/parser/tensorflow_parser.h" +#include "external/register/scope/scope_fusion_pass_register.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/parser/parser_api.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/optimize/common/params.h" +#include "graph/passes/variable_format_pass.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/type_utils.h" +#include "iterator_fusion_pass.h" +#include "omg/omg.h" +#include "omg/parser/op_parser.h" +#include "omg/parser/parser_factory.h" +#include "parser/common/acl_graph_parser_util.h" +#include "parser/common/model_saver.h" +#include "parser/common/op_map.h" +#include "parser/common/op_parser_factory.h" +#include "parser/common/parser_fp16_t.h" +#include "parser/common/pass_manager.h" +#include "parser/common/pre_checker.h" +#include "parser/common/thread_pool.h" +#include "parser/tensorflow/tensorflow_custom_parser_adapter.h" +#include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h" +#include "parser/tensorflow/tensorflow_fusion_op_parser.h" +#include "parser/tensorflow/tensorflow_fusionop_util.h" +#include "parser/tensorflow/tensorflow_op_parser.h" +#include "parser/tensorflow/tensorflow_util.h" +#include "register/op_registry.h" +#include "register/scope/scope_graph_impl.h" +#include "register/scope/scope_pass_registry_impl.h" + +using ge::const_op_update_vec; +using ge::OpParserFactory; +using ge::Pb2Json; +using ge::PreChecker; +using ge::TENSORFLOW_ATTR_DATA_FORMAT; +using ge::TENSORFLOW_ATTR_DTYPE; +using ge::TENSORFLOW_ATTR_SHAPE; +using ge::TENSORFLOW_ATTR_T; +using ge::TENSORFLOW_ATTR_TYPE_STRING; +using ge::TENSORFLOW_ATTR_TYPE_TENSOR; +using ge::TENSORFLOW_ATTR_TYPE_TYPE; +using ge::TENSORFLOW_ATTR_VALUE; +using ge::TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG; +using ge::TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG; +using ge::tensorflow_op_map; +using ge::tensorflow_train_op_map; +using ge::TENSORFLOWF_NODE_OP_CONST; +using ge::TENSORFLOWF_NODE_OP_IDENTITY; +using ge::TENSORFLOWF_NODE_OP_MERGE; +using ge::TENSORFLOWF_NODE_OP_PLACEHOLDER; +using ge::TENSORFLOWF_NODE_OP_SWITCH; +using ge::TENSORFLOWF_NODE_OP_TRANSPOSE; +using ge::TENSORFLOWF_TENSOR_NCHW; +using ge::TENSORFLOWF_TENSOR_NHWC; +using ge::TensorFlowFunsionOPUtil; +using ge::TensorFlowFusionCustomParserAdapter; +using ge::TensorFlowFusionOpParser; +using ge::TensorFlowOpParser; +using ge::ThreadPool; +using ge::parser::fp16_t; +using ge::parser::ModelSaver; + +namespace ge { +graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph) { + GE_CHECK_NOTNULL(model_file); + GetParserContext().type = domi::TENSORFLOW; + std::map options; + options.insert(std::pair(string(ge::FRAMEWORK_TYPE), to_string(ge::TENSORFLOW))); + + // load custom plugin so and proto + AclGrphParseUtil acl_graph_parse_util; + (void)acl_graph_parse_util.AclParserInitialize(options); + + // Create an empty computegraph + ge::ComputeGraphPtr compute_graph = ge::MakeShared("tmpGraph"); + GE_CHECK_NOTNULL(compute_graph); + + graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::TENSORFLOW); + GE_CHECK_NOTNULL(model_parser); + + // parse tensorflow model_file to GE graph + ge::graphStatus ret = model_parser->Parse(model_file, graph); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); + return ge::FAILED; + } + + if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { + GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); + return ge::FAILED; + } + GELOGI("Parser graph %s success.", graph.GetName().c_str()); + return ge::SUCCESS; +} +} // namespace ge + +namespace ge { +namespace { +const int kTransposeInputIdx = 0; +const uint32_t kThreadNum = 16; +const size_t kInputNumUint = 2; +const int kInputNumInt = 2; +const int32_t kControlSlot = -1; +const size_t kSoftmaxMultiple = 2; +const set kTfBlackFields = {"tensor_content"}; +const std::vector kSkipCheckoutInputSizeNodes = {ge::parser::DATA, ge::parser::VARIABLE, + ge::parser::FRAMEWORKOP, ge::parser::LAYERNORM}; +const std::vector kMakeOperatorNotByIr = {ge::parser::ARG, ge::parser::VARIABLE, ge::parser::VARHANDLEOP, + ge::parser::FRAMEWORKOP, ge::parser::DATA}; +const std::map kNeedMarkFormatNodes = { + {"ExtractImagePatches", domi::DOMI_TENSOR_NHWC}, + {"ExtractVolumePatches", domi::DOMI_TENSOR_NHWC}, + {"LogSoftmax", domi::DOMI_TENSOR_NHWC}, + {"ResizeBilinear", domi::DOMI_TENSOR_NHWC}, + {"ResizeBilinearGrad", domi::DOMI_TENSOR_NHWC}, + {"ResizeNearestNeighbor", domi::DOMI_TENSOR_NHWC}, + {"Softmax", domi::DOMI_TENSOR_NHWC}, + {"SoftmaxCrossEntropyWithLogits", domi::DOMI_TENSOR_NHWC}, + {"SoftmaxGrad", domi::DOMI_TENSOR_NHWC}, + {"SpaceToBatch", domi::DOMI_TENSOR_NHWC}}; +const char *const kDpop = "DPOP"; +const char *const kFuncDefLibraryFilePath = "graph_def_library.pbtxt"; +const char *const kAttrNameIsScopeInnerNode = "_is_scope_inner_node"; +struct ParseArg { + const google::protobuf::Message *proto; + std::string function_name; + ge::NodePtr parent_node; + std::string subgraph_name; + ge::ComputeGraphPtr graph; +}; + +Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque &args) { + GELOGI("Gen subgraph parse tasks start"); + for (auto &node : parent_graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { + auto i = subgraph_name_to_index.second; + auto subgraph_iname = op_desc->GetSubgraphInstanceName(i); + if (subgraph_iname.empty()) { + GELOGE(PARAM_INVALID, "The subgraph index %u of node %s is empty", i, node->GetName().c_str()); + return PARAM_INVALID; + } + + // A function may be referenced multiple times in TF, change the graph name to ensure it is unique in GE + auto unique_name = node->GetName() + std::to_string(i) + subgraph_iname; + auto subgraph = ge::MakeShared(unique_name); + if (subgraph == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to alloc subgraph %s", subgraph_iname.c_str()); + return OUT_OF_MEMORY; + } + auto ret = ge::NodeUtils::SetSubgraph(*node, i, subgraph); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to set subgraph %s to node %s index %u", subgraph_iname.c_str(), node->GetName().c_str(), + i); + return ret; + } + + GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s", + node->GetName().c_str(), i, subgraph_iname.c_str()); + args.push_back({nullptr, subgraph_iname, node, subgraph_name_to_index.first, subgraph}); + } + } + GELOGI("Gen subgraph parse tasks end"); + return SUCCESS; +} + +Status PostOpProcessForSubgraph(const ParseArg &arg) { + if (arg.parent_node == nullptr) { + return SUCCESS; + } + + auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(arg.parent_node->GetType()); + if (post_func == nullptr) { + GELOGW("The subgraph post func for node %s type %s is null", arg.parent_node->GetName().c_str(), + arg.parent_node->GetType().c_str()); + return SUCCESS; + } + + GELOGD("Post process for subgraph %s node %s type %s subgraph name %s", arg.function_name.c_str(), + arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), arg.subgraph_name.c_str()); + + // refresh node_name in subgraph + for (const ge::NodePtr &node : arg.graph->GetDirectNode()) { + if ((node->GetOpDesc() == nullptr) || (node->GetType() == "Variable") || (node->GetType() == "VariableV2")) { + continue; + } + node->GetOpDesc()->SetName(node->GetOwnerComputeGraph()->GetName() + "/" + node->GetName()); + } + + auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(arg.graph); + auto ret = post_func(arg.subgraph_name, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s subgraph name %s", arg.function_name.c_str(), + arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), arg.subgraph_name.c_str()); + return FAILED; + } + + return SUCCESS; +} +} // namespace + +/** + * @ingroup domi_omg + * @brief Trans common decorate function to PartitionedCall. + * @param [in] node_def: Node of common function. + * @param [out] op: result of PartitionedCall OpDesc. + * @return 0: SUCCESS / Others: FAILED + */ +Status TensorFlowModelParser::DefunToPartitionedCall(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op) { + const string op_name = node_def->name(); + domi::tensorflow::AttrValue attr_call_inference; + if (!ge::TensorFlowUtil::FindAttrValue(node_def, "_disable_call_shape_inference", attr_call_inference)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19014", {"opname", "value", "reason"}, + {node_def->name(), "attr [_disable_call_shape_inference]", "is not exist in nodedef"}); + GELOGE(FAILED, "In NodeDef %s attr [_disable_call_shape_inference] not exist.", op_name.c_str()); + return FAILED; + } + + op = ge::MakeShared(op_name, ge::parser::PARTITIONEDCALL); + GE_CHECK_NOTNULL(op); + + size_t input_tensor_num = 0; + size_t output_tensor_num = 0; + GetInputOutputTensorNum(op, input_tensor_num, output_tensor_num); + + for (size_t i = 0; i < input_tensor_num; ++i) { + ge::GeTensorDesc input_tensor; + if (op->AddInputDesc(input_tensor) != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "op [%s] type[%s] add input(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i); + return FAILED; + } + } + + for (size_t i = 0; i < output_tensor_num; ++i) { + ge::GeTensorDesc output_tensor; + if (op->AddOutputDesc(output_tensor) != ge::GRAPH_SUCCESS) { + GELOGE(FAILED, "op [%s] type[%s] add output(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i); + return FAILED; + } + } + + GELOGI("After AddTensorDescToOpDesc op[%s]: type[%s] have input size: %zu, output size: %zu, disable inference: %d", + op_name.c_str(), op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize(), attr_call_inference.b()); + + (void)op->AddSubgraphName("f"); + (void)op->SetSubgraphInstanceName(0, op_name); + return SUCCESS; +} + +Status TensorFlowModelParser::TransNodeToOpDesc(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, + const string &op_type) { + GE_CHECK_NOTNULL(node_def); + string node_name = node_def->name(); + ge::Operator op_factory = ge::OperatorFactory::CreateOperator(node_name, op_type); + if (op_factory.GetName() != node_name || op_type == ge::parser::DATA) { + if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) { + op = ge::MakeShared(node_name, op_type); + GE_CHECK_NOTNULL(op); + } else if (node_name == op_type) { + // Trans @tensorflow.python.framework.Defun(...) to PartitionedCall. + GE_RETURN_IF_ERROR(DefunToPartitionedCall(node_def, op)); + GE_CHECK_NOTNULL(op); + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E12011", {"opname", "optype"}, {node_name, op_type}); + GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); + return FAILED; + } + } else { + op = ge::OpDescUtils::GetOpDescFromOperator(op_factory); + GE_CHECK_NOTNULL(op); + GELOGI("After GetOpDescFromOperator op[%s]: type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(), + op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize()); + + GE_RETURN_IF_ERROR(AddTensorDescToOpDesc(op, node_def)); + GELOGI("After AddTensorDescToOpDesc op[%s]: type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(), + op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize()); + } + op_factory.BreakConnect(); + return SUCCESS; +} + +Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, + shared_ptr &op_parser) { + GE_CHECK_NOTNULL(node_def); + GE_CHECK_NOTNULL(op); + GE_CHECK_NOTNULL(op_parser); + + string node_name = node_def->name(); + string node_op = node_def->op(); + + Status status = FAILED; + domi::ParseParamByOpFunc parse_param_by_op_fn = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(node_op); + if (parse_param_by_op_fn == nullptr) { + shared_ptr tensorflow_op_parser = std::dynamic_pointer_cast(op_parser); + GE_CHECK_NOTNULL(tensorflow_op_parser); + status = tensorflow_op_parser->ParseParams(node_def, op); + if (status != SUCCESS) { + GELOGE(status, "Parse params for node[%s] failed", node_name.c_str()); + return status; + } + } else { + ge::Operator op_src(node_def->name(), node_def->op()); + status = domi::AutoMappingFn(node_def, op_src); + if (status != SUCCESS) { + GELOGE(status, "Node[%s] auto mapping failed.", node_name.c_str()); + return status; + } + std::shared_ptr tf_custom_op_parser = + std::dynamic_pointer_cast(op_parser); + GE_CHECK_NOTNULL(tf_custom_op_parser); + status = tf_custom_op_parser->ParseParams(op_src, op); + if (status != SUCCESS) { + GELOGE(status, "Parse params for node[%s] failed", op_src.GetName().c_str()); + return status; + } + } + return SUCCESS; +} + +Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, ge::ComputeGraphPtr &graph, + shared_ptr &scope_graph) { + GE_CHECK_NOTNULL(node_def); + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(scope_graph); + domi::tensorflow::AttrValue attr_value; + if (ge::TensorFlowUtil::FindAttrValue(node_def, kAttrNameIsScopeInnerNode, attr_value) && attr_value.b()) { + std::mutex graph_mutex; + return AddScopeInnerNode(this, graph, &graph_mutex, node_def); + } + // node is released in destructor + string node_name = node_def->name(); + string node_op = node_def->op(); + auto type_it = tensorflow_op_map.find(node_op); + if (type_it == tensorflow_op_map.end()) { + GELOGI("Can not find,maybe this node has no plugin node_name is %s, node_op is %s ", node_name.c_str(), + node_op.c_str()); + ge::OpDescPtr op_desc; + GE_RETURN_IF_ERROR(TransNodeToOpDesc(node_def, op_desc, node_op)); + + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); + GE_CHK_STATUS(domi::AutoMappingFn(node_def, op)); + op.BreakConnect(); + + ge::NodePtr node = nullptr; + node = graph->AddNode(op_desc); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), DeleteFuisonNodeDef(); return INTERNAL_ERROR, "add node failed."); + + node_map_[node_name] = node; + + return SUCCESS; + } + + string op_type = type_it->second; + + // The type value is obtained from the definition map set of DaVinci. + ge::OpDescPtr op; + GE_RETURN_IF_ERROR(TransNodeToOpDesc(node_def, op, op_type)); + + bool needFusion = IsFusionOp(scope_graph, node_def); + // The number of inputs and outputs of each operator can be determined after the new IR design model is resolved. + // Add tensordesc to the opdesc object of the operator + // Process change of tensordesc initialization of opdesc, + // Previous process: Tensordesc is constructed according to graph structure in builder stage + // Current process: Tensordesc is determined before the opdesc of the operator is added to the graph + Status status = FAILED; + // create OpParser + shared_ptr factory = OpParserFactory::Instance(domi::TENSORFLOW); + GE_CHECK_NOTNULL(factory); + if (!needFusion) { + shared_ptr op_parser = factory->CreateOpParser(op_type); + // parse op param + status = ParseOpParams(node_def, op, op_parser); + if (status != SUCCESS) { + GELOGE(status, "Parse params for node[%s] failed", node_name.c_str()); + return status; + } + } + GELOGI("After op parser op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(), + op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize()); + // checkout op input number with IR + GE_RETURN_IF_ERROR(CheckoutInputNum(op, node_def)); + ge::NodePtr node = graph->AddNode(op); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), DeleteFuisonNodeDef(); return INTERNAL_ERROR, "add node failed."); + + node_map_[node_name] = node; + + if (needFusion) { + shared_ptr fusion_op_parser = factory->CreateFusionOpParser(op_type); + GE_CHECK_NOTNULL(fusion_op_parser); + // Find all children of the fusion operator + auto iter = fusion_op_nodedef_map_.find(node_def->name()); + if (iter == fusion_op_nodedef_map_.end()) { + GELOGE(FAILED, "FusionOp node %s has no children node!", node_name.c_str()); + return INTERNAL_ERROR; + } + vector node_def_v = iter->second; + // parse fusion node param + status = FusionNodeParseParams(fusion_op_parser, node_def, node); + if (status != SUCCESS) { + GELOGE(status, "Parse params for fusion node[%s] failed", node_name.c_str()); + return status; + } + // record original op names + std::vector namesTmp; + for (auto &node_def_iter : node_def_v) { + GE_CHECK_NOTNULL(node_def_iter); + std::string nodeName = node_def_iter->name(); + namesTmp.push_back(nodeName); + } + + ge::GraphUtils::RecordOriginalNames(namesTmp, node); + status = RecordFusionResult(scope_graph, node_def, op); + if (status != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Record fusion result for fusion op: %s failed", op->GetName().c_str()); + DeleteFuisonNodeDef(); + return status; + } + } + return SUCCESS; +} + +void TensorFlowModelParser::GetInputOutputTensorNum(ge::OpDescPtr &op_desc, size_t &input_tensor_num, + size_t &output_tensor_num) const { + // The caller guarantees that the pointer is not null + auto iter = op_node_context_map_.find(op_desc->GetName()); + if (iter == op_node_context_map_.end()) { + return; + } + const OpNodeContext &op_context = iter->second; + const std::map>> &dest_input_map = op_context.input_map; + // input number + input_tensor_num = 0; + for (auto &input_vec : dest_input_map) { + for (auto &input_v : input_vec.second) { + if (input_v.second != kControlSlot) { + input_tensor_num++; + } + } + } + + // output number + const std::map>> &src_output_map = op_context.output_map; + int32_t max_anchor_index = 0; + for (auto &src_output_iter : src_output_map) { + for (auto &index_output_iter : src_output_iter.second) { + if (index_output_iter.first > max_anchor_index) { + max_anchor_index = index_output_iter.first; + } + } + } + output_tensor_num = max_anchor_index + 1; + + return; +} + +Status TensorFlowModelParser::CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(op_desc); + + if (std::find(kSkipCheckoutInputSizeNodes.begin(), kSkipCheckoutInputSizeNodes.end(), op_desc->GetType()) != + kSkipCheckoutInputSizeNodes.end()) { + return SUCCESS; + } + + // get input and output tensor number + size_t input_tensor_num = 0; + size_t output_tensor_num = 0; + GetInputOutputTensorNum(op_desc, input_tensor_num, output_tensor_num); + + // get input and output tensor number from op desc + size_t factory_input_size = op_desc->GetInputsSize(); + if (input_tensor_num != factory_input_size) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19014", {"opname", "value", "reason"}, + {op_desc->GetName(), "input number of tensorflow[" + std::to_string(input_tensor_num) + "]", + "should be equal to factory size[" + std::to_string(factory_input_size) + "]"}); + GELOGE(FAILED, "op [%s], type[%s], The input number of tensorflow[%zu] should be equal to factory size[%zu]", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), input_tensor_num, factory_input_size); + return FAILED; + } + return SUCCESS; +} + +void TensorFlowModelParser::UpdateInputTensor(ge::OpDescPtr &op_desc, const std::vector &input_desc, + const size_t input_tensor_num) { + // The caller guarantees that the pointer is not null + for (size_t i = 0; i < input_tensor_num; ++i) { + if (i < input_desc.size()) { + // i is guaranteed to be valid, no check required. + ge::graphStatus ret = op_desc->UpdateInputDesc(op_desc->GetInputNameByIndex(i), input_desc[i]); + if (ret != ge::GRAPH_SUCCESS) { + // UpdateInputDesc for dynamic intput will be failed, but it will be added in later op parser. + GELOGI("op [%s], type[%s], update input(%zu) with name %s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str()); + } + } else { + ge::GeTensorDesc input_tensor; + // i is guaranteed to be valid, no check required. + ge::graphStatus ret = op_desc->UpdateInputDesc(op_desc->GetInputNameByIndex(i), input_tensor); + if (ret != ge::GRAPH_SUCCESS) { + // UpdateInputDesc for dynamic intput will be failed, but it will be added in later op parser. + GELOGI("op [%s], type[%s], update input(%zu) with name %s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str()); + } + } + } +} + +void TensorFlowModelParser::UpdateOutputTensor(ge::OpDescPtr &op_desc, const std::vector &output_desc, + size_t output_tensor_num) { + // The caller guarantees that the pointer is not null + for (size_t i = 0; i < output_tensor_num; ++i) { + if (i < output_desc.size()) { + // i is guaranteed to be valid, no check required. + ge::graphStatus ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(i), output_desc[i]); + if (ret != ge::GRAPH_SUCCESS) { + // UpdateOutputDesc for dynamic output will be failed, but it will be added in later op parser. + GELOGI("op [%s], type[%s], update output(%zu) with name %s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str()); + } + } else { + ge::GeTensorDesc output_tensor; + // i is guaranteed to be valid, no check required. + ge::graphStatus ret = op_desc->UpdateOutputDesc(op_desc->GetOutputNameByIndex(i), output_tensor); + if (ret != ge::GRAPH_SUCCESS) { + // UpdateOutputDesc for dynamic output will be failed, but it will be added in later op parser. + GELOGI("op [%s], type[%s], update output(%zu) with name %s failed", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i, op_desc->GetInputNameByIndex(i).c_str()); + } + } + } +} + +Status TensorFlowModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(op_desc); + // get input and output attr from tensorflow + const string type = node->op(); + domi::tensorflow::AttrValue input_attr_value; + domi::tensorflow::AttrValue output_attr_value; + ParserOperator temp_op; + if (ge::TensorFlowUtil::FindAttrValue(node, ge::parser::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { + GE_CHK_STATUS_RET(ge::TensorFlowUtil::TransTensorDescriptor(input_attr_value, &temp_op, + TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG, type), + "trans input_attr_value failed, op: %s", node->name().c_str()); + } else { + GELOGD("Frameworkop has no input tensor desc, name:%s, type:%s.", node->name().c_str(), type.c_str()); + } + if (ge::TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { + GE_CHK_STATUS_RET(ge::TensorFlowUtil::TransTensorDescriptor(output_attr_value, &temp_op, + TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG, type), + "trans output_attr_value failed, op: %s", node->name().c_str()); + } else { + GELOGD("Frameworkop has no output tensor desc, name:%s, type:%s.", node->name().c_str(), type.c_str()); + } + + auto iter = op_node_context_map_.find(op_desc->GetName()); + if (iter == op_node_context_map_.end()) { + return SUCCESS; + } + + const std::vector &input_desc = temp_op.GetInputTensorDesc(); + const std::vector &output_desc = temp_op.GetOutputTensorDesc(); + + // get input and output tensor number + size_t input_tensor_num = 0; + size_t output_tensor_num = 0; + GetInputOutputTensorNum(op_desc, input_tensor_num, output_tensor_num); + + // update input + UpdateInputTensor(op_desc, input_desc, input_tensor_num); + + // update output + UpdateOutputTensor(op_desc, output_desc, output_tensor_num); + + return SUCCESS; +} + +Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + for (auto &src_iter : op_node_context_map_) { + string src_op_name = src_iter.first; + OpNodeContext src_op_node_context = src_iter.second; + std::map>> &src_output_map = src_op_node_context.output_map; + // Traverse all output of the op_node + for (auto &src_output_iter : src_output_map) { + string dest_op_name = src_output_iter.first; + auto dest_iter = op_node_context_map_.find(dest_op_name); + if (dest_iter == op_node_context_map_.end()) { + continue; + } + // Find that the output of the source node is equal to the destination node + std::map>> &dest_input_map = dest_iter->second.input_map; + auto input_iter = dest_input_map.find(src_op_name); + // Find output and input + if (input_iter == dest_input_map.end()) { + continue; + } + auto iter = node_map_.find(src_op_name); + if (iter == node_map_.end()) { + continue; + } + ge::NodePtr src = iter->second; + GE_CHECK_NOTNULL(src); + auto iter1 = node_map_.find(dest_op_name); + if (iter1 == node_map_.end()) { + continue; + } + // Each pair builds an edge + ge::NodePtr dest = iter1->second; + GE_CHECK_NOTNULL(dest); + if (src_output_iter.second.size() != input_iter->second.size()) { + ErrorManager::GetInstance().ATCReportErrMessage("E12021", {"opname1", "index1", "opname2", "index2"}, + {src_op_name, std::to_string(input_iter->second.size()), + dest_op_name, std::to_string(src_output_iter.second.size())}); + GELOGE(INTERNAL_ERROR, "Input size of op[%s]:%d is not equal to Output size of op[%s]:%d.", src_op_name.c_str(), + input_iter->second.size(), dest_op_name.c_str(), src_output_iter.second.size()); + return INTERNAL_ERROR; + } + for (auto &outputpair : src_output_iter.second) { + // Get control edge properties + bool control = GetEdgesControlInfo(dest_op_name, outputpair.second); + // Graph create new edge + if (!control) { + GELOGD("Start add edge: from %s:%d to %s:%d.", src->GetName().c_str(), outputpair.first, + dest->GetName().c_str(), outputpair.second); + ge::OutDataAnchorPtr out_archor_ptr = src->GetOutDataAnchor(outputpair.first); + GE_CHECK_NOTNULL(out_archor_ptr); + ge::InDataAnchorPtr in_archor_ptr = dest->GetInDataAnchor(outputpair.second); + GE_CHECK_NOTNULL(in_archor_ptr); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, + ErrorManager::GetInstance().ATCReportErrMessage( + "E12014", {"opname1", "opname2"}, {src->GetName(), dest->GetName()}); + return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", + src->GetName().c_str(), dest->GetName().c_str()); + } else { + GELOGD("Start add contorl edge: from %s to %s.", src->GetName().c_str(), dest->GetName().c_str()); + ge::InControlAnchorPtr in_archor_ptr = dest->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_archor_ptr); + GE_IF_BOOL_EXEC(nodedef_map_[src_op_name]->op() != TENSORFLOWF_NODE_OP_SWITCH, + ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_archor_ptr); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, + ErrorManager::GetInstance().ATCReportErrMessage("E12014", {"opname1", "opname2"}, + {src->GetName(), dest->GetName()}); + return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), + dest->GetName().c_str());); + + GE_IF_BOOL_EXEC(nodedef_map_[src_op_name]->op() == TENSORFLOWF_NODE_OP_SWITCH, + ge::OutDataAnchorPtr out_data_archor_ptr = src->GetOutDataAnchor(outputpair.first); + GE_CHECK_NOTNULL(out_data_archor_ptr); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::GraphUtils::AddEdge(out_data_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, + ErrorManager::GetInstance().ATCReportErrMessage("E12014", {"opname1", "opname2"}, + {src->GetName(), dest->GetName()}); + return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), + dest->GetName().c_str());); + } + } + dest_input_map.erase(input_iter); + } + } + + return SUCCESS; +} + +Status TensorFlowModelParser::AddFmkNodeDefToMap(const domi::tensorflow::GraphDef &graph_def, + const domi::tensorflow::NodeDef *node_def, + vector &op_node_name_list) { + GE_CHECK_NOTNULL(node_def); + const string &node_name = node_def->name(); + + nodedef_map_[node_name] = node_def; + + OpNodeContext op_node_context; + op_node_context_map_[node_name] = op_node_context; + op_node_name_list.push_back(node_name); + + return SUCCESS; +} + +Status TensorFlowModelParser::CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set &dims, + bool &valid) { + GE_CHECK_NOTNULL(node_def); + domi::tensorflow::AttrValue input_attr_value; + bool is_attr_exist = + ge::TensorFlowUtil::FindAttrValue(node_def, ge::parser::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value); + GE_IF_BOOL_EXEC(!is_attr_exist, return SUCCESS); + GE_CHK_BOOL_EXEC(input_attr_value.has_list(), return PARAM_INVALID, "output attr value vector is empty"); + + // list contain many TensorDescriptors + domi::tensorflow::AttrValue_ListValue a_list = input_attr_value.list(); + for (int32_t i = 0; i < a_list.func_size(); i++) { + ge::GeTensorDesc ge_desc; + int32_t tf_datatype = 0; + GE_CHK_BOOL_RET_STATUS(ge::TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, i, tf_datatype), PARAM_INVALID, + "parse ge_desc failed."); + + for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { + int64_t temp_dim = ge_desc.GetShape().GetDim(j); + GE_IF_BOOL_EXEC(dims.count(temp_dim) > 0, valid = false); + } + } + + return SUCCESS; +} + +Status TensorFlowModelParser::CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type) { + GE_CHECK_NOTNULL(node_def); + bool valid = true; + string node_name = node_def->name(); + + std::map> check_dims = { + {ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, {10}}, + }; + + GE_IF_BOOL_EXEC( + op_type == ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, + GE_CHK_STATUS_RET(CheckOpShapeDim(node_def, check_dims[op_type], valid), "failed to check op shape"); + GE_IF_BOOL_EXEC(!valid, op_type = ge::parser::FRAMEWORKOP; GELOGI("Set op %s to frameworkop", node_name.c_str()); + framework_ops_[node_name] = node_def;);); + + GE_IF_BOOL_EXEC( + op_type == ge::parser::ADD || op_type == ge::parser::MULTIPLY || op_type == ge::parser::MEAN, + for (const string &input_name + : node_def->input()) { + string tmp_input_name; + GE_RETURN_IF_ERROR(CheckInputNodeName(input_name, &tmp_input_name, nullptr, nullptr)); + GELOGD("Add or Mul op %s input name is %s", node_name.c_str(), input_name.c_str()); + GE_IF_BOOL_EXEC(framework_ops_.find(tmp_input_name) != framework_ops_.end(), + GELOGI("Set op %s to frameworkop", node_name.c_str()); + op_type = ge::parser::FRAMEWORKOP;); + }); + + return SUCCESS; +} + +/* + * @ingroup domi_omg + * @brief Mapping TF's datatype to GE's datatype + * @param [in] type, datatype types of operators in TF networks + * @return ge::DataType + */ +ge::DataType TensorFlowModelParser::ConvertToGeDataType(const uint32_t type) { + ge::DataType data_type = domi::TensorAssign::ConvertTensorflowDataType(type); + return data_type; +} + +Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, + std::mutex *graphMutex, shared_ptr &scope_graph, + const domi::tensorflow::NodeDef *node_def) { + // The caller guarantees that the pointer is not null + string node_name = node_def->name(); + string node_op = node_def->op(); + GELOGD("TF op node name = %s, op type= %s", node_name.c_str(), node_op.c_str()); + domi::tensorflow::AttrValue attr_value; + if (ge::TensorFlowUtil::FindAttrValue(node_def, kAttrNameIsScopeInnerNode, attr_value) && attr_value.b()) { + return AddScopeInnerNode(parser, graph, graphMutex, node_def); + } + + auto iterator = parser->adaptedOpTypeMap_.find(node_name); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iterator == parser->adaptedOpTypeMap_.end(), return FAILED, + "get adapted op type failed, node name = %s", node_name.c_str()); + + string op_type = iterator->second; + // Log printing for determining operator type + domi::ImplyType implyType = domi::OpRegistry::Instance()->GetImplyType(op_type); + GE_IF_BOOL_EXEC((implyType == domi::ImplyType::TVM) && (op_type != ge::parser::FRAMEWORKOP), + GELOGD("TBE %s parsering", node_op.c_str());); + GE_IF_BOOL_EXEC((implyType == domi::ImplyType::CCE) && (op_type != ge::parser::FRAMEWORKOP), + GELOGD("CCE %s parsering", node_op.c_str());); + GE_IF_BOOL_EXEC((implyType == domi::ImplyType::HCCL) && (op_type != ge::parser::FRAMEWORKOP), + GELOGD("HCCL %s parsering", node_op.c_str());); + GE_IF_BOOL_EXEC(op_type == ge::parser::FRAMEWORKOP, GELOGD("FRAMEWORKOP %s parsering", node_op.c_str());); + GELOGD("TF op node name = %s, op type= %s, trans to op type %s", node_name.c_str(), node_op.c_str(), op_type.c_str()); + + // Construct operator by IR + ge::OpDescPtr op; + ge::Operator op_factory = ge::OperatorFactory::CreateOperator(node_name, op_type); + if (op_factory.GetName() != node_name) { + if (std::find(kMakeOperatorNotByIr.begin(), kMakeOperatorNotByIr.end(), op_type) != kMakeOperatorNotByIr.end()) { + op = ge::MakeShared(node_name, op_type); + GE_CHECK_NOTNULL(op); + } else if (node_name == op_type) { + GE_RETURN_IF_ERROR(parser->DefunToPartitionedCall(node_def, op)); + GE_CHECK_NOTNULL(op); + ge::Operator op_tmp = ge::OpDescUtils::CreateOperatorFromOpDesc(op); + GE_CHK_STATUS(domi::AutoMappingFn(node_def, op_tmp)); + op_tmp.BreakConnect(); + ge::NodePtr node; + { + std::lock_guard lock(*graphMutex); + node = graph->AddNode(op); + } + GE_CHECK_NOTNULL(node); + { + std::lock_guard lock(parser->nodeMapMutex_); + parser->node_map_[node_name] = node; + } + return SUCCESS; + } else { + GELOGE(INTERNAL_ERROR, "op[%s] type[%s] have no ir factory.]", node_name.c_str(), op_type.c_str()); + return FAILED; + } + } else { + op = ge::OpDescUtils::GetOpDescFromOperator(op_factory); + GE_CHECK_NOTNULL(op); + GELOGD("After GetOpDescFromOperator op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(), + op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize()); + + GE_RETURN_IF_ERROR(parser->AddTensorDescToOpDesc(op, node_def)); + GELOGD("After AddTensorDescToOpDesc op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(), + op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize()); + } + GELOGD("TF op node name = %s, outpusize= %zu", node_name.c_str(), op->GetAllOutputsDesc().size()); + op_factory.BreakConnect(); + + // create OpParser + shared_ptr factory = OpParserFactory::Instance(domi::TENSORFLOW); + GE_CHECK_NOTNULL(factory); + bool needFusion = parser->IsFusionOp(scope_graph, node_def); + GELOGD("TF op node name = %s, op type= %s is fusion op(NO: 0; YES: 1)= %d", node_name.c_str(), node_op.c_str(), + needFusion); + + Status status = FAILED; + if (!needFusion) { + shared_ptr op_parser = factory->CreateOpParser(op_type); + status = parser->ParseOpParams(node_def, op, op_parser); + if (status != SUCCESS) { + GELOGE(status, "Parse params for node[%s] failed", node_name.c_str()); + return status; + } + } + GELOGD("After op parser op[%s] type[%s] have input size: %zu, output size: %zu", op->GetName().c_str(), + op->GetType().c_str(), op->GetInputsSize(), op->GetOutputsSize()); + + // checkout op input number with IR + GE_RETURN_IF_ERROR(parser->CheckoutInputNum(op, node_def)); + + if (needFusion) { + status = RecordFusionResult(scope_graph, node_def, op); + if (status != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Record fusion result for fusion op: %s failed", op->GetName().c_str()); + return status; + } + } + + ge::NodePtr node; + { + std::lock_guard lock(*graphMutex); + node = graph->AddNode(op); + } + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), return INTERNAL_ERROR, "add node failed."); + + if (needFusion) { + shared_ptr fusion_op_parser = factory->CreateFusionOpParser(op_type); + status = parser->FusionNodeParseParams(fusion_op_parser, node_def, node); + GE_CHK_STATUS_EXEC(status, return status, "Parse Params for node %s failed", node_name.c_str()); + } + + { + std::lock_guard lock(parser->nodeMapMutex_); + parser->node_map_[node_name] = node; + } + + return SUCCESS; +} + +Status TensorFlowModelParser::AdaptOpType(const domi::tensorflow::NodeDef *node_def, bool isDatasetInit) { + // The caller guarantees that the pointer is not null + string node_name = node_def->name(); + string node_op = node_def->op(); + string op_type; + if (tensorflow_train_op_map.find(node_op) != tensorflow_train_op_map.end()) { + op_type = tensorflow_train_op_map.at(node_op); + GE_CHK_STATUS_RET(CheckOpType(node_def, op_type), "Failed to check op type"); + } else { + op_type = ge::parser::FRAMEWORKOP; + domi::tensorflow::AttrValue attr_call_inference; + if ((node_name == node_op) && + ge::TensorFlowUtil::FindAttrValue(node_def, "_disable_call_shape_inference", attr_call_inference)) { + op_type = node_op; + } + } + + GE_IF_BOOL_EXEC(isDatasetInit, op_type = ge::parser::FRAMEWORKOP); + adaptedOpTypeMap_[node_name] = op_type; + + return SUCCESS; +} + +Status TensorFlowModelParser::AddFmkNode(ge::ComputeGraphPtr &graph, shared_ptr &scope_graph, + vector &op_node_name_list, bool is_dataset_init) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(scope_graph); + + GE_RETURN_IF_ERROR(AddFusionNodeDef(scope_graph, op_node_name_list)); + size_t op_node_list_size = op_node_name_list.size(); + for (size_t i = 0; i < op_node_list_size; ++i) { + const string op_node_name = op_node_name_list[i]; + const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name]; + GE_CHECK_NOTNULL(node_def); + GE_RETURN_IF_ERROR(AdaptOpType(node_def, is_dataset_init)); + } + GELOGD("Add fusion nodedef and Adapt op type success"); + + // Multithreading parallel parsing nodedef + ThreadPool executor(kThreadNum); + std::mutex graphMutex; + std::vector> vectorFuture(op_node_list_size); + ge::ComputeGraphPtr graph_tmp = ge::MakeShared("tmpGraph"); + GE_CHECK_NOTNULL(graph_tmp); + for (size_t j = 0; j < op_node_list_size; j++) { + const string op_node_name = op_node_name_list[j]; + const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name]; + GE_CHECK_NOTNULL(node_def); + std::future f = + executor.commit(TensorFlowModelParser::ParseNodeDef, this, graph_tmp, &graphMutex, scope_graph, node_def); + if (!f.valid()) { + GELOGE(FAILED, "Future is invalid"); + return FAILED; + } + vectorFuture[j] = std::move(f); + } + GELOGD("Parse nodedef success"); + // Wait for the return value of each thread. If the thread does not finish processing, it will block here + bool ret_flag = true; + size_t futureSize = vectorFuture.size(); + for (size_t i = 0; i < futureSize; ++i) { + Status retStatus = vectorFuture[i].get(); + if (retStatus != SUCCESS) { + ret_flag = false; + } + } + if (!ret_flag) { + return FAILED; + } + return AddNodeToGraphAndMarkFormat(graph, op_node_name_list); +} + +Status TensorFlowModelParser::AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &graph, + const vector &op_node_name_list) { + // Add ge:: nodeptr to graph in order + size_t op_node_list_size = op_node_name_list.size(); + for (size_t j = 0; j < op_node_list_size; j++) { + const string op_node_name = op_node_name_list[j]; + auto iterator = node_map_.find(op_node_name); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((iterator == node_map_.end()), return INTERNAL_ERROR, "add node failed."); + GE_CHECK_NOTNULL(iterator->second); + GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed"); + graph->AddNode(iterator->second); + } + + // mark format with default one explained in tf documents for some nodes + for (auto &node : graph->GetDirectNode()) { + auto nodeType = node->GetType(); + auto iter = kNeedMarkFormatNodes.find(nodeType); + if (iter != kNeedMarkFormatNodes.end()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + for (auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { + input_desc->SetOriginFormat((ge::Format)iter->second); + input_desc->SetFormat((ge::Format)iter->second); + } + for (auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { + output_desc->SetOriginFormat((ge::Format)iter->second); + output_desc->SetFormat((ge::Format)iter->second); + } + } + } + + return SUCCESS; +} + +Status TensorFlowModelParser::ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *graph_def, + shared_ptr &scope_graph) { + // Identifying scope fusion operators based on scope rules + GE_CHECK_NOTNULL(graph_def); + ScopePassManager passmanager; + PARSER_TIMESTAMP_START(BuildScopeGraph); + scope_graph = passmanager.BuildScopeGraph(graph_def); + GE_CHECK_NOTNULL(scope_graph); + PARSER_TIMESTAMP_END(BuildScopeGraph, "TensorFlowModelParser::BuildScopeGraph"); + PARSER_TIMESTAMP_START(ScopeGraphPass); + // Validate the non-general scope fusion pass. + // The parameter is set to the name of the fusion rule. + // Multiple names can be set and separated by ",". + std::vector enable_pass_names = + ge::StringUtils::Split(ge::GetParserContext().enable_scope_fusion_passes, ','); + auto &impl = ge::ScopeFusionPassRegistry::GetInstance().impl_; + if (impl == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "ScopeFusionPassRegistry is not properly initialized."); + return ge::MEMALLOC_FAILED; + } + + for (size_t i = 0; i < enable_pass_names.size(); ++i) { + if (enable_pass_names[i].empty()) { + continue; + } + if (!impl->SetPassEnableFlag(enable_pass_names[i], true)) { + GELOGW("Failed to set enable flag of scope fusion pass:%s", enable_pass_names[i].c_str()); + } + } + std::vector scope_passes_list = impl->GetAllRegisteredPasses(); + Status ret = RunScopeFusionPass(scope_passes_list, passmanager, scope_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Run scope fusion failed, ret:%u.", ret); + return ret; + } + PARSER_TIMESTAMP_END(ScopeGraphPass, "TensorFlowModelParser::ScopeGraphPass"); + + return SUCCESS; +} + +Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(data); + GE_CHECK_NOTNULL(graph); + + // Store objects parsed from pb files + domi::tensorflow::GraphDef OriDef; + + bool read = ge::parser::ReadProtoFromArray(data, static_cast(size), &OriDef); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, return INTERNAL_ERROR, "read_proto_from_binary failed."); + + domi::tensorflow::GraphDef graph_def; + if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) { + graph_def = OriDef; + } else { + GELOGI("Before Trim, the Graph Node size is:%d", OriDef.node_size()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(TrimGraph(OriDef, &graph_def), return INTERNAL_ERROR, "Trim Graph fail."); + GELOGI("After Trim, The graph_def.node_size():%d", graph_def.node_size()); + } + + shared_ptr scope_graph = nullptr; + Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[TF ParseFromMemory] scope fusion failed."); + return ret; + } + GELOGD("[TF ParseFromMemory] scope fusion success"); + + // Add nodedef in the model to prechecker and check the general parameters + for (int i = 0; i < graph_def.node_size(); i++) { + const domi::tensorflow::NodeDef &node = graph_def.node(i); + GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()), + "Add node_def to PreChecker failed, node name: %s.", node.name().c_str()); + GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().CheckName(&node), "Check node_def name failed, node name: %s.", + node.name().c_str()); + if (node.op() != TENSORFLOWF_NODE_OP_IDENTITY) { + GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().CheckType(&node, true), + "Check node_def type failed, node name: %s.", node.name().c_str()); + } + } + + bool has_error = false; + // save node name + vector op_node_name_list; + for (int i = 0; i < graph_def.node_size(); i++) { + const domi::tensorflow::NodeDef *node_def = graph_def.mutable_node(i); + + // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_ + GE_IF_BOOL_EXEC(MaybeFusionOp(scope_graph, node_def), + GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str());); + + // Do not exit immediately when there is an error, wait until all errors are collected before exiting + GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(graph_def, node_def, op_node_name_list), has_error = true, + "add node failed."); + } + + // Verify the validity of fusionop + GE_RETURN_IF_ERROR(CheckFusionOpValid()); + + // The fusion operator has passed the verification. + // The errors of internal non key operators (which will be ignored later) + // do not affect the transformation of the whole model, + // So clear the error information of non key operators + // This function call affects the return value of prechecker::instance().Haserror() + GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); + + // Check the input validity of the node, the input attribute must have a corresponding node + GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def)); + + // Building input and input relationships for all OP nodes + GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); + GELOGD("[TF ParseFromMemory] get op nodes context from graph success"); + + // Infer input formats + ge::GetParserContext().format = InferInputFormats(); + GELOGD("[TF ParseFromMemory] infer input formats success"); + + // Building input-output relationship between fusionop and common op + GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, graph_def, op_node_name_list)); + + ret = AddFusionNodeDef(scope_graph, op_node_name_list); + if (ret != SUCCESS) { + GELOGE(ret, "Add fusion NodeDef failed."); + DeleteFuisonNodeDef(); + return ret; + } + GELOGI("TF op node size = %zu.", op_node_name_list.size()); + // Loop analysis of op_nodes and map them to nodes in graph + for (size_t i = 0; i < op_node_name_list.size(); i++) { + GELOGI("TF op node name = %s.", op_node_name_list[i].c_str()); + const string op_node_name = op_node_name_list[i]; + const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name_list[i]]; + if (node_def == nullptr) { + GELOGE(INTERNAL_ERROR, "Node def is nullptr, name:%s.", op_node_name.c_str()); + DeleteFuisonNodeDef(); + return INTERNAL_ERROR; + } + const string &node_op = node_def->op(); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((tensorflow_op_map.find(node_op) == tensorflow_op_map.end()), DeleteFuisonNodeDef(); + return INTERNAL_ERROR, "Unsupport op type %s", node_op.c_str()); + + ret = AddNode(node_def, graph, scope_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Add node failed, name:%s.", op_node_name.c_str()); + DeleteFuisonNodeDef(); + return ret; + } + } + + DeleteFuisonNodeDef(); + + GE_RETURN_IF_ERROR(AddEdges(graph)); + GE_RETURN_IF_ERROR(graph->TopologicalSorting()); + + has_error = has_error || PreChecker::Instance().HasError(); + if (has_error) { + GELOGE(PARAM_INVALID, "Precheck has errors."); + return PARAM_INVALID; + } + GELOGI("[TF ParseFromMemory] Parse from memory success."); + return SUCCESS; +} + +Status TensorFlowModelParser::GetFunctionProto(const string &file, + domi::tensorflow::GraphDefLibrary &graph_def_library) { + int pos = file.rfind('/'); + string graph_def_path = (pos == -1) ? kFuncDefLibraryFilePath : file.substr(0, pos) + "/" + kFuncDefLibraryFilePath; + GELOGI("Function def libraray path is %s.", graph_def_path.c_str()); + + bool read = ge::parser::ReadProtoFromText(graph_def_path.c_str(), &graph_def_library); + if (!read) { + GELOGE(INTERNAL_ERROR, + "Get subgraph library failed. " + "The model contains function operators. " + "Need to use the script func2graph.py in the atc package to save the subgraphs to graph_def_library.pbtxt"); + ErrorManager::GetInstance().ATCReportErrMessage("E12029"); + return FAILED; + } + + GELOGI("Get subgraph library success."); + return SUCCESS; +} + +Status TensorFlowModelParser::Parse(const char *model_path, ge::Graph &graph) { + GE_CHECK_NOTNULL(model_path); + ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(root_graph); + + Status ret = Parse(model_path, root_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); + return ret; + } + + GELOGI("Parser graph %s success.", graph.GetName().c_str()); + return SUCCESS; +} + +Status TensorFlowModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &root_graph) { + GE_CHECK_NOTNULL(model_path); + GE_CHECK_NOTNULL(root_graph); + + GELOGI("Parse file %s", model_path); + // Store objects parsed from pb files + domi::tensorflow::GraphDef ori_def; + bool read = ge::parser::ReadProtoFromBinaryFile(model_path, &ori_def); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, return INTERNAL_ERROR, "read_proto_from_binary failed."); + + // Trim graph by user input and output. + domi::tensorflow::GraphDef graph_def; + if (ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) { + graph_def = ori_def; + } else { + GELOGI("Before Trim, the Graph Node size is:%d", ori_def.node_size()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(TrimGraph(ori_def, &graph_def), return INTERNAL_ERROR, "Trim Graph fail."); + GELOGI("After Trim, The graph_def.node size is:%d", graph_def.node_size()); + } + + // Construct ParseArg for root graph. + google::protobuf::Message *root_proto = &graph_def; + std::deque tasks; + tasks.push_back({root_proto, "root", nullptr, "", root_graph}); + + // Get sub graph from graph_def_library.pbtxt which prepared before and stored in model_path. + std::map function_name_to_graphdef; + + // Parse all root graph and sub graph. + while (!tasks.empty()) { + auto arg = tasks.front(); + tasks.pop_front(); + + if (arg.proto == nullptr) { + if (function_name_to_graphdef.empty() && (ori_def.library().function_size() > 0)) { + GELOGI("Graph has function size: %d ", ori_def.library().function_size()); + domi::tensorflow::GraphDefLibrary graph_def_library; + GE_CHK_STATUS_RET(GetFunctionProto(model_path, graph_def_library)); + for (auto &ge_graph_def : graph_def_library.graph_def()) { + function_name_to_graphdef[ge_graph_def.name()] = ge_graph_def.graph(); + GELOGD("Graph_def name: %s, node size: %d", ge_graph_def.name().c_str(), ge_graph_def.graph().node_size()); + } + } + + auto iter = function_name_to_graphdef.find(arg.function_name); + if (iter == function_name_to_graphdef.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E12013", {"functionname"}, {arg.function_name}); + GELOGE(FAILED, "Failed to get subgraph by function name %s", arg.function_name.c_str()); + return FAILED; + } + arg.proto = &(iter->second); + } + + GELOGI("Begin to parse graph %s", arg.function_name.c_str()); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW); + auto ret = model_parser->ParseAllGraph(arg.proto, arg.graph); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(), + arg.graph->GetName().c_str()); + return ret; + } + + ret = PostOpProcessForSubgraph(arg); + if (ret != SUCCESS) { + // the error log has been printed inner the function + return ret; + } + + ret = GenSubgraphParseTasks(arg.graph, tasks); + if (ret != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E12017", {"graphname"}, {arg.graph->GetName()}); + GELOGE(ret, "Failed to gen tasks on graph %s for next iteration", arg.graph->GetName().c_str()); + return ret; + } + } + return SUCCESS; +} + +Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(proto); + GE_CHECK_NOTNULL(graph); + + const domi::tensorflow::GraphDef *ori_graph = reinterpret_cast(proto); + // Make a copy for operation without modifying the original graph def. + domi::tensorflow::GraphDef graph_def = *ori_graph; + + shared_ptr scope_graph = nullptr; + Status ret = ExcuteScopeFusionPasses(&graph_def, scope_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[TF Parse] scope fusion failed."); + return ret; + } + GELOGD("[TF Parse] scope fusion success"); + + GE_RETURN_IF_ERROR(OptimizeConstNodes4CustomOp(&graph_def)); + GELOGD("[TF Parse] optimize const nodes for custom op base success"); + + // Add nodedef in the model to prechecker and check the general parameters + // Prevent data residue in multiple calls + PreChecker::Instance().Clear(); + for (int i = 0; i < graph_def.node_size(); i++) { + const domi::tensorflow::NodeDef &node = graph_def.node(i); + GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().AddOp(&node, node.name(), node.op()), + "Add node_def to PreChecker failed, node name: %s.", node.name().c_str()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckName(&node) != SUCCESS, return FAILED, + "Check op[%s] failed, name repeat in tensorflow pb file.", node.name().c_str()); + GE_CHK_BOOL_EXEC_NOLOG( + node.op() == TENSORFLOWF_NODE_OP_IDENTITY, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(PreChecker::Instance().CheckType(&node, true) != SUCCESS, return FAILED, + "Check op[%s]'s optype failed, type is not supported.", node.name().c_str());) + } + + bool has_error = false; + // save node name + vector op_node_name_list; + for (int i = 0; i < graph_def.node_size(); i++) { + const domi::tensorflow::NodeDef *node_def = graph_def.mutable_node(i); + + // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_ + if (MaybeFusionOp(scope_graph, node_def)) { + GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str()); + } + + // Do not exit immediately when there is an error, wait until all errors are collected before exiting + GE_CHK_STATUS_EXEC(AddFmkNodeDefToMap(graph_def, node_def, op_node_name_list), has_error = true); + } + + // Verify the validity of fusionop + GE_RETURN_IF_ERROR(CheckFusionOpValid()); + + // The fusion operator has passed the verification. + // The errors of internal non key operators (which will be ignored later) + // do not affect the transformation of the whole model, + // So clear the error information of non key operators + // This function call affects the return value of prechecker::instance().Haserror() + GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list)); + + // Check the input validity of the node, the input attribute must have a corresponding node + GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def)); + GELOGD("[TF Parse] check graph success"); + + // Building input and input relationships for all OP nodes + GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def)); + GELOGD("[TF Parse] get op nodes context from graph success"); + + // Infer input formats + ge::GetParserContext().format = InferInputFormats(); + GELOGD("[TF Parse] infer input formats success"); + + // Building input-output relationship between fusionop and common op + GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, graph_def, op_node_name_list)); + GELOGD("[TF Parse] update all node op context success"); + + // set user-designate-inputs-order + std::vector user_inputs_order; + for (auto &input : ge::GetParserContext().user_input_dims) { + user_inputs_order.push_back(input.first); + } + graph->SetInputsOrder(user_inputs_order); + + ret = AddFusionNodeDef(scope_graph, op_node_name_list); + if (ret != SUCCESS) { + GELOGE(ret, "Add fusion NodeDef failed."); + DeleteFuisonNodeDef(); + return ret; + } + GELOGI("TF op node size = %zu.", op_node_name_list.size()); + + // Loop analysis of op_nodes and map them to nodes in graph + for (size_t i = 0; i < op_node_name_list.size(); i++) { + GELOGI("TF op node name = %s.", op_node_name_list[i].c_str()); + const string op_node_name = op_node_name_list[i]; + const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name_list[i]]; + if (node_def == nullptr) { + GELOGE(INTERNAL_ERROR, "Cannot find [%s] in nodedef map.", op_node_name_list[i].c_str()); + DeleteFuisonNodeDef(); + return INTERNAL_ERROR; + } + const string &node_op = node_def->op(); + + if (tensorflow_op_map.find(node_op) == tensorflow_op_map.end()) { + GELOGW("%s not found in tensorflow_op_map.", node_op.c_str()); + } + Status ret = AddNode(node_def, graph, scope_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Add op[%s] failed", node_def->name().c_str()); + DeleteFuisonNodeDef(); + return ret; + } + } + + GELOGD("[TF Parse] parse tf node to geop success"); + + DeleteFuisonNodeDef(); + + GE_RETURN_IF_ERROR(AddEdges(graph)); + GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); + GE_RETURN_IF_ERROR(graph->TopologicalSorting()); + + if (has_error) { + GELOGE(PARAM_INVALID, "Precheck has errors."); + return PARAM_INVALID; + } + GELOGI("[TF Parser] Parse proto success."); + return SUCCESS; +} + +Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) { + // Number of data nodes + uint32_t data_node_count = 0; + for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) { + // Check that all input is valid + for (const string &node_name : node_def.input()) { + string tmp_node_name; + GE_RETURN_IF_ERROR(CheckInputNodeName(node_name, &tmp_node_name, nullptr, nullptr)); + + if (nodedef_map_.find(tmp_node_name) == nodedef_map_.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E12009", {"opname", "inputopname"}, + {node_def.name(), node_name}); + GELOGE(INTERNAL_ERROR, "Op[%s]'s input op[%s] is not exist in the graph_def.", node_def.name().c_str(), + node_name.c_str()); + return INTERNAL_ERROR; + } + } + + if (node_def.op() == TENSORFLOWF_NODE_OP_PLACEHOLDER || node_def.op() == ge::parser::ARG) { + data_node_count++; + } + } + if (data_node_count == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E12010"); + GELOGE(INTERNAL_ERROR, "Model has no Placeholder node."); + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) { + // Build the input relationship first + for (auto &iter : op_node_context_map_) { + map>> input_map; + const string &op_node_name = iter.first; + GE_RETURN_IF_ERROR(GetOpNodeInputMap(op_node_name, input_map)); + + OpNodeContext &op_node_context = iter.second; + op_node_context.input_map = input_map; + } + + // Then build the output relationship + GE_RETURN_IF_ERROR(GetOpNodeOutputMap(graph_def)); + + return SUCCESS; +} + +// Get the input relation of opnode includeing input_op and input_const +Status TensorFlowModelParser::GetOpNodeInputMap(const string &op_node_name, + map>> &input_map) { + // Get the current nodedef according to the node_name + const domi::tensorflow::NodeDef *node_def = nodedef_map_[op_node_name]; + GE_CHECK_NOTNULL(node_def); + int32_t input_index = 0; + int32_t output_index = 0; + for (const string &input_node_name : node_def->input()) { + GELOGD("Get Op InputMap, node_name : %s, input node:%s", node_def->name().c_str(), input_node_name.c_str()); + string tmp_node_name; + bool control = false; + GE_RETURN_IF_ERROR(CheckInputNodeName(input_node_name, &tmp_node_name, &output_index, &control)); + input_map[tmp_node_name].push_back({output_index, control ? kControlSlot : input_index}); + SaveEdgesControlInfo(node_def->name(), control); + input_index = control ? input_index : input_index + 1; + } + + return SUCCESS; +} + +Status TensorFlowModelParser::GetOpNodeOutputMap(const domi::tensorflow::GraphDef &graph_def) { + // Loop through all nodes in graphdef + for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) { + auto currentIter = op_node_context_map_.find(node_def.name()); + if (currentIter != op_node_context_map_.end()) { + OpNodeContext &op_node_context = currentIter->second; + // Find all input nodes of the current node + for (auto &inputIter : op_node_context.input_map) { + auto iter = op_node_context_map_.find(inputIter.first); + if (iter != op_node_context_map_.end()) { + std::vector> inputpairs = inputIter.second; + OpNodeContext &op_node_context1 = iter->second; + op_node_context1.output_map[node_def.name()].assign(inputpairs.begin(), inputpairs.end()); + } + } + } + } + return SUCCESS; +} + +Status TensorFlowModelParser::GeStoi(const string &input_node_name, const string &index_str, int32_t *index) { + try { + int32_t tmp_index = static_cast(std::stoi(index_str.c_str(), nullptr, 10)); + *index = tmp_index; + } catch (std::invalid_argument &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, + {"input_node_name(" + input_node_name + ")", index_str}); + GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is invalid argument!", input_node_name.c_str(), + index_str.c_str()); + return INTERNAL_ERROR; + } catch (std::out_of_range &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, + {"input_node_name(" + input_node_name + ")", index_str}); + GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is out of range!", input_node_name.c_str(), + index_str.c_str()); + return INTERNAL_ERROR; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, + {"input_node_name(" + input_node_name + ")", index_str}); + GELOGE(INTERNAL_ERROR, "stl[stoi] input_node_name[%s] indexstr[%s] is bad argument!", input_node_name.c_str(), + index_str.c_str()); + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +Status TensorFlowModelParser::CheckInputNodeName(const string &input_node_name, string *node_name, int32_t *index, + bool *control) { + // Processing scene: input: "^fastrcnn_predictions/map/while/Identity"" + string tmp_input_node_name = input_node_name; + if (tmp_input_node_name.find("^") == 0) { + tmp_input_node_name = tmp_input_node_name.substr(1, tmp_input_node_name.length() - 1); + if (control != nullptr) { + *control = true; + } + } else { + if (control != nullptr) { + *control = false; + } + } + + int32_t tmp_index = 0; + auto find = tmp_input_node_name.find(":"); + if (find == string::npos) { + *node_name = tmp_input_node_name; + + if (index == nullptr) { + return SUCCESS; + } + *index = tmp_index; + + return SUCCESS; + } + + string indexstr = tmp_input_node_name.substr(find + 1, tmp_input_node_name.length() - find - 1); + *node_name = tmp_input_node_name.substr(0, find); + + if (index == nullptr) { + return SUCCESS; + } + + if (GeStoi(input_node_name, indexstr, index) != SUCCESS) { + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +Status TensorFlowModelParser::RunScopeFusionPass(const vector &scope_passes_list, + ScopePassManager &pass_manager, + shared_ptr &scope_graph) { + if (scope_passes_list.empty()) { + return SUCCESS; + } + GE_CHECK_NOTNULL(scope_graph); + auto &impl = ge::ScopeFusionPassRegistry::GetInstance().impl_; + if (impl == nullptr) { + GELOGE(ge::MEMALLOC_FAILED, "ScopeFusionPassRegistry is not properly initialized."); + return ge::MEMALLOC_FAILED; + } + + for (auto &pass_name : scope_passes_list) { + auto pass = impl->CreateScopeFusionPass(pass_name); + if (pass == nullptr) { + ErrorManager::GetInstance().ATCReportErrMessage("E12001", {"passname"}, {pass_name}); + GELOGE(INTERNAL_ERROR, "Scope fusion pass[%s] is not registered.", pass_name.c_str()); + return INTERNAL_ERROR; + } + Status ret = pass_manager.AddPass(pass); + if (ret != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E12002", {"passname"}, {pass_name}); + GELOGE(INTERNAL_ERROR, "Add scope fusion pass[%s] failed.", pass_name.c_str()); + return INTERNAL_ERROR; + } + } + Status ret = pass_manager.Run(scope_graph); + if (ret != SUCCESS && ret != domi::SCOPE_NOT_CHANGED) { + GELOGE(FAILED, "Run scope fusion pass failed, ret:%u.", ret); + return FAILED; + } + return SUCCESS; +} + +bool TensorFlowModelParser::MaybeFusionOp(shared_ptr &scope_graph, + const domi::tensorflow::NodeDef *node_def) { + GE_CHECK_NOTNULL(scope_graph); + GE_CHECK_NOTNULL(node_def); + // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_ + ge::ScopeFusionOpInfo info; + std::vector info_list; + auto &impl = scope_graph->impl_; + if (TensorFlowFunsionOPUtil::MaybeFusionOp(node_def->name(), &info) || + impl->IsFusionOpChild(node_def->name(), info_list)) { + GE_IF_BOOL_EXEC( + info_list.size() > 0, for (size_t i = 0; i < info_list.size(); ++i) { + fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].fusion_op_type); + fusion_op_type_map_[info_list[i].fusion_node_name].push_back(info_list[i].description); + fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(node_def); + if (info_list[i].fusion_op_type == "Dropout" && + (node_def->op() == "Add" || node_def->op() == "RandomUniform")) { + fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(0)]); + } + if (info_list[i].fusion_op_type == "LayerNorm" && node_def->op() == "Mean") { + fusion_op_nodedef_map_[info_list[i].fusion_node_name].push_back(nodedef_map_[node_def->input(1)]); + } + fusion_op_policy_[info_list[i].fusion_node_name] = info_list[i].scope_pass; + fusion_op_children_[node_def->name()] = info_list[i]; + }); + GE_IF_BOOL_EXEC(info_list.size() == 0, fusion_op_type_map_[info.fusion_node_name].push_back(info.fusion_op_type); + fusion_op_type_map_[info.fusion_node_name].push_back(info.description); + fusion_op_nodedef_map_[info.fusion_node_name].push_back(node_def); + fusion_op_policy_[info.fusion_node_name] = info.scope_pass; + fusion_op_children_[node_def->name()] = info); + + return true; + } + + return false; +} + +bool TensorFlowModelParser::IsFusionOpChild(const string &node_name, ge::ScopeFusionOpInfo *info) { + GE_CHK_BOOL_EXEC(info != nullptr, return false, "fusion info is null."); + // 1.View in full match fusion strategy first + // 2.View in scope fusion policy then + auto iter = fusion_op_children_.find(node_name); + if (iter != fusion_op_children_.end()) { + info->node_name = fusion_op_children_[node_name].node_name; + info->fusion_node_name = fusion_op_children_[node_name].fusion_node_name; + info->fusion_op_type = fusion_op_children_[node_name].fusion_op_type; + info->description = fusion_op_children_[node_name].description; + info->scope_pass = fusion_op_children_[node_name].scope_pass; + + return true; + } + + return false; +} + +bool TensorFlowModelParser::FusionOpChildIgnore(shared_ptr &scope_graph, + const ge::ScopeFusionOpInfo &info) { + GE_CHECK_NOTNULL(scope_graph); + bool ignore = false; + if (info.scope_pass) { + // Scope fusion strategy + auto &impl = scope_graph->impl_; + ignore = impl->FusionOpChildIgnore(info); + } else { + // Full match fusion strategy + ignore = TensorFlowFunsionOPUtil::FusionOpChildIgnore(info); + } + return ignore; +} + +bool TensorFlowModelParser::IsFusionOp(shared_ptr &scope_graph, + const domi::tensorflow::NodeDef *node_def) { + // The caller guarantees that the pointer is not null + auto &impl = scope_graph->impl_; + if (TensorFlowFunsionOPUtil::IsFusionOp(node_def) || impl->IsFusionOp(node_def)) { + return true; + } + + return false; +} +Status TensorFlowModelParser::GetInPutIndex(shared_ptr &scope_graph, const ge::ScopeFusionOpInfo &info, + const int32_t old_index, int32_t &new_index) { + GE_CHECK_NOTNULL(scope_graph); + Status ret; + if (info.scope_pass) { + auto &impl = scope_graph->impl_; + ret = impl->GetInputOrOutputIndex(info, old_index, true, new_index); + } else { + ret = TensorFlowFunsionOPUtil::GetInPutIndex(info, old_index, new_index); + } + + return ret; +} +Status TensorFlowModelParser::GetOutPutIndex(shared_ptr &scope_graph, const ge::ScopeFusionOpInfo &info, + const int32_t old_index, int32_t &new_index) { + GE_CHECK_NOTNULL(scope_graph); + Status ret; + if (info.scope_pass) { + auto &impl = scope_graph->impl_; + ret = impl->GetInputOrOutputIndex(info, old_index, false, new_index); + } else { + ret = TensorFlowFunsionOPUtil::GetOutPutIndex(info, old_index, new_index); + } + + return ret; +} + +Status TensorFlowModelParser::CheckFusionOpValid() { + for (auto &iter : fusion_op_nodedef_map_) { + const string fusion_node_name = iter.first; + vector nodedef_list = iter.second; + vector funsion_op_info = fusion_op_type_map_[fusion_node_name]; + // vecotr index 0 is fusion_op_type + const string funsion_op_type = funsion_op_info[0]; + if (!fusion_op_policy_[fusion_node_name]) { + // Check the validity of the fusion_op_nodedef_map children operator + GE_RETURN_IF_ERROR( + TensorFlowFunsionOPUtil::CheckFusionOpChildren(fusion_node_name, nodedef_list, funsion_op_type)); + + // Because there are many scenes in tensorflow graph, + // in order to avoid the problem of omission, the error is returned directly. + // In the future, functions like rollback can be implemented according to the definition of fusion operator + } + } + return SUCCESS; +} + +bool TensorFlowModelParser::ConstOpNeedUpdate(const string &op_name) { + if (nodedef_map_[op_name]->op() != TENSORFLOWF_NODE_OP_CONST) { + // Normal op need to update + return true; + } else { + auto iter = op_node_context_map_.find(op_name); + if (iter != op_node_context_map_.end()) { + ge::ScopeFusionOpInfo info; + auto outmap = iter->second.output_map; + for (auto &out_node : outmap) { + // if the const op output connected to are all fusion ops and the cosnt op is not in the update vector + if (!IsFusionOpChild(out_node.first, &info)) { + return true; + } + } + if (std::find(const_op_update_vec.begin(), const_op_update_vec.end(), op_name) == const_op_update_vec.end()) { + return false; + } + } + return true; + } +} + +Status TensorFlowModelParser::UpdateAllNodeOpContext(shared_ptr &scope_graph, + const domi::tensorflow::GraphDef &graph_def, + vector &op_node_name_list) { + GE_CHECK_NOTNULL(scope_graph); + vector tmp_op_node_name_list; + map tmp_fusion_op_node_context_map; + + for (auto &op_node_name : op_node_name_list) { + auto iter = op_node_context_map_.find(op_node_name); + if (iter != op_node_context_map_.end()) { + ge::ScopeFusionOpInfo info; + if (IsFusionOpChild(op_node_name, &info) && nodedef_map_[op_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) { + // This node is a fusion operator + auto fusion_iter = tmp_fusion_op_node_context_map.find(info.fusion_node_name); + if (fusion_iter == tmp_fusion_op_node_context_map.end()) { + OpNodeContext op_node_context; + tmp_fusion_op_node_context_map[info.fusion_node_name] = op_node_context; + tmp_op_node_name_list.push_back(info.fusion_node_name); + } + + OpNodeContext &fusion_op_node_context = tmp_fusion_op_node_context_map[info.fusion_node_name]; + OpNodeContext &normal_op_node_context = op_node_context_map_[op_node_name]; + GE_RETURN_IF_ERROR(UpdateFusionOpContext(scope_graph, info, fusion_op_node_context, normal_op_node_context)); + + // Delete fusion operator context + op_node_context_map_.erase(iter); + } else { + // This node is a common operator + OpNodeContext &normal_op_node_context = op_node_context_map_[op_node_name]; + GE_RETURN_IF_ERROR(UpdateNormalOpContext(scope_graph, op_node_name, normal_op_node_context)); + tmp_op_node_name_list.push_back(op_node_name); + } + } + } + + // update op_node_name_list + op_node_name_list.clear(); + op_node_name_list.assign(tmp_op_node_name_list.begin(), tmp_op_node_name_list.end()); + + // update op_node_context_map_ + for (const auto &iter : tmp_fusion_op_node_context_map) { + op_node_context_map_[iter.first] = iter.second; + } + // Normalized context + GE_RETURN_IF_ERROR(NormalizeAllNodeOpContext()); + + return SUCCESS; +} + +Status TensorFlowModelParser::UpdateFusionOpContext(shared_ptr &scope_graph, + const ge::ScopeFusionOpInfo &info, + OpNodeContext &fusion_op_node_context, + OpNodeContext &normal_op_node_context) { + GE_CHECK_NOTNULL(scope_graph); + if (FusionOpChildIgnore(scope_graph, info)) { + // The inner children operators of the fusion operator can be ignored directly + // if they do not establish the edge relationship with other outer ordinary / fusion operators + return SUCCESS; + } + + GE_CHK_STATUS_RET(UppdateInputMap(scope_graph, info, fusion_op_node_context, normal_op_node_context), + "UppdateInputMap ret fail"); + GE_CHK_STATUS_RET(UppdateOutputMap(scope_graph, info, fusion_op_node_context, normal_op_node_context), + "UppdateOutputMap ret fail"); + + return SUCCESS; +} + +Status TensorFlowModelParser::UppdateInputMap(shared_ptr &scope_graph, + const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context, + OpNodeContext &normal_op_node_context) { + GE_CHECK_NOTNULL(scope_graph); + for (auto &iter : normal_op_node_context.input_map) { + string input_node_name = iter.first; + std::vector> &pairs = iter.second; + ge::ScopeFusionOpInfo from_info; + int32_t from_index = 0; + int32_t to_index = 0; + if (!ConstOpNeedUpdate(input_node_name)) { + GELOGI("%s is const node connected to a fusion child, ignore", input_node_name.c_str()); + continue; + } + if (IsFusionOpChild(input_node_name, &from_info)) { + if (info.fusion_node_name == from_info.fusion_node_name) { + // Ignore two sub operators in the same fusion operator + continue; + } + + for (auto &pair : pairs) { + GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, from_info, pair.first, from_index), + "GetOutPutIndex failed ,input_node_name %s.", input_node_name.c_str()); + GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, info, pair.second, to_index), + "GetInPutIndex failed ,input_node_name %s.", input_node_name.c_str()); + fusion_op_node_context.input_map[from_info.fusion_node_name].push_back({from_index, to_index}); + UpdateEdgesControlInfo(info); + GELOGD("[Update op context] update fusion input map for fusion input, %s:%d TO %s:%d", + from_info.fusion_node_name.c_str(), from_index, info.fusion_node_name.c_str(), to_index); + } + } else { + for (auto &pair : pairs) { + from_index = pair.first; + GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, info, pair.second, to_index), + "GetInPutIndex input_node_name %s.", input_node_name.c_str()); + fusion_op_node_context.input_map[input_node_name].push_back({from_index, to_index}); + UpdateEdgesControlInfo(info); + GELOGD("[Update op context] update fusion input map for normal input, %s:%d TO %s:%d", + input_node_name.c_str(), from_index, info.fusion_node_name.c_str(), to_index); + } + } + } + return SUCCESS; +} +Status TensorFlowModelParser::UppdateOutputMap(shared_ptr &scope_graph, + const ge::ScopeFusionOpInfo &info, OpNodeContext &fusion_op_node_context, + OpNodeContext &normal_op_node_context) { + GE_CHECK_NOTNULL(scope_graph); + for (auto &iter : normal_op_node_context.output_map) { + string output_node_name = iter.first; + std::vector> &pairs = iter.second; + ge::ScopeFusionOpInfo to_info; + int32_t from_index = 0; + int32_t to_index = 0; + if (IsFusionOpChild(output_node_name, &to_info)) { + if (info.fusion_node_name == to_info.fusion_node_name) { + // Ignore two sub operators in the same fusion operator + continue; + } + for (auto &pair : pairs) { + GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, info, pair.first, from_index), + "fusion GetOutPutIndex failed ,output_node_name %s.", output_node_name.c_str()); + GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, to_info, pair.second, to_index), + "fusion GetInPutIndex failed ,output_node_name %s.", output_node_name.c_str()); + fusion_op_node_context.output_map[to_info.fusion_node_name].push_back({from_index, to_index}); + GELOGD("[Update op context] update fusion output map for fusion output, %s:%d TO %s:%d", + info.fusion_node_name.c_str(), from_index, to_info.fusion_node_name.c_str(), to_index); + } + } else { + for (auto &pair : pairs) { + to_index = pair.second; + GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, info, pair.first, from_index), + "not fusion,GetOutPutIndex failed ,output_node_name %s.", output_node_name.c_str()); + fusion_op_node_context.output_map[output_node_name].push_back({from_index, to_index}); + GELOGD("[Update op context] update fusion output map for normal output, %s:%d TO %s:%d", + info.fusion_node_name.c_str(), from_index, output_node_name.c_str(), to_index); + } + } + } + return SUCCESS; +} + +Status TensorFlowModelParser::EraseNormalOpOutputIfChild(shared_ptr &scope_graph, + const string &op_node_name, + OpNodeContext &normal_op_node_context) { + std::map>> tmp_output_map; + for (auto iter = normal_op_node_context.output_map.begin(); iter != normal_op_node_context.output_map.end();) { + string output_node_name = iter->first; + ge::ScopeFusionOpInfo to_info; + int32_t from_index = 0; + int32_t to_index = 0; + + if (IsFusionOpChild(output_node_name, &to_info) && + nodedef_map_[output_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) { + // Fuse operator, update index + std::vector> &pairs = iter->second; + for (auto &pair : pairs) { + from_index = pair.first; + GE_RETURN_WITH_LOG_IF_ERROR(GetInPutIndex(scope_graph, to_info, pair.second, to_index), + "GetInPutIndex failed ,output_node_name %s.", output_node_name.c_str()); + tmp_output_map[to_info.fusion_node_name].push_back({from_index, to_index}); + GELOGD("[Update op context] update normal output map for fusion output, %s:%d TO %s:%d", op_node_name.c_str(), + from_index, to_info.fusion_node_name.c_str(), to_index); + } + + iter = normal_op_node_context.output_map.erase(iter); + } else { + iter++; + } + } + + for (auto &iter : tmp_output_map) { + normal_op_node_context.output_map[iter.first] = iter.second; + } + + return SUCCESS; +} + +Status TensorFlowModelParser::UpdateNormalOpContext(shared_ptr &scope_graph, const string &op_node_name, + OpNodeContext &normal_op_node_context) { + GE_CHECK_NOTNULL(scope_graph); + std::map>> tmp_input_map; + + for (auto iter = normal_op_node_context.input_map.begin(); iter != normal_op_node_context.input_map.end();) { + string input_node_name = iter->first; + ge::ScopeFusionOpInfo from_info; + int32_t from_index = 0; + int32_t to_index = 0; + + if (IsFusionOpChild(input_node_name, &from_info) && + nodedef_map_[input_node_name]->op() != TENSORFLOWF_NODE_OP_CONST) { + // Fuse operator, update index + std::vector> &pairs = iter->second; + for (auto &pair : pairs) { + to_index = pair.second; + GE_RETURN_WITH_LOG_IF_ERROR(GetOutPutIndex(scope_graph, from_info, pair.first, from_index), + "GetOutPutIndex failed ,input_node_name %s.", input_node_name.c_str()); + tmp_input_map[from_info.fusion_node_name].push_back({from_index, to_index}); + GELOGD("[Update op context] update normal input map for fusion input, %s:%d TO %s:%d", + from_info.fusion_node_name.c_str(), from_index, op_node_name.c_str(), to_index); + } + + iter = normal_op_node_context.input_map.erase(iter); + } else { + iter++; + } + } + + Status ret = EraseNormalOpOutputIfChild(scope_graph, op_node_name, normal_op_node_context); + if (ret != SUCCESS) { + return ret; + } + + for (auto &iter : tmp_input_map) { + normal_op_node_context.input_map[iter.first] = iter.second; + } + + return SUCCESS; +} + +Status TensorFlowModelParser::NormalizeAllNodeOpContext() { + for (auto iter = op_node_context_map_.begin(); iter != op_node_context_map_.end();) { + OpNodeContext &context = iter->second; + NormalizeInputOrOutputMap(context.input_map); + NormalizeInputOrOutputMap(context.output_map); + + if ((context.input_map.size() == 0) && (context.output_map.size() == 0)) { + GELOGD("[Update op context] node: %s will be removed at the back.", iter->first.c_str()); + iter = op_node_context_map_.erase(iter); + } else { + iter++; + } + } + return SUCCESS; +} + +Status TensorFlowModelParser::NormalizeInputOrOutputMap( + std::map>> &context_map) { + if (context_map.size() == 0) { + return SUCCESS; + } + + for (auto iter = context_map.begin(); iter != context_map.end();) { + std::vector> &pairs = iter->second; + std::vector> temp_pairs; + std::set compare_set; + + for (auto &pair : pairs) { + if ((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) { + // The edge will be cut off at the back, ignoring + continue; + } + + string name = to_string(pair.first) + ":" + to_string(pair.second); + auto compare_iter = compare_set.find(name); + if (compare_iter != compare_set.end()) { + // pair repeat, ignore + continue; + } + + temp_pairs.push_back(pair); + compare_set.insert(name); + } + + if (temp_pairs.size() == 0) { + // If there is no pair, the context can be deleted + iter = context_map.erase(iter); + continue; + } else { + iter++; + } + + pairs.clear(); + pairs.assign(temp_pairs.begin(), temp_pairs.end()); + } + + return SUCCESS; +} + +void TensorFlowModelParser::DeleteFuisonNodeDef() { + for (auto &fusion_nodedef : fusion_nodedef_list) { + GE_DELETE_NEW_SINGLE(fusion_nodedef); + } +} + +void TensorFlowModelParser::SaveEdgesControlInfo(const string &node_name, const bool control) { + if (control) { + // If the control attribute is true, save the control attribute to edges_control_map + edges_control_map[node_name].push_back(kControlSlot); + } +} + +void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info) { + auto iter = edges_control_map.find(info.node_name); + if (iter != edges_control_map.end()) { + // Delete the original fusion operator node information and add the fusion operator control edge information + edges_control_map.erase(iter); + edges_control_map[info.fusion_node_name].push_back(kControlSlot); + } +} + +bool TensorFlowModelParser::GetEdgesControlInfo(const string &node_name, const int32_t index) { + // If the node name is included, then confirm whether the index is the same + auto iter = edges_control_map.find(node_name); + if (iter != edges_control_map.end()) { + for (auto &i : iter->second) { + if (i == index) { + return true; + } + } + } + + return false; +} + +Status TensorFlowModelParser::ClearFusionOpError(const vector &op_node_name_list) { + for (const auto &name : op_node_name_list) { + ge::ScopeFusionOpInfo info; + if (IsFusionOpChild(name, &info)) { + const NodeDef *node = nodedef_map_[name]; + GE_CHECK_NOTNULL(node); + GE_RETURN_WITH_LOG_IF_ERROR(PreChecker::Instance().Clear(node, "fused and removed."), + "Clear pre-checking for node %s failed.", node->name().c_str()); + } + } + + return SUCCESS; +} + +Status TensorFlowModelParser::ToJson(const char *model_file, const char *json_file) { + GE_CHK_BOOL_RET_STATUS(model_file != nullptr, FAILED, "model_file is nullptr."); + GE_CHK_BOOL_RET_STATUS(json_file != nullptr, FAILED, "json_file is nullptr."); + domi::tensorflow::GraphDef graph_def; + nlohmann::json j; + + GE_RETURN_WITH_LOG_IF_FALSE(ge::parser::ReadProtoFromBinaryFile(model_file, &graph_def), + "ReadProtoFromBinaryFile failed, file:%s.", model_file); + + Pb2Json::Message2Json(graph_def, kTfBlackFields, j, true); + return ModelSaver::SaveJsonToFile(json_file, j); +} + +Status TensorFlowWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) { + return SUCCESS; +} + +Status TensorFlowWeightsParser::Parse(const char *file, ge::Graph &graph) { return SUCCESS; } + +Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { + PARSER_TIMESTAMP_START(ParseProto); + GE_CHECK_NOTNULL(proto); + GE_CHECK_NOTNULL(graph); + ge::GetParserContext().train_flag = true; + + const domi::tensorflow::GraphDef *graph_def_in = reinterpret_cast(proto); + // Make a copy for operation without modifying the original graph def. + domi::tensorflow::GraphDef graph_def_operation = *graph_def_in; + domi::tensorflow::GraphDef *graph_def = &graph_def_operation; + GELOGI("[TF Parser] graph def version:%d", graph_def->version()); + + shared_ptr scope_graph = nullptr; + Status ret = ExcuteScopeFusionPasses(graph_def, scope_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[TF Parser] scope fusion failed."); + return ret; + } + GELOGD("[TF Parser] scope fusion success"); + + bool has_error = false; + + // Graphdef optimizes identity + PARSER_TIMESTAMP_START(GraphDefOptimize); + GE_RETURN_IF_ERROR(GraphDefOptimize(graph_def)); + PARSER_TIMESTAMP_END(GraphDefOptimize, "TensorFlowModelParser::GraphDefOptimize"); + GELOGD("[TF Parser] graph def optimize success"); + + // Optimization for TVM operator + PARSER_TIMESTAMP_START(OptimizeConstNodes4CustomOp); + GE_RETURN_IF_ERROR(OptimizeConstNodes4CustomOp(graph_def)); + PARSER_TIMESTAMP_END(OptimizeConstNodes4CustomOp, "TensorFlowModelParser::OptimizeConstNodes4CustomOp"); + GELOGD("[TF Parser] optimize const nodes for custom op success"); + + GE_RETURN_IF_ERROR(GetTensorflowGraphInOutMap(graph_def)); + GE_RETURN_IF_ERROR(RemoveIsolateNode(graph_def)); + + vector op_node_name_list; + bool isDatasetInit = false; + PARSER_TIMESTAMP_START(AddFmkNodeDefToMap); + for (int i = 0; i < graph_def->node_size(); i++) { + const domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i); + if (node_def->op() == ge::parser::IDENTITY && node_def->input_size() == 0) { + continue; + } + if (node_def->op() == ge::parser::SNAPSHOT && node_def->input_size() == 0) { + continue; + } + GE_IF_BOOL_EXEC(node_def->op() == "MakeIterator", isDatasetInit = true); + + // If it is a fusion operator, put nodedef in the fusion_op_nodedef_map_ + if (MaybeFusionOp(scope_graph, node_def)) { + GELOGI("Node: %s maybe a fusion op.", node_def->name().c_str()); + } + + // Do not exit immediately when there is an error, wait until all errors are collected before exiting + Status ret = AddFmkNodeDefToMap(*graph_def, node_def, op_node_name_list); + GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); + } + PARSER_TIMESTAMP_END(AddFmkNodeDefToMap, "TensorFlowModelParser::AddFmkNodeDefToMap"); + GELOGI("[TF Parser] TF subgraph isDatasetInit: %d.", isDatasetInit); + + // Verify the validity of fusionop + GE_RETURN_IF_ERROR(CheckFusionOpValid()); + + // Build input and output relationships for all OP nodes + PARSER_TIMESTAMP_START(GetOpNodesContextFromGraph); + GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); + PARSER_TIMESTAMP_END(GetOpNodesContextFromGraph, "TensorFlowModelParser::GetOpNodesContextFromGraph"); + GELOGD("[TF Parser] Get op nodes context from graph success"); + + // Building input-output relationship between fusionop and common op + GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, *graph_def, op_node_name_list)); + + GELOGI("[TF Parser] TF op node size = %zu.", op_node_name_list.size()); + PARSER_TIMESTAMP_START(AddFmkNode); + // Loop analysis of op_nodes and map them to nodes in graph + ret = AddFmkNode(graph, scope_graph, op_node_name_list, isDatasetInit); + PARSER_TIMESTAMP_END(AddFmkNode, "TensorFlowModelParser::AddFmkNode"); + GE_CHK_STATUS_EXEC(ret, DeleteFuisonNodeDef(); return ret, "AddFmkNode failed"); + GELOGD("[TF Parser] Add framework node success"); + + ret = AddEdges(graph); + DeleteFuisonNodeDef(); + GE_CHK_STATUS_EXEC(ret, return ret, "AddEdges failed"); + GELOGD("[TF Parser] Add edges success"); + + PARSER_TIMESTAMP_START(RemoveIsolateNode); + // Delete isolated nodes + GE_RETURN_IF_ERROR(RemoveIsolateNode(graph)); + + PARSER_TIMESTAMP_END(RemoveIsolateNode, "TensorFlowModelParser::RemoveIsolateNode"); + PARSER_TIMESTAMP_START(TopologicalSorting); + GE_RETURN_IF_ERROR(graph->TopologicalSorting()); + PARSER_TIMESTAMP_END(TopologicalSorting, "TensorFlowModelParser::TopologicalSorting"); + + ge::parser::PassManager iterator_fusion_pass; + try { + (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", + new ge::IteratorFusionPass(ge::TENSORFLOW, false)); + } catch (std::bad_alloc &e) { + GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); + return INTERNAL_ERROR; + } + ret = iterator_fusion_pass.Run(graph); + if (ret != SUCCESS && ret != ge::NOT_CHANGED) { + GELOGE(ret, "Run graph passes optimize for preprocess failed, ret:%u.", ret); + return ret; + } + + has_error = has_error || PreChecker::Instance().HasError(); + if (has_error) { + GELOGE(PARAM_INVALID, "Precheck has errors."); + return PARAM_INVALID; + } + GELOGI("[TF Parser] Parse proto success."); + PARSER_TIMESTAMP_END(ParseProto, "TensorFlowModelParser::ParseProto"); + return SUCCESS; +} + +Status TensorFlowModelParser::ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, + domi::GetGraphCallback callback, ge::ComputeGraphPtr &root_graph) { + GE_CHECK_NOTNULL(root_proto); + GE_CHECK_NOTNULL(callback); + GE_CHECK_NOTNULL(root_graph); + + PARSER_TIMESTAMP_START(ParseProtoWithSubgraph); + std::vector> proto_holder; + std::deque tasks; + tasks.push_back({root_proto, "root", nullptr, "", root_graph}); + + while (!tasks.empty()) { + auto arg = tasks.front(); + tasks.pop_front(); + + if (arg.proto == nullptr) { + auto proto = callback(root_proto, arg.function_name); + if (proto == nullptr) { + GELOGE(FAILED, "Failed to get function by name %s", arg.function_name.c_str()); + return FAILED; + } + arg.proto = proto.get(); + proto_holder.emplace_back(std::move(proto)); + } + + GELOGI("Begin to parse graph %s", arg.function_name.c_str()); + auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW); + auto ret = model_parser->ParseProto(arg.proto, arg.graph); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to parse graph %s, instance name %s", arg.function_name.c_str(), + arg.graph->GetName().c_str()); + return ret; + } + + ret = PostOpProcessForSubgraph(arg); + if (ret != SUCCESS) { + // the error log has been printed inner the function + return ret; + } + + ret = GenSubgraphParseTasks(arg.graph, tasks); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to gen tasks on graph %s for next iteration", arg.graph->GetName().c_str()); + return ret; + } + } + PARSER_TIMESTAMP_EVENT_END(ParseProtoWithSubgraph, "TensorFlowModelParser::ParseProtoWithSubgraph"); + return SUCCESS; +} + +// For the identity operator whose output is "_retval", optimize it. +Status TensorFlowModelParser::OptimizeIdentityByOutput(map &nodedef_map, + const string &curr_node_name, bool &clear_input_flag) { + auto context_iter = op_node_context_map_.find(curr_node_name); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((context_iter == op_node_context_map_.end()), return INTERNAL_ERROR, + "Can't find op node context."); + OpNodeContext op_node_context = context_iter->second; + + auto node_def_iter = nodedef_map.find(curr_node_name); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node_def_iter == nodedef_map.end()), return INTERNAL_ERROR, "Can't find nodedef"); + domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second; + GE_CHECK_NOTNULL(curr_node_def); + bool has_out_retval = false; + // For the identity operator whose output is "_retval", optimize it + std::map>> output_map = op_node_context.output_map; + for (auto output_iter = output_map.begin(); output_iter != output_map.end(); ++output_iter) { + const string &output_node_name = output_iter->first; + domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; + GE_CHECK_NOTNULL(output_node_def); + if (output_node_def->op() == "_Retval") { + GELOGD("_Retval Identity need optimize."); + output_node_def->set_input(0, curr_node_def->input(0).c_str()); + has_out_retval = true; + GELOGD("op %s set input(0):%s.", output_node_def->name().c_str(), curr_node_def->input(0).c_str()); + } + } + + // Deal with non _Retval output operator of Identity. + if (has_out_retval) { + for (auto output_iter = output_map.begin(); output_iter != output_map.end(); ++output_iter) { + const string &output_node_name = output_iter->first; + domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; + GE_CHECK_NOTNULL(output_node_def); + GE_IF_BOOL_EXEC(output_node_def->op() == "_Retval", continue); + for (int k = 0; k < output_node_def->input_size(); ++k) { + GE_IF_BOOL_EXEC( + output_node_def->input(k) == curr_node_name, output_node_def->set_input(k, curr_node_def->input(0).c_str()); + GELOGD("%s op set input(%d):%s.", output_node_def->name().c_str(), k, curr_node_def->input(0).c_str());) + } + } + clear_input_flag = true; + } + return SUCCESS; +} + +Status TensorFlowModelParser::GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, + map &nodedef_map, + const vector &nodedef_to_optimize) { + GE_CHECK_NOTNULL(graph_def); + if (!nodedef_to_optimize.empty()) { + // Building input and input relationships for all OP nodes + GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); + } else { + return SUCCESS; + } + for (auto &curr_node_def : nodedef_to_optimize) { + GE_CHECK_NOTNULL(curr_node_def); + bool clear_input_flag = false; + const string &curr_node_name = curr_node_def->name(); + GE_RETURN_IF_ERROR(OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag)); + if (clear_input_flag) { + curr_node_def->clear_input(); + } + } + GELOGI("GraphDefOptimizeIdentity success."); + return SUCCESS; +} + +Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, + map &nodedef_map, + const std::pair &input_data, + const std::vector &control_list) { + GE_CHECK_NOTNULL(curr_mode_def); + if (curr_mode_def == nullptr) { + GELOGE(FAILED, "input param is nullptr."); + return PARAM_INVALID; + } + string curr_node_name = curr_mode_def->name(); + auto context_iter = op_node_context_map_.find(curr_node_name); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((context_iter == op_node_context_map_.end()), return INTERNAL_ERROR, + "Can't find op node context."); + OpNodeContext op_node_context = context_iter->second; + + std::map>> output_map = op_node_context.output_map; + for (auto &output_iter : output_map) { + const string &output_node_name = output_iter.first; + domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; + GE_CHECK_NOTNULL(output_node_def); + auto inputs = output_node_def->mutable_input(); + for (auto &input : *inputs) { + string node_name; + bool is_control = false; + if (CheckInputNodeName(input, &node_name, nullptr, &is_control) != SUCCESS) { + GELOGE(FAILED, "parse node input info failed, node %s, input %s.", output_node_def->name().c_str(), + input.c_str()); + return FAILED; + } + if (node_name == curr_node_name) { + if (is_control) { + input = "^" + input_data.first; + } else if (input_data.second == 0) { + input = input_data.first; + } else { + input = input_data.first + ":" + std::to_string(input_data.second); + } + GELOGD("Optimize Snapshot node, dest:%s, set input:%s.", output_node_name.c_str(), input.c_str()); + + for (auto &item : control_list) { + bool is_exist_input = false; + for (auto &tmp_input : output_node_def->input()) { + string tmp_node_name; + if (CheckInputNodeName(tmp_input, &tmp_node_name, nullptr, nullptr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "parse node input info failed, node %s, input %s.", + output_node_def->name().c_str(), tmp_input.c_str()); + return FAILED; + } + if (tmp_node_name == item) { + is_exist_input = true; + break; + } + } + if (!is_exist_input) { + output_node_def->add_input("^" + item); + GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), item.c_str()); + } + } + break; + } + } + } + // Clear the input of snapshot and become an isolated node + curr_mode_def->clear_input(); + return SUCCESS; +} + +Status TensorFlowModelParser::GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, + map &nodedef_map, + const vector &nodedef_to_optimize) { + GE_CHECK_NOTNULL(graph_def); + if (!nodedef_to_optimize.empty()) { + // Building input and input relationships for all OP nodes + GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); + GELOGD("Optimize snapshot num:%zu.", nodedef_to_optimize.size()); + } else { + return SUCCESS; + } + + for (auto &curr_node_def : nodedef_to_optimize) { + GE_CHECK_NOTNULL(curr_node_def); + std::pair input_data; // src node name, src index + vector control_list; + uint32_t data_input_cnt = 0; + for (auto &input : curr_node_def->input()) { + string node_name; + int input_index = 0; + bool is_control = false; + if (CheckInputNodeName(input, &node_name, &input_index, &is_control) != SUCCESS) { + GELOGE(FAILED, "parse SnapShot input info failed, node %s, input %s.", curr_node_def->name().c_str(), + input.c_str()); + return FAILED; + } + if (is_control) { + control_list.push_back(node_name); + } else { + data_input_cnt++; + input_data = std::make_pair(node_name, input_index); + } + } + if (data_input_cnt != 1) { + GELOGE(FAILED, "%s op data input size %u invalid", curr_node_def->name().c_str(), data_input_cnt); + return FAILED; + } + // Optimize Snapshot Node + GE_CHK_STATUS_RET(OptimizeSnapShot(curr_node_def, nodedef_map, input_data, control_list)); + } + GELOGI("GraphDefOptimizeSnapShot success."); + return SUCCESS; +} + +void TensorFlowModelParser::OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, + domi::tensorflow::NodeDef *nodeCurrent, + bool &clearInputFlag) { + // Internal call to ensure that the parameter is not empty. + GELOGI("DestroyTemporaryVariable optimizing."); + for (int w = 0; w < graph_def->node_size(); w++) { + domi::tensorflow::NodeDef *nodeDst = graph_def->mutable_node(w); + GE_IF_BOOL_EXEC(nodeDst->name() == nodeCurrent->name(), continue); + for (int k = 0; k < nodeDst->input_size(); k++) { + string nodeDstInputName = nodeDst->input(k); + string nodeDstInputNameTmp; + bool isControl = false; + if (CheckInputNodeName(nodeDstInputName, &nodeDstInputNameTmp, nullptr, &isControl) != SUCCESS) { + GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeDstInputName.c_str()); + return; + } + if (nodeDstInputNameTmp == nodeCurrent->name()) { + GELOGI("current node name is %s ", nodeCurrent->name().c_str()); + clearInputFlag = true; + if (isControl) { + string nodeCurrentName = nodeCurrent->input(0); + string nodeCurrentNameTmp; + if (CheckInputNodeName(nodeCurrentName, &nodeCurrentNameTmp, nullptr, nullptr) != SUCCESS) { + GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeCurrentName.c_str()); + return; + } + nodeCurrentNameTmp = "^" + nodeCurrentNameTmp; + GELOGI("set nodeCurrentNameTmp: %s", nodeCurrentNameTmp.c_str()); + nodeDst->set_input(k, nodeCurrentNameTmp); + } else { + nodeDst->set_input(k, nodeCurrent->input(0).c_str()); + GELOGD("%s op set input:%s.", nodeDst->name().c_str(), nodeCurrent->input(0).c_str()); + } + // DestroyTemporaryVariable node have only one input and one output. + // If the number of inputs is greater than 1, all subsequent inputs are + // control edge inputs. Therefore, after deleting DestroyTemporaryVariable, + // these control edge inputs can be directly connected to nodeDst. + if (nodeCurrent->input_size() > 1) { + for (int i = 1; i < nodeCurrent->input_size(); ++i) { + nodeDst->add_input(nodeCurrent->input(i)); + } + } + GELOGI("Optimize DestroyTemporaryVariable successful."); + } + } + } +} + +Status TensorFlowModelParser::GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, + domi::tensorflow::NodeDef *nodeCurrent) { + if (graph_def == nullptr || nodeCurrent == nullptr) { + GELOGE(FAILED, "input param is nullptr."); + return FAILED; + } + if (nodeCurrent->op() != ge::parser::DESTROYTEMPORARYVARIABLE) { + return SUCCESS; + } + + GELOGI("Optimize DestroyTemporaryVariable, node name is :%s.", nodeCurrent->name().c_str()); + bool clearInputFlag = false; + + google::protobuf::Map *attr_map_destroy = nodeCurrent->mutable_attr(); + domi::tensorflow::AttrValue var_name_attr_destroy = (*attr_map_destroy)[ge::VAR_ATTR_NAME]; + + for (int j = 0; j < graph_def->node_size(); j++) { + domi::tensorflow::NodeDef *nodeTmpVar = graph_def->mutable_node(j); + GE_IF_BOOL_EXEC(nodeTmpVar->op() != ge::parser::TEMPORARYVARIABLE, continue); + + google::protobuf::Map *attr_map_tmp = nodeTmpVar->mutable_attr(); + domi::tensorflow::AttrValue var_name_attr_tmp = (*attr_map_tmp)[ge::VAR_ATTR_NAME]; + + if (var_name_attr_destroy.s() != var_name_attr_tmp.s()) { + continue; + } + + // Optimize destroytemporaryvariable operator + OptimizeDestroyTemporaryVariable(graph_def, nodeCurrent, clearInputFlag); + + if (clearInputFlag) { + nodeCurrent->clear_input(); // Clear the destroytemporaryvariable input to become an isolated node + break; + } + } + if (!clearInputFlag) { + GELOGE(INTERNAL_ERROR, "Optimize DestroyTemporaryVariable failed, node name is :%s.", nodeCurrent->name().c_str()); + return FAILED; + } + + return SUCCESS; +} + +struct DelTransposeInfo { + domi::tensorflow::NodeDef *node_def; // transpose + domi::tensorflow::NodeDef *nextNodeDef; // transpose --> [next] + int inputIdx; +}; + +Status GetTransposeInfo(GraphDef *graph_def, std::map &softmaxInfo, + std::map &transposeInfo) { + GE_CHECK_NOTNULL(graph_def); + for (int i = 0; i < graph_def->node_size(); ++i) { + auto node_def = graph_def->mutable_node(i); + if (node_def->op() == ge::parser::TRANSPOSE) { + DelTransposeInfo transpose; + transpose.node_def = node_def; + transposeInfo.insert(std::make_pair(node_def->name(), transpose)); + } else if (node_def->op() == ge::parser::SOFTMAX) { + softmaxInfo.insert(std::make_pair(node_def->name(), node_def->input(0))); + GELOGI("softmax name:%s, input name:%s", node_def->name().c_str(), node_def->input(0).c_str()); + } + } + + for (auto &itTranspose : transposeInfo) { + for (int j = 0; j < graph_def->node_size(); ++j) { + auto nextNodeDef = graph_def->mutable_node(j); + bool bFind = false; + for (int k = 0; k < nextNodeDef->input_size(); ++k) { + if (nextNodeDef->input(k) == itTranspose.first) { + itTranspose.second.nextNodeDef = nextNodeDef; + itTranspose.second.inputIdx = k; + GELOGI("transpose info name:%s, next name:%s, idx:%d", itTranspose.second.node_def->name().c_str(), + nextNodeDef->name().c_str(), k); + bFind = true; + break; + } + } + if (bFind) { + break; + } + } + } + return SUCCESS; +} + +Status EraseTransposeNode(std::map &softmaxInfo, + std::map &transposeInfo) { + auto itTranspose = transposeInfo.begin(); + for (; itTranspose != transposeInfo.end();) { + // transpose --> softmax + bool bErase = true; + if (softmaxInfo.find(itTranspose->second.node_def->input(0)) != softmaxInfo.end() || + softmaxInfo.find(itTranspose->second.nextNodeDef->name()) != softmaxInfo.end()) { + bErase = false; + } + + if (bErase) { + GELOGI("erase node name:%s, input(0):%s", itTranspose->first.c_str(), + itTranspose->second.node_def->input(0).c_str()); + itTranspose = transposeInfo.erase(itTranspose); + } else { + itTranspose++; + } + } + + if ((softmaxInfo.size() <= SIZE_MAX / kSoftmaxMultiple) && + (softmaxInfo.size() * kSoftmaxMultiple != transposeInfo.size())) { + GELOGW("softmax size[%zu], transpose size[%zu]", softmaxInfo.size(), transposeInfo.size()); + return FAILED; + } + + return SUCCESS; +} + +void TensorFlowModelParser::OptimizeTranspose(std::map &transposeInfo) { + for (auto &it : transposeInfo) { + auto transpose = it.second; + transpose.nextNodeDef->set_input(transpose.inputIdx, transpose.node_def->input(kTransposeInputIdx)); + transpose.node_def->clear_input(); + } +} + +void TensorFlowModelParser::SoftmaxAddAttr(GraphDef *graph_def) { + // The caller guarantees that the pointer is not null + for (int i = 0; i < graph_def->node_size(); ++i) { + auto node_def = graph_def->mutable_node(i); + if (node_def->op() == ge::parser::SOFTMAX) { + domi::tensorflow::AttrValue attr_value; + attr_value.set_i(1); + ge::TensorFlowUtil::AddNodeAttr("axis", attr_value, node_def); + GELOGI("SoftmaxAddAttr, name: %s, input name:%s", node_def->name().c_str(), node_def->input(0).c_str()); + } + } +} + +Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph_def) { + GE_CHECK_NOTNULL(graph_def); + map nodedef_map; + vector op_node_name_list; + // Save Identity and ReadVariableOp + vector identity_to_optimize; + // Save Snapshot + vector snapshot_to_optimize; + + for (int i = 0; i < graph_def->node_size(); i++) { + // mutable_node return vale is not empty + domi::tensorflow::NodeDef *node_def = graph_def->mutable_node(i); + const string &node_name = node_def->name(); + Status ret = AddFmkNodeDefToMap(*graph_def, node_def, op_node_name_list); + GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); + if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) { + identity_to_optimize.push_back(node_def); + } else if (node_def->op() == ge::parser::SNAPSHOT) { + snapshot_to_optimize.push_back(node_def); + } + nodedef_map[node_name] = node_def; + } + + // Optimize for Identity/ReadVariableOp + GE_RETURN_IF_ERROR(GraphDefOptimizeIdentity(graph_def, nodedef_map, identity_to_optimize)); + // Optimize for Snapshot + GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize)); + + for (int i = 0; i < graph_def->node_size(); i++) { + domi::tensorflow::NodeDef *nodeCurrent = graph_def->mutable_node(i); + GE_CHK_STATUS_RET(GraphDefOptimizeDestroyTemporaryVariable(graph_def, nodeCurrent)); + } + + // These member variables will be rebuilt later and need to be cleared here. + nodedef_map_.clear(); + op_node_context_map_.clear(); + return SUCCESS; +} + +Status TensorFlowModelParser::RemoveIsolateNode(ge::ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + + auto nodes = graph->GetDirectNode(); + for (auto &n : nodes) { + // get front 4 char + if (n->GetName().substr(0, 4) == "dpop") { + continue; + } + if ((n->GetType() == ge::parser::DATA) || + (ge::GetParserContext().out_nodes_map.find(n->GetName()) != ge::GetParserContext().out_nodes_map.end())) { + GELOGI("Can not remove op [%s] because it is data or out node.", n->GetName().c_str()); + continue; + } + GE_IF_BOOL_EXEC((((n->GetInAllNodes().size() == 0) && (n->GetOutDataNodes().size() == 0)) || + ((n->GetType() == ge::parser::CONSTANTOP || n->GetType() == ge::parser::CONSTANT) && + (n->GetOutDataNodes().size() == 0))), + GE_CHK_STATUS_RET(ge::GraphUtils::IsolateNode(n, {}), "Isolate removed node: %s, type: %s failed", + n->GetName().c_str(), n->GetType().c_str()); + GE_CHK_STATUS_RET(ge::GraphUtils::RemoveNodeWithoutRelink(graph, n), + "Remove node: %s, type: %s without relink failed", n->GetName().c_str(), + n->GetType().c_str());); + } + return SUCCESS; +} + +// The format specified by the command line argument is preferred, +// if not specified, use InferInputFormats to infer, +// and if the inference fails, the default NHWC format is used. +domiTensorFormat_t TensorFlowModelParser::InferInputFormats() { + GE_IF_BOOL_EXEC(ge::GetParserContext().format != DOMI_TENSOR_RESERVED, return ge::GetParserContext().format); + + domiTensorFormat_t global_input_format = DOMI_TENSOR_RESERVED; + set visited_node; + for (auto &node_item : nodedef_map_) { + // Infer format for data node and save it to ge::GetParserContext().format. + domiTensorFormat_t format = DOMI_TENSOR_RESERVED; + const NodeDef *node = node_item.second; + if (node == nullptr) { + return format; + } + auto it = tensorflow_op_map.find(node->op()); + if (it != tensorflow_op_map.end() && it->second == ge::parser::DATA) { + GE_IF_BOOL_EXEC(GetNodeFormat(node, NO_TRANSPOSE, format, visited_node) != SUCCESS, + GELOGW("Cannot infer input format, the NHWC format is used by default, and you can also " + "specify format by command line arguments."); + return domi::DOMI_TENSOR_NHWC); + + GE_IF_BOOL_EXEC(global_input_format == DOMI_TENSOR_RESERVED, global_input_format = format); + + GE_IF_BOOL_EXEC( + format != DOMI_TENSOR_RESERVED && format != global_input_format, + GELOGW("Multiple data ops with different formats are not supported, " + "the NHWC format is used by default, and you can also specify format by command line arguments."); + return domi::DOMI_TENSOR_NHWC); + } + } + + return global_input_format == DOMI_TENSOR_RESERVED ? domi::DOMI_TENSOR_NHWC : global_input_format; +} + +Status TensorFlowModelParser::GetNodeFormat(const NodeDef *node, TfTranspose pred_transpose, domiTensorFormat_t &format, + set &visited_node) { + GE_CHECK_NOTNULL(node); + // Avoid repeated visits. + GE_IF_BOOL_EXEC(visited_node.find(node) != visited_node.end(), return SUCCESS); + visited_node.emplace(node); + + GE_IF_BOOL_EXEC(node->op() == TENSORFLOWF_NODE_OP_SWITCH || node->op() == TENSORFLOWF_NODE_OP_MERGE, return SUCCESS); + + // If node has a data_format attribute, format is set according to data_format. + domi::tensorflow::AttrValue attr; + if (ge::TensorFlowUtil::FindAttrValue(node, TENSORFLOW_ATTR_DATA_FORMAT, attr) && node->op() != ge::parser::BIASADD) { + GE_RETURN_IF_ERROR(ge::TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING)); + + format = (attr.s() == TENSORFLOWF_TENSOR_NCHW) ? domi::DOMI_TENSOR_NCHW : domi::DOMI_TENSOR_NHWC; + + GE_IF_BOOL_EXEC(format == domi::DOMI_TENSOR_NCHW && pred_transpose == TO_NCHW, format = domi::DOMI_TENSOR_NHWC); + GE_IF_BOOL_EXEC(format == domi::DOMI_TENSOR_NHWC && pred_transpose == TO_NHWC, format = domi::DOMI_TENSOR_NCHW); + GE_IF_BOOL_EXEC((format == domi::DOMI_TENSOR_NCHW && pred_transpose == TO_NHWC) || + (format == domi::DOMI_TENSOR_NHWC && pred_transpose == TO_NCHW), + GELOGI("Format conflicts with transpose."); + return FAILED); + + return SUCCESS; + } + + TfTranspose transpose; + GE_RETURN_IF_ERROR(GetFormatTranspose(node, transpose)); + GE_IF_BOOL_EXEC(pred_transpose == transpose && pred_transpose != NO_TRANSPOSE, + GELOGI("Multiple transpose conflicts."); + return FAILED); + + // If node does not have the data_format attribute, format is set according to the output node. + string node_name = node->name(); + GE_IF_BOOL_EXEC(op_node_context_map_.find(node_name) == op_node_context_map_.end(), + GELOGI("node %s not found in op_node_context_map_", node_name.c_str()); + return FAILED); + + domiTensorFormat_t inferred_format = DOMI_TENSOR_RESERVED; + const OpNodeContext &node_ctx = op_node_context_map_.at(node_name); + + for (const auto &output_item : node_ctx.output_map) { + auto node_iter = nodedef_map_.find(output_item.first); + GE_IF_BOOL_EXEC(node_iter == nodedef_map_.end(), + GELOGI("node %s not found in nodedef_map_", output_item.first.c_str()); + return FAILED); + + const NodeDef *output_node = node_iter->second; + GE_CHECK_NOTNULL(output_node); + domiTensorFormat_t output_format = DOMI_TENSOR_RESERVED; + GE_RETURN_IF_ERROR(GetNodeFormat(output_node, transpose, output_format, visited_node)); + + GE_IF_BOOL_EXEC(output_format != DOMI_TENSOR_RESERVED && inferred_format != DOMI_TENSOR_RESERVED && + output_format != inferred_format, + GELOGI("Multiple output formats conflict."); + return FAILED); + + inferred_format = output_format; + } + + format = inferred_format; + + return SUCCESS; +} + +Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc) { + GE_CHECK_NOTNULL(transpose_node); + transpose_direc = NO_TRANSPOSE; + + GE_IF_BOOL_EXEC(transpose_node->op() != TENSORFLOWF_NODE_OP_TRANSPOSE, return SUCCESS); + + GE_IF_BOOL_EXEC(transpose_node->input_size() != kInputNumInt, GELOGI("Input size of transpose is not 2."); + return FAILED); + + string perm_node_name = transpose_node->input(1); + auto it = nodedef_map_.find(perm_node_name); + GE_IF_BOOL_EXEC(it == nodedef_map_.end(), GELOGI("Node %s not found in nodedef_map_.", perm_node_name.c_str()); + return FAILED); + + const NodeDef *perm_node = it->second; + GE_CHECK_NOTNULL(perm_node); + domi::tensorflow::AttrValue attr_value; + GE_IF_BOOL_EXEC(perm_node->op() != TENSORFLOWF_NODE_OP_CONST, GELOGI("Input node of transpose is not const."); + return FAILED); + + GE_IF_BOOL_EXEC(!ge::TensorFlowUtil::FindAttrValue(perm_node, TENSORFLOW_ATTR_DTYPE, attr_value), return FAILED); + GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TYPE) != SUCCESS, + return FAILED); + domi::tensorflow::DataType type = attr_value.type(); + GE_IF_BOOL_EXEC(type != domi::tensorflow::DT_INT32 && type != domi::tensorflow::DT_INT64, return FAILED); + + GE_IF_BOOL_EXEC(!ge::TensorFlowUtil::FindAttrValue(perm_node, TENSORFLOW_ATTR_VALUE, attr_value), return FAILED); + GE_IF_BOOL_EXEC(ge::TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TENSOR) != SUCCESS, + return FAILED); + const TensorProto &tensor = attr_value.tensor(); + const TensorShapeProto &tensor_shape = tensor.tensor_shape(); + GE_IF_BOOL_EXEC(tensor_shape.dim_size() != 1 || tensor_shape.dim(0).size() != ge::DIM_DEFAULT_SIZE, return SUCCESS); + GE_IF_BOOL_EXEC(tensor.tensor_content().empty(), return SUCCESS); + + vector perm_value; + + GE_IF_BOOL_EXEC( + type == domi::tensorflow::DT_INT32, + const int32_t *data = reinterpret_cast(tensor.tensor_content().data()); + for (int i = 0; i < ge::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); }); + + GE_IF_BOOL_EXEC( + type == domi::tensorflow::DT_INT64, + const int64_t *data = reinterpret_cast(tensor.tensor_content().data()); + for (int i = 0; i < ge::DIM_DEFAULT_SIZE; i++) { perm_value.push_back(data[i]); }); + + // 0, 1, 2, 3 present dim num. + vector perm_to_nchw = {0, 3, 1, 2}; + vector perm_to_nhwc = {0, 2, 3, 1}; + GE_IF_BOOL_EXEC(perm_value == perm_to_nchw, transpose_direc = TO_NCHW); + GE_IF_BOOL_EXEC(perm_value == perm_to_nhwc, transpose_direc = TO_NHWC); + + return SUCCESS; +} + +Status TensorFlowModelParser::TrimGraph(const domi::tensorflow::GraphDef &input_graph_def, + domi::tensorflow::GraphDef *output_graph_def) { + GE_CHECK_NOTNULL(output_graph_def); + if (!ge::GetParserContext().input_dims.empty() && ge::GetParserContext().out_nodes_map.empty()) { + return TrimGraphByInput(input_graph_def, output_graph_def); + } else { + return TrimGraphByOutput(input_graph_def, output_graph_def); + } +} +Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def, + domi::tensorflow::GraphDef *output_graph_def) { + // The caller guarantees that the pointer is not null + std::set delete_nodes; + std::set input_nodes; + for (auto &iter : ge::GetParserContext().input_dims) { + input_nodes.insert(iter.first); + } + std::map node_lookup; + for (const NodeDef &node : input_graph_def.node()) { + node_lookup[node.name()] = &node; + } + std::vector current_inputs; + for (auto &iter : ge::GetParserContext().input_dims) { + current_inputs.push_back(iter.first); + } + while (!current_inputs.empty()) { + std::set next_inputs; + for (const string ¤t_input : current_inputs) { + delete_nodes.insert(current_input); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), ErrorManager::GetInstance().ATCReportErrMessage( + "E12012", {"opname"}, {current_input}); + return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); + const NodeDef *current_node = node_lookup[current_input]; + GE_CHECK_NOTNULL(current_node); + for (const string &input_name : current_node->input()) { + string input_node_name = NodeNameFromInput(input_name); + if (!delete_nodes.count(input_node_name)) { + next_inputs.insert(input_node_name); + } + } + } + current_inputs = std::vector(next_inputs.begin(), next_inputs.end()); + } + domi::tensorflow::GraphDef filtered_graph_def; + filtered_graph_def.mutable_node()->Clear(); + for (const NodeDef &node : input_graph_def.node()) { + if (input_nodes.count(node.name())) { + *(filtered_graph_def.mutable_node()->Add()) = node; + } + if (!delete_nodes.count(node.name())) { + *(filtered_graph_def.mutable_node()->Add()) = node; + } + } + output_graph_def->Clear(); + for (const NodeDef &node : filtered_graph_def.node()) { + if (input_nodes.count(node.name())) { + NodeDef placeholder_node; + placeholder_node = node; + placeholder_node.clear_input(); + GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); + domi::tensorflow::AttrValue attr_value; + TensorShapeProto *data_shape = attr_value.mutable_shape(); + GE_CHECK_NOTNULL(data_shape); + const ge::ParserContext &ctx = ge::GetParserContext(); + std::unordered_map> input_dims = ctx.input_dims; + std::vector designated_dims = input_dims.at(node.name()); + for (int32_t i = 0; i < (int32_t)designated_dims.size(); i++) { + data_shape->add_dim()->set_size(designated_dims[i]); + } + google::protobuf::Map *attr = placeholder_node.mutable_attr(); + (*attr)[TENSORFLOW_ATTR_SHAPE] = attr_value; + GE_CHECK_NOTNULL(output_graph_def->mutable_node()); + *(output_graph_def->mutable_node()->Add()) = placeholder_node; + } else { + GE_CHECK_NOTNULL(output_graph_def->mutable_node()); + *(output_graph_def->mutable_node()->Add()) = node; + } + } + return SUCCESS; +} +Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def, + domi::tensorflow::GraphDef *output_graph_def) { + // The caller guarantees that the pointer is not null + std::set required_nodes; + std::set input_nodes; + for (auto &iter : ge::GetParserContext().input_dims) { + required_nodes.insert(iter.first); + input_nodes.insert(iter.first); + } + for (auto &iter : ge::GetParserContext().out_nodes_map) { + required_nodes.insert(iter.first); + } + std::map node_lookup; + for (const NodeDef &node : input_graph_def.node()) { + node_lookup[node.name()] = &node; + } + std::vector current_inputs; + for (auto &iter : ge::GetParserContext().out_nodes_map) { + current_inputs.push_back(iter.first); + } + while (!current_inputs.empty()) { + std::set next_inputs; + for (const string ¤t_input : current_inputs) { + required_nodes.insert(current_input); + GE_IF_BOOL_EXEC(input_nodes.count(current_input), continue); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), ErrorManager::GetInstance().ATCReportErrMessage( + "E12012", {"opname"}, {current_input}); + return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); + const NodeDef *current_node = node_lookup[current_input]; + GE_CHECK_NOTNULL(current_node); + for (const string &input_name : current_node->input()) { + string input_node_name = NodeNameFromInput(input_name); + if (!required_nodes.count(input_node_name)) { + next_inputs.insert(input_node_name); + } + } + } + current_inputs = std::vector(next_inputs.begin(), next_inputs.end()); + } + domi::tensorflow::GraphDef filtered_graph_def; + filtered_graph_def.mutable_node()->Clear(); + for (const NodeDef &node : input_graph_def.node()) { + if (required_nodes.count(node.name())) { + *(filtered_graph_def.mutable_node()->Add()) = node; + } + } + output_graph_def->Clear(); + for (const NodeDef &node : filtered_graph_def.node()) { + if (input_nodes.count(node.name())) { + NodeDef placeholder_node; + placeholder_node = node; + placeholder_node.clear_input(); + GE_IF_BOOL_EXEC(node.op() != "Placeholder", placeholder_node.set_op("Placeholder")); + domi::tensorflow::AttrValue attr_value; + TensorShapeProto *data_shape = attr_value.mutable_shape(); + GE_CHECK_NOTNULL(data_shape); + const ge::ParserContext &ctx = ge::GetParserContext(); + std::unordered_map> input_dims = ctx.input_dims; + std::vector designated_dims = input_dims.at(node.name()); + for (int32_t i = 0; i < (int32_t)designated_dims.size(); i++) { + data_shape->add_dim()->set_size(designated_dims[i]); + } + google::protobuf::Map *attr = placeholder_node.mutable_attr(); + (*attr)[TENSORFLOW_ATTR_SHAPE] = attr_value; + GE_CHECK_NOTNULL(output_graph_def->mutable_node()); + *(output_graph_def->mutable_node()->Add()) = placeholder_node; + } else { + GE_CHECK_NOTNULL(output_graph_def->mutable_node()); + *(output_graph_def->mutable_node()->Add()) = node; + } + } + return SUCCESS; +} +string TensorFlowModelParser::NodeNameFromInput(const string &input_name) { + string prefix; + string node_name; + string suffix; + std::vector input_parts = ge::StringUtils::Split(input_name, ':'); + suffix = (input_parts.size() < kInputNumUint) ? "" : (":" + input_parts[1]); + string tmp_name = input_parts[0]; + GE_IF_BOOL_EXEC(input_parts[0].find("^") == 0, tmp_name = tmp_name.substr(1, tmp_name.length() - 1)); + node_name = tmp_name; + return node_name; +} + +Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr &op_parser, + const domi::tensorflow::NodeDef *node_def, ge::NodePtr &node) { + GE_CHECK_NOTNULL(node_def); + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(op_parser); + + GELOGI("FusionNodeParseParams:node name:%s.", node_def->name().c_str()); + + // The fusion operator deals with parseparams separately + shared_ptr tensorflow_fusion_op_parser = + std::dynamic_pointer_cast(op_parser); + GE_IF_BOOL_EXEC(tensorflow_fusion_op_parser == nullptr, + GELOGE(FAILED, "node :%s can not get fusion parser, please check!", node_def->name().c_str()); + return INTERNAL_ERROR); + + // Find all children of the fusion operator + auto iter = fusion_op_nodedef_map_.find(node_def->name()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter == fusion_op_nodedef_map_.end(), return INTERNAL_ERROR, + "FusionOp node %s has no children node!", node_def->name().c_str()); + + (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op()); + vector node_def_v = iter->second; + domi::FusionParseParamByOpFunc parse_param_func = + domi::OpRegistry::Instance()->GetFusionParseParamByOpFunc(node->GetType(), node_def->op()); + Status status = FAILED; + if (parse_param_func == nullptr) { + status = tensorflow_fusion_op_parser->ParseParams(node_def_v, node); + GE_CHK_STATUS_EXEC(status, return status, "Parse Params for fusionop node %s failed", node_def->name().c_str()); + } else { + vector op_src_vec; + for (const auto &node_def_src : node_def_v) { + ge::Operator op_src(node_def_src->name(), node_def_src->op()); + status = domi::AutoMappingFn(node_def_src, op_src); + if (status != SUCCESS) { + GELOGE(status, "Node[%s] auto mapping failed", node_def_src->name().c_str()); + return status; + } + auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_src); + GE_CHECK_NOTNULL(op_desc); + for (int32_t i = 0; i < node_def_src->input_size(); i++) { + ge::GeTensorDesc tensor_desc; + tensor_desc.SetName(node_def_src->input(i)); + if (op_desc->AddInputDesc(tensor_desc) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Op [%s] type[%s] add input(%d) tensor failed.", op_desc->GetName().c_str(), + op_desc->GetType().c_str(), i); + return FAILED; + } + } + op_src_vec.push_back(op_src); + } + shared_ptr tf_custom_fusion_op_paser = + std::dynamic_pointer_cast(tensorflow_fusion_op_parser); + status = tf_custom_fusion_op_paser->ParseParams(op_src_vec, node); + if (status != SUCCESS) { + GELOGE(status, "Parse params for fusionop node %s failed", node_def->name().c_str()); + return status; + } + } + + return SUCCESS; +} + +/** + * @ingroup domi_omg + * @brief Optimizing const nodes for custom operators + * @param [in] graph_def graph object + * @return true optimize successfully + * @return false optimize failed + * + */ +Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) { + GE_CHECK_NOTNULL(graph_def); + // 1. find all the nodes in the graph and save them to all_nodedef_map + map all_nodedef_map; + int graph_node_size = graph_def->node_size(); + for (int i = 0; i != graph_node_size; ++i) { + // mutable_node return vale is not empty + domi::tensorflow::NodeDef *current_node = graph_def->mutable_node(i); + string node_name = current_node->name(); + all_nodedef_map[node_name] = current_node; + } + GE_CHK_BOOL_EXEC_INFO(!all_nodedef_map.empty(), return SUCCESS, "all_nodedef_map is empty"); + + // 2. move input to attr. + for (auto &it_node_map : all_nodedef_map) { + domi::tensorflow::NodeDef *current_node = it_node_map.second; + GE_CHECK_NOTNULL(current_node); + string current_op_name = current_node->op(); + + // 2.1. check whether the current op is register for move to attr. + const std::vector &move_input_vec = + domi::OpRegistry::Instance()->GetRemoveInputConfigure(current_op_name); + GE_CHK_BOOL_EXEC_NOLOG(!move_input_vec.empty(), continue); + GELOGD("Current op %s is registered for remove input.", current_op_name.c_str()); + + // 2.2 check whether the current op is a TVM op. + GE_CHK_BOOL_EXEC_INFO( + domi::OpRegistry::Instance()->GetImplyTypeByOriOpType(current_op_name) == domi::ImplyType::TVM, continue, + "op %s is not TVM op", current_op_name.c_str()); + GELOGD("handle tvm op %s", current_op_name.c_str()); + + // 2.3 copy input to attr + set unused_inputs; + for (const auto &it : move_input_vec) { + uint32_t move_index; + if (it.inputIdx >= 0) { + move_index = it.inputIdx; + } else { + GE_IF_BOOL_EXEC( + -it.inputIdx > current_node->input_size(), + ErrorManager::GetInstance().ATCReportErrMessage( + "E12004", {"opname", "inputIdx", "inputsize"}, + {current_op_name, std::to_string(-it.inputIdx), std::to_string(current_node->input_size())}); + GELOGE(INTERNAL_ERROR, + "Op[%s] register failed, inputIdx[-%d] should be greater than inputsize[%d] when inputIdx < 0.", + current_op_name.c_str(), it.inputIdx, current_node->input_size()); + return PARAM_INVALID); + move_index = current_node->input_size() + it.inputIdx; + } + // For an isolated node in deep lab V3 networ. + // solve the problem of protobuf index less current_size. + GE_IF_BOOL_EXEC(current_node->input_size() == 0, GELOGI("Input size is 0, already optimized"); continue); + + if (it.moveType == domi::OMG_REMOVE_TYPE_WITH_COND) { + domi::tensorflow::AttrValue attr_value; + GE_IF_BOOL_EXEC(!(ge::TensorFlowUtil::FindAttrValue(current_node, it.attrName, attr_value)), + ErrorManager::GetInstance().ATCReportErrMessage("E12005", {"attrname"}, {current_op_name}); + GELOGE(INTERNAL_ERROR, "AttrName[%s] has no value!", it.attrName.c_str()); + return PARAM_INVALID); + GE_IF_BOOL_EXEC(attr_value.b() == it.attrValue, unused_inputs.insert(move_index)); + } else if (it.moveType == domi::OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE && it.originalType == current_op_name) { + GELOGD("Input %s:%d will be removed.", current_op_name.c_str(), move_index); + unused_inputs.insert(move_index); + } else if (it.moveType == domi::OMG_INPUT_REORDER) { + auto inputs = current_node->input(); + if (static_cast(inputs.size()) != it.input_order.size()) { + GELOGE(INTERNAL_ERROR, "Size of input is mismatched, new order size is %zu, input size is %d.", + it.input_order.size(), inputs.size()); + return INTERNAL_ERROR; + } + for (size_t i = 0; i < it.input_order.size(); ++i) { + int new_index = it.input_order[i]; + if (new_index < 0 || new_index >= inputs.size()) { + GELOGE(INTERNAL_ERROR, "New order of %s has invalid index %d.", it_node_map.first.c_str(), new_index); + return INTERNAL_ERROR; + } + current_node->set_input(i, inputs[new_index]); + } + GELOGI("The input sequence of the node has been rearranged, node name:%s.", it_node_map.first.c_str()); + } + } + + // 2.4 remove the input const nodes + Status ret = RemoveInputs(current_node, unused_inputs); + if (ret != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E12006", {"opname"}, {current_op_name}); + GELOGE(INTERNAL_ERROR, "Op[%s] remove input failed.", current_op_name.c_str()); + return ret; + } + } + + return SUCCESS; +} + +/** + * @ingroup domi_omg + * @brief Delete input from nodedef + * @param [in] node_def Nodedef object + * @param [in] remove_index_set Index collection of input nodes to be deleted + * @return true remove successfully + * @return false remove failed + * + */ +Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, const set &remove_index_set) { + GE_CHECK_NOTNULL(node_def); + if (remove_index_set.empty()) { + GELOGI("The size of remove_index_set is zero."); + return SUCCESS; + } + + map> remove_inputs_map; + for (auto &it : remove_index_set) { + const string &input_node_name = node_def->input(it); + remove_inputs_map[input_node_name].emplace_back(it); + GELOGD("Push input:%s, index:%d into remove map.", input_node_name.c_str(), it); + } + + RemoveInputAttr(node_def, remove_inputs_map); + + int index = 0; + auto *inputs = node_def->mutable_input(); + for (auto input_it = inputs->begin(); input_it != inputs->end(); ++index) { + // 1.decide whether to remove the input + bool flag = false; + for (auto &remove_input : remove_inputs_map) { + string remove_input_name = remove_input.first; + vector remove_input_indexs = remove_input.second; + if ((*input_it) == remove_input_name && + std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) { + GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index); + flag = true; + break; + } + } + + if (flag) { + // 2 remove the input + input_it = inputs->erase(input_it); + } else { + ++input_it; + } + } + + return SUCCESS; +} + +void TensorFlowModelParser::RemoveInputAttr(domi::tensorflow::NodeDef *node_def, + const map> &remove_inputs_map) { + // The caller guarantees that the pointer is not null + auto *inputs = node_def->mutable_input(); + google::protobuf::Map *attr_map = node_def->mutable_attr(); + const google::protobuf::Map::iterator it = + attr_map->find(ge::ATTR_NAME_INPUT_TENSOR_DESC); + if (it == attr_map->end()) { + GELOGW("Failed to find input desc from tf node_def[%s]", node_def->name().c_str()); + } else { + domi::tensorflow::AttrValue *input_attr_value = &(it->second); + auto tmp_attr = input_attr_value->mutable_list()->mutable_func(); + auto attr_it = tmp_attr->begin(); + int index = 0; + for (auto input_it = inputs->begin(); input_it != inputs->end(); ++input_it, ++index) { + // 1.decide whether to remove the input + bool flag = false; + for (auto &remove_input : remove_inputs_map) { + string remove_input_name = remove_input.first; + vector remove_input_indexs = remove_input.second; + if ((*input_it) == remove_input_name && + std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) { + GELOGD("Remove input attr:%s, index:%d", remove_input_name.c_str(), index); + flag = true; + break; + } + } + + if (flag) { + // 2.1 remove the input attr + if (!tmp_attr->empty() && attr_it != tmp_attr->end()) { + attr_it = tmp_attr->erase(attr_it); + } else { + ++attr_it; + } + } else { + ++attr_it; + } + } + } +} + +Status TensorFlowModelParser::GetTensorflowGraphInOutMap(domi::tensorflow::GraphDef *graph_def) { + GE_CHECK_NOTNULL(graph_def); + for (int i = 0; i < graph_def->node_size(); i++) { + domi::tensorflow::NodeDef *node = graph_def->mutable_node(i); + const string &node_name = node->name(); + node_inputs_outputs_map_.emplace(node_name, std::pair, set>{}); + for (const auto &input : node->input()) { + string input_node_name; + GE_RETURN_IF_ERROR(CheckInputNodeName(input, &input_node_name, nullptr, nullptr)); + node_inputs_outputs_map_[node_name].first.insert(input_node_name); + node_inputs_outputs_map_[input_node_name].second.insert(node_name); + } + } + return SUCCESS; +} + +Status TensorFlowModelParser::RemoveIsolateNode(domi::tensorflow::GraphDef *graph_def) { + GE_CHECK_NOTNULL(graph_def); + set node_to_delete; + for (int i = 0; i < graph_def->node_size(); i++) { + domi::tensorflow::NodeDef *node = graph_def->mutable_node(i); + const string &node_name = node->name(); + if (node_inputs_outputs_map_.find(node_name) == node_inputs_outputs_map_.end()) { + GELOGE(FAILED, "Can not find input output context, node:%s.", node_name.c_str()); + return FAILED; + } + if ((node_inputs_outputs_map_[node_name].first.empty() && node_inputs_outputs_map_[node_name].second.empty() && + node->op() != kDpop) || + (node->op() == ge::parser::CONSTANT && node_inputs_outputs_map_[node_name].second.empty())) { + GELOGI("%s will inset to node_to_delete", node_name.c_str()); + node_to_delete.insert(node_name); + } + } + + // delete isolate nodes + auto nodeList = graph_def->mutable_node(); + for (auto iter = nodeList->begin(); iter != nodeList->end();) { + if (node_to_delete.count(iter->name()) != 0) { + GELOGI("%s has zero input and output, will delete.", iter->name().c_str()); + iter = nodeList->erase(iter); + } else { + iter++; + } + } + return SUCCESS; +} + +Status TensorFlowModelParser::RecordFusionResult(std::shared_ptr &scope_graph, + const domi::tensorflow::NodeDef *node, ge::OpDescPtr &op_desc) { + // The caller guarantees that the pointer is not null + GELOGI("RecordFusionResult for %s start.", op_desc->GetName().c_str()); + auto &impl_scope_graph = scope_graph->impl_; + ge::FusionScopesResult *fusion_result = impl_scope_graph->GetFusionScopesResults(node); + if (fusion_result == nullptr) { + GELOGW("fusion_result is not found."); + return SUCCESS; + } + + std::vector original_names; + auto nodes = fusion_result->Nodes(); + std::transform(nodes.begin(), nodes.end(), std::back_inserter(original_names), + [](ge::OperatorPtr n) -> std::string { return n->GetName(); }); + + GELOGI("Op %s original_names size = %zu.", op_desc->GetName().c_str(), original_names.size()); + bool ret = ge::AttrUtils::SetListStr(op_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); + if (!ret) { + GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str()); + } + auto outputs_desc = op_desc->GetAllOutputsDesc(); + auto &impl = fusion_result->impl_; + for (auto &fusion_output : impl->GetOutputs()) { + for (size_t i = 0; i < fusion_output.second.size(); ++i) { + if (fusion_output.second[i] == ge::kFusionDisableIndex) { + continue; + } + + if (fusion_output.second[i] >= static_cast(op_desc->GetOutputsSize())) { + GELOGE(PARAM_INVALID, "fusion output index %d must less than outputs desc size %zu.", fusion_output.second[i], + op_desc->GetOutputsSize()); + return PARAM_INVALID; + } + + ret = ge::AttrUtils::SetStr(op_desc->MutableOutputDesc(fusion_output.second[i]), + ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, fusion_output.first); + if (!ret) { + GELOGW("Set %s to %s %d output fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME.c_str(), op_desc->GetName().c_str(), + fusion_output.second[i]); + } + + ret = ge::AttrUtils::SetInt(op_desc->MutableOutputDesc(fusion_output.second[i]), + ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, i); + if (!ret) { + GELOGW("Set %s to %s %d output fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX.c_str(), + op_desc->GetName().c_str(), fusion_output.second[i]); + } + } + } + + return SUCCESS; +} + +Status TensorFlowModelParser::SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context, + const std::vector> &inputs, + const std::vector> &outputs) { + int32_t in_index = 0; + for (const auto &in : inputs) { + bool is_ctrl = in.second == kControlSlot; + op_node_context.input_map[in.first].emplace_back(std::make_pair(in.second, is_ctrl ? kControlSlot : in_index)); + SaveEdgesControlInfo(node_def->name(), is_ctrl); + in_index = is_ctrl ? in_index : in_index + 1; + } + int32_t out_index = 0; + for (const auto &out : outputs) { + bool is_ctrl = out.second == kControlSlot; + op_node_context.output_map[out.first].emplace_back(std::make_pair(is_ctrl ? kControlSlot : out_index, out.second)); + out_index = is_ctrl ? out_index : out_index + 1; + } + return SUCCESS; +} + +void TensorFlowModelParser::GetFusionInputInfo( + const string &fusion_op_name, OpNodeContext &fusion_context, + std::map>> &remap_data_input, + std::map> &remap_ctrl_input, std::set &fusion_input_nodes) { + for (const auto &fusion_input : fusion_context.input_map) { + string fusion_src_name = fusion_input.first; + for (const auto &fusion_idx_pair : fusion_input.second) { + string key = fusion_op_name + std::to_string(fusion_idx_pair.second); + if (fusion_idx_pair.second != kControlSlot) { + remap_data_input[key] = {fusion_src_name, {fusion_idx_pair.first, fusion_idx_pair.second}}; + } else { + remap_ctrl_input[key].emplace_back(fusion_src_name); + } + } + fusion_input_nodes.insert(fusion_src_name); + } +} + +void TensorFlowModelParser::GetFusionOutputInfo( + const string &fusion_op_name, OpNodeContext &fusion_context, + std::map>>> &remap_data_output, + std::map> &remap_ctrl_output, std::set &fusion_output_nodes) { + for (const auto &fusion_output : fusion_context.output_map) { + string fusion_dst_name = fusion_output.first; + for (const auto &fusion_idx_pair : fusion_output.second) { + string key = fusion_op_name + std::to_string(fusion_idx_pair.first); + if (fusion_idx_pair.first != kControlSlot) { + remap_data_output[key].emplace_back( + std::make_pair(fusion_dst_name, std::make_pair(fusion_idx_pair.first, fusion_idx_pair.second))); + } else { + remap_ctrl_output[key].emplace_back(fusion_dst_name); + } + } + fusion_output_nodes.insert(fusion_dst_name); + } +} + +void TensorFlowModelParser::UpdateInnerInputMap(const string &fusion_op_name, OpNodeContext &fusion_context, + const std::vector &inner_nodes_name, + std::set &fusion_input_nodes) { + std::map>> remap_data_input; + std::map> remap_ctrl_input; + GetFusionInputInfo(fusion_op_name, fusion_context, remap_data_input, remap_ctrl_input, fusion_input_nodes); + + for (const auto &node_name : inner_nodes_name) { + auto context_iter = op_node_context_map_.find(node_name); + if (context_iter != op_node_context_map_.end()) { + OpNodeContext &op_node_context = context_iter->second; + // update input map of inner node + std::map>> tmp_input_map; + for (auto iter = op_node_context.input_map.begin(); iter != op_node_context.input_map.end();) { + string src_name = iter->first; + std::vector> &input_idx = iter->second; + if (src_name == ge::kInputFromFusionScope) { + for (const auto &in_pair : input_idx) { + if (in_pair.second != kControlSlot) { + auto data = remap_data_input[fusion_op_name + std::to_string(in_pair.first)]; + tmp_input_map[data.first].emplace_back(std::make_pair(data.second.first, in_pair.second)); + GELOGI("Update inner input, src:%s, idx:%u->%u", data.first.c_str(), data.second.first, in_pair.second); + } + } + auto ctrl = remap_ctrl_input[fusion_op_name + std::to_string(kControlSlot)]; + for (const auto &ctrl_in : ctrl) { + tmp_input_map[ctrl_in].emplace_back(std::make_pair(kControlSlot, kControlSlot)); + SaveEdgesControlInfo(node_name, kControlSlot); + } + iter = op_node_context.input_map.erase(iter); + } else { + ++iter; + } + } + op_node_context.input_map.insert(tmp_input_map.begin(), tmp_input_map.end()); + // update output map of pre node + for (const auto &in_iter : op_node_context.input_map) { + auto src_iter = op_node_context_map_.find(in_iter.first); + if (src_iter != op_node_context_map_.end()) { + std::vector> input_pairs = in_iter.second; + OpNodeContext &src_context = src_iter->second; + src_context.output_map[node_name].assign(input_pairs.begin(), input_pairs.end()); + } + } + } + } +} + +void TensorFlowModelParser::UpdateInnerOutputMap(const string &fusion_op_name, OpNodeContext &fusion_context, + const std::vector &inner_nodes_name, + std::set &fusion_output_nodes) { + std::map>>> remap_data_output; + std::map> remap_ctrl_output; + GetFusionOutputInfo(fusion_op_name, fusion_context, remap_data_output, remap_ctrl_output, fusion_output_nodes); + for (const auto &node_name : inner_nodes_name) { + auto context_iter = op_node_context_map_.find(node_name); + if (context_iter != op_node_context_map_.end()) { + OpNodeContext &op_node_context = context_iter->second; + // update output map of inner node + std::map>> tmp_output_map; + for (auto iter = op_node_context.output_map.begin(); iter != op_node_context.output_map.end();) { + string dst_name = iter->first; + std::vector> &output_idx = iter->second; + if (dst_name == ge::kOutputToFusionScope) { + for (const auto &out_pair : output_idx) { + if (out_pair.second != kControlSlot) { + auto data_outputs = remap_data_output[fusion_op_name + std::to_string(out_pair.second)]; + for (const auto &data : data_outputs) { + tmp_output_map[data.first].emplace_back(std::make_pair(out_pair.first, data.second.second)); + GELOGI("Update inner output, dst:%s, idx:%u->%u.", data.first.c_str(), out_pair.first, + data.second.second); + } + } + } + auto ctrl = remap_ctrl_output[fusion_op_name + std::to_string(kControlSlot)]; + for (const auto &ctrl_in : ctrl) { + tmp_output_map[ctrl_in].emplace_back(std::make_pair(kControlSlot, kControlSlot)); + } + iter = op_node_context.output_map.erase(iter); + } else { + ++iter; + } + } + op_node_context.output_map.insert(tmp_output_map.begin(), tmp_output_map.end()); + // update input map of pre node + for (const auto &out_iter : op_node_context.output_map) { + auto dst_iter = op_node_context_map_.find(out_iter.first); + if (dst_iter != op_node_context_map_.end()) { + std::vector> output_pairs = out_iter.second; + OpNodeContext &dst_context = dst_iter->second; + dst_context.input_map[node_name].assign(output_pairs.begin(), output_pairs.end()); + } + } + } + } +} + +Status TensorFlowModelParser::UpdateInnerNodeContext(const string &fusion_op_name, + const std::vector &inner_nodes_name) { + auto fusion_iter = op_node_context_map_.find(fusion_op_name); + if (fusion_iter == op_node_context_map_.end()) { + GELOGE(INTERNAL_ERROR, "Can't find context for fusion node %s.", fusion_op_name.c_str()); + return INTERNAL_ERROR; + } + OpNodeContext &fusion_context = fusion_iter->second; + std::set fusion_input_nodes; + std::set fusion_output_nodes; + UpdateInnerInputMap(fusion_op_name, fusion_context, inner_nodes_name, fusion_input_nodes); + UpdateInnerOutputMap(fusion_op_name, fusion_context, inner_nodes_name, fusion_output_nodes); + for (const auto &in_name : fusion_input_nodes) { + auto fusion_in = op_node_context_map_.find(in_name); + if (fusion_in != op_node_context_map_.end()) { + OpNodeContext &fusion_in_context = fusion_in->second; + fusion_in_context.output_map.erase(fusion_op_name); + } + } + for (const auto &out_name : fusion_output_nodes) { + auto fusion_out = op_node_context_map_.find(out_name); + if (fusion_out != op_node_context_map_.end()) { + OpNodeContext &fusion_out_context = fusion_out->second; + fusion_out_context.input_map.erase(fusion_op_name); + } + } + op_node_context_map_.erase(fusion_op_name); + return SUCCESS; +} + +Status TensorFlowModelParser::AddFusionInnerNodeDef(shared_ptr &scope_graph, + const string &fusion_op_name, vector &node_name_list) { + auto &impl_scope_graph = scope_graph->impl_; + GE_CHECK_NOTNULL(impl_scope_graph); + ge::FusionScopesResult *fusion_result = impl_scope_graph->GetFusionScopesResults(fusion_op_name); + GE_CHECK_NOTNULL(fusion_result); + auto &impl_fusion_rlt = fusion_result->impl_; + GE_CHECK_NOTNULL(impl_fusion_rlt); + ge::FusionInnerNodesInfo inner_nodes_info = impl_fusion_rlt->GetInnerNodesInfo(); + vector inner_nodes_name; + for (const auto &info : inner_nodes_info) { + string node_name; + string type; + std::vector> inputs; + std::vector> outputs; + const ge::Operator *op = nullptr; + std::tie(node_name, type, inputs, outputs, op) = info; + NodeDef *node_def = new (std::nothrow) NodeDef(); + GE_CHECK_NOTNULL(node_def); + node_def->set_name(node_name); + node_def->set_op(type); + nodedef_map_[node_name] = node_def; + fusion_nodedef_list.push_back(node_def); + for (const auto &in : inputs) { + // The input value is not used in the subsequent process. The value is added only for placeholders. + node_def->add_input(in.first); + } + domi::tensorflow::AttrValue attr_value; + attr_value.set_b(true); + ge::TensorFlowUtil::AddNodeAttr(kAttrNameIsScopeInnerNode, attr_value, node_def); + OpNodeContext &op_node_context = op_node_context_map_[node_name]; + Status ret = SetOriginNodeContext(node_def, op_node_context, inputs, outputs); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to add context and attrs, node:%s.", node_name.c_str()); + return ret; + } + scope_inner_node_map_.insert({node_name, op}); + node_name_list.emplace_back(node_name); + inner_nodes_name.emplace_back(node_name); + GELOGI("Add fusion inner node def, name:%s, type:%s.", node_name.c_str(), type.c_str()); + } + Status ret = UpdateInnerNodeContext(fusion_op_name, inner_nodes_name); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to update inner node context, fusion_op_name:%s.", fusion_op_name.c_str()); + return ret; + } + return SUCCESS; +} + +Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr &scope_graph, + vector &node_name_list) { + vector node_name_list_new; + size_t op_node_list_size = node_name_list.size(); + DumpAllNodeContext("BeforeAddFusionNodeDef"); + for (size_t i = 0; i < op_node_list_size; ++i) { + const string op_node_name = node_name_list[i]; + auto iter = fusion_op_nodedef_map_.find(op_node_name); + if (iter != fusion_op_nodedef_map_.end()) { + vector fusion_op_info = fusion_op_type_map_[op_node_name]; + if (fusion_op_info[0] != ge::kScopeToMultiNodes) { + NodeDef *node_def = new (std::nothrow) NodeDef(); + GE_CHECK_NOTNULL(node_def); + node_def->set_name(op_node_name); + node_def->set_op(fusion_op_info[0]); + nodedef_map_[op_node_name] = node_def; + fusion_nodedef_list.push_back(node_def); + OpNodeContext &node_context = op_node_context_map_[node_def->name()]; + for (const auto &input : node_context.input_map) { + // The input value is not used in the subsequent process. The value is added only for placeholders. + node_def->add_input(input.first); + } + node_name_list_new.emplace_back(op_node_name); + GELOGI("Add Fusion node def, name:%s, type:%s.", node_def->name().c_str(), node_def->op().c_str()); + } else { + Status ret = AddFusionInnerNodeDef(scope_graph, op_node_name, node_name_list_new); + if (ret != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E12028", {"opname"}, {op_node_name}); + GELOGE(ret, "Failed to add fusion inner node, fusion_op_name:%s.", op_node_name.c_str()); + return ret; + } + GELOGI("Add fusion inner nodes successfully, fusion name:%s.", op_node_name.c_str()); + op_node_context_map_.erase(op_node_name); + } + } else { + node_name_list_new.emplace_back(op_node_name); + } + } + node_name_list.clear(); + node_name_list.assign(node_name_list_new.begin(), node_name_list_new.end()); + DumpAllNodeContext("AfterAddFusionNodeDef"); + return SUCCESS; +} + +Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, + std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def) { + // This is an internal function. The pointer input parameter is not empty when this function is invoked. + string node_name = node_def->name(); + string node_op = node_def->op(); + auto iter = parser->scope_inner_node_map_.find(node_name); + if (iter == parser->scope_inner_node_map_.end()) { + GELOGE(PARAM_INVALID, "Failed to find scope inner node:%s, type:%s.", node_name.c_str(), node_op.c_str()); + return PARAM_INVALID; + } + const ge::Operator *op = iter->second; + ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op); + GE_CHECK_NOTNULL(op_desc); + ge::NodePtr node; + { + std::lock_guard lock(*graph_mutex); + node = graph->AddNode(op_desc); + } + if (node == nullptr) { + GELOGE(INTERNAL_ERROR, "Failed to Add scope inner node:%s, type:%s.", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + return INTERNAL_ERROR; + } + { + std::lock_guard lock(parser->nodeMapMutex_); + parser->node_map_[node_name] = node; + } + GELOGI("Add scope inner node successfully, node name:%s, type:%s.", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); + return SUCCESS; +} + +void TensorFlowModelParser::DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase) { + GELOGD("phase:%s === Begin to dump context for node:%s ===", phase.c_str(), node_name.c_str()); + for (const auto &input : ctx.input_map) { + for (const auto &input_idx : input.second) { + GELOGD(" Input info: %s:%d --> in_idx %d.", input.first.c_str(), input_idx.first, input_idx.second); + } + } + for (const auto &output : ctx.output_map) { + for (const auto &output_idx : output.second) { + GELOGD(" Output info: out_idx %d --> %s:%d.", output_idx.first, output.first.c_str(), output_idx.second); + } + } + GELOGD("phase:%s === End to dump context for node:%s ===", phase.c_str(), node_name.c_str()); +} + +void TensorFlowModelParser::DumpAllNodeContext(const string &phase) { + if (!IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { + return; + } + for (const auto &iter : op_node_context_map_) { + DumpNodeContext(iter.first, iter.second, phase); + } +} +} // namespace ge + +namespace domi { +REGISTER_MODEL_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowModelParser); +REGISTER_WEIGHTS_PARSER_CREATOR(TENSORFLOW, ge::TensorFlowWeightsParser); +} // namespace domi diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h new file mode 100644 index 0000000..ac313fc --- /dev/null +++ b/parser/tensorflow/tensorflow_parser.h @@ -0,0 +1,681 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_ +#define PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "common/op/ge_op_utils.h" +#include "graph/compute_graph.h" +#include "graph/ge_attr_value.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/range_vistor.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/tensor_utils.h" +#include "omg/parser/model_parser.h" +#include "omg/parser/op_parser.h" +#include "omg/parser/weights_parser.h" +#include "parser/tensorflow/tensorflow_fusion_op_parser.h" +#include "parser/tensorflow/tensorflow_fusionop_util.h" +#include "parser/tensorflow/tensorflow_util.h" +#include "proto/om.pb.h" +#include "proto/tensorflow/graph.pb.h" +#include "proto/tensorflow/node_def.pb.h" +#include "proto/tensorflow/graph_library.pb.h" +#include "external/register/scope/scope_fusion_pass_register.h" +#include "scope/scope_pass_manager.h" + +using ge::ScopePassManager; +using domi::tensorflow::GraphDef; +using domi::tensorflow::DT_HALF; +using domi::tensorflow::NodeDef; +using domi::tensorflow::GraphDef; +using domi::tensorflow::AttrValue; +using domi::tensorflow::DataType; +using ge::OpParser; + +namespace ge { +using std::string; +using std::vector; +using std::set; +using std::map; +using std::unordered_map; +using std::mutex; +using std::shared_ptr; + +enum TfTranspose { TO_NCHW, TO_NHWC, NO_TRANSPOSE }; + +struct OpNodeContext { + // save for input + std::map>> input_map; + // save for output + std::map>> output_map; +}; + +struct DelTransposeInfo; +class TensorFlowModelParser : public domi::ModelParser { + public: + TensorFlowModelParser() {} + virtual ~TensorFlowModelParser() {} + + /** + * @ingroup domi_omg + * @brief Parse the relevant data from the model file and save it to graph + * @param [in] file Path of the model file + * @param [in|out] graph save model information after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + + */ + Status Parse(const char *file, ge::Graph &graph) override; + + Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; + + /** + * @ingroup domi_omg + * @brief Convert model files to JSON format + * @param [in] model_file Model file path to be converted + * @param [out] json_file Converted JSON file path + * @return SUCCESS parse successfully + * @return others parse failed + */ + Status ToJson(const char *model_file, const char *json_file) override; + + /** + * @ingroup domi_omg + * @brief Parse the relevant data from the model file and save it to graph + * @param [in] graph_def input tensorflow model + * @param [in|out] graph save model informati:on after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + */ + Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override; + + Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, + domi::GetGraphCallback callback, + ge::ComputeGraphPtr &graph) override; + + /* + * @ingroup domi_omg + * @brief Mapping TF's datatype to GE's datatype + * @param [in] type, datatype types of operators in TF networks + * @return ge::DataType + */ + ge::DataType ConvertToGeDataType(const uint32_t type) override; + + Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override ; + + private: + Status Parse(const char *file, ge::ComputeGraphPtr &graph); + + /** + * @ingroup domi_omg + * @brief Add node information to graph + * @param [in|out] op_node_name_list + * @param [in|out] graph save model information after parsing + * @return SUCCESS add successfully + * @return FAILED add failed + + */ + Status AddFmkNode(ge::ComputeGraphPtr &graph, shared_ptr &scope_graph, + vector &op_node_name_list, bool is_dataset_init = false); + + Status AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &graph, const vector &op_node_name_list); + + /** + * @ingroup domi_omg + * @brief Add node def into node map + * @param NodeDef* + * @return SUCCESS add successfully + * @return FAILED add failed + + */ + Status AddFmkNodeDefToMap(const domi::tensorflow::GraphDef &graph_def, const domi::tensorflow::NodeDef *node_def, + vector &op_node_name_list); + + /** + * @ingroup domi_omg + * @brief Add node information to graph + * @param [in] layer layer infomation + * @param [in|out] graph save model information after parsing + * @return SUCCESS add successfully + * @return FAILED add failed + + */ + Status AddNode(const domi::tensorflow::NodeDef *node_def, + ge::ComputeGraphPtr &graph, + shared_ptr &scope_graph); + /** + * @ingroup domi_omg + * @brief Add edge information to graph + * @param [in|out] graph save model information after parsing + * @return SUCCESS add successfully + * @return FAILED add failed + + */ + Status AddEdges(ge::ComputeGraphPtr &graph); + + /** + * @ingroup domi_omg + * @brief get op context from the parsed graph + */ + Status GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def); + + /** + * @ingroup domi_omg + * @brief get input,include opNode and constNode + * @param [in] op_node_name op name + * @param [out] input_map input node and index + * @return SUCCESS get successfully + * @return FAILED get failed + */ + Status GetOpNodeInputMap(const string &op_node_name, + map>> &input_map); + + /** + * @ingroup domi_omg + * @brief get output of node + * @param [in] graph_def graph + * @return SUCCESS get successfully + * @return FAILED get failed + */ + Status GetOpNodeOutputMap(const domi::tensorflow::GraphDef &graph_def); + + /** + * @ingroup domi_omg + * @brief Verifying the validity of graphdef object parsed by pb + * @param [in] graph_def Parsed tensorflow:: graphdef object + * @return SUCCESS check successfully + * @return FAILED check failed + */ + Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def); + + /** + * @ingroup domi_omg + * @brief whether const OP need to update context + * @param const op name + * @return true or false + */ + bool ConstOpNeedUpdate(const string &op_name); + + + Status ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *graph_def, shared_ptr &scope_graph); + /** + * @ingroup domi_omg + * @brief Run the scope fusion optimizer in list scope_passes_list + * @param [in] scope_passes_list optimizer list + * @param [in/out] pass_manager an object to manager the optimizers + * @param [in/out] scope_graph Save the result of scope fusion + * @return SUCCESS Run successfully + * @return others Run failed + */ + Status RunScopeFusionPass(const vector &scope_passes_list, + ScopePassManager &pass_manager, + shared_ptr &scope_graph); + + /** + * @ingroup domi_omg + * @brief Check whether the nodedef parsed from pb is a fusion operator, put NodeDef into fusion_op_nodedef_map_ + * @param [in] graph_def Parsed tensorflow:: graphdef object + * @return maybe a fusion operator + */ + bool MaybeFusionOp(shared_ptr &scope_graph, const domi::tensorflow::NodeDef *node_def); + + /** + * @Confirm whether it is a child operator of the fusion operator + */ + bool IsFusionOpChild(const string &node_name, ge::ScopeFusionOpInfo *info); + + /** + * @brief Inner child operators of fusion operators + */ + bool FusionOpChildIgnore(shared_ptr &scope_graph, const ge::ScopeFusionOpInfo &info); + + // Is it a fusion operator + bool IsFusionOp(shared_ptr &scope_graph, const domi::tensorflow::NodeDef *node_def); + + /** + * @brief get inPut index of the fusion operator + * @param [in] info Child node description of fusion operator + * @param [in] old_index Child node original index + * @return old_index as input index of the fusion operator + * @return return code + */ + static Status GetInPutIndex(shared_ptr &scope_graph, + const ge::ScopeFusionOpInfo &info, + const int32_t old_index, + int32_t &new_index); + + /** + * @brief get output index of the fusion operator + * @param [in] info Child node description of fusion operator + * @param [in] old_index Child node original index + * @return old_index as output index of the fusion operator + * @return return code + */ + static Status GetOutPutIndex(shared_ptr &scope_graph, + const ge::ScopeFusionOpInfo &info, + const int32_t old_index, + int32_t &new_index); + /** + * @ingroup domi_omg + * @brief Check the validity of fusionop,put it into op_node_name_list if Misjudgement + * @param op_node_name_list + * @return SUCCESS check successfully + * @return FAILED check failed + + */ + Status CheckFusionOpValid(); + + /** + * @ingroup domi_omg + * @brief Update input-output relationships of all operators + * @param graph_def和op_node_name_list + * @return SUCCESS + * @return FAILED + + */ + Status UpdateAllNodeOpContext(shared_ptr &scope_graph, const domi::tensorflow::GraphDef &graph_def, + vector &op_node_name_list); + + /** + * @ingroup domi_omg + * @brief Updating the input-output relationship of fusion operators + * @param info Description of fusion operator + * @param fusion_op_node_context Input-output relationship of fusion operator + * @param normal_op_node_context Input-output relationship of normal operators + * @return SUCCESS + * @return FAILED + + */ + Status UpdateFusionOpContext(shared_ptr &scope_graph, const ge::ScopeFusionOpInfo &info, + OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); + + /** + * @ingroup domi_omg + * @brief Updating the input-output relationship of normal operators + * @param normal_op_node_context Input-output relationship of normal operators + * @return SUCCESS + * @return FAILED + + */ + Status UpdateNormalOpContext(shared_ptr &scope_graph, const string &op_node_name, + OpNodeContext &normal_op_node_context); + + Status EraseNormalOpOutputIfChild(shared_ptr &scope_graph, const string &op_node_name, + OpNodeContext &normal_op_node_context); + + /** + * @ingroup domi_omg + * @brief Normalized I / O relationship: de duplication and de outliers + + */ + Status NormalizeAllNodeOpContext(); + + /** + * @ingroup domi_omg + * @brief Normalized I / O relationship: according to context map, de duplicate and de outliers + + */ + Status NormalizeInputOrOutputMap(std::map>> &context_map); + + /** + * @ingroup domi_omg + * @brief delete fusionNodeDef + + */ + void DeleteFuisonNodeDef(); + + /** + * @ingroup domi_omg + * @brief Save the control attribute to edges control map + + */ + void SaveEdgesControlInfo(const string &node_name, const bool control); + + /** + * @ingroup domi_omg + * @brief Update the control property to edges control map + + */ + void UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info); + + /** + * @ingroup domi_omg + * @brief get contral information + + */ + bool GetEdgesControlInfo(const string &node_name, const int32_t index); + + /** + * @ingroup domi_omg + * @brief Check the validity of input_name + * @param input_node_name,Consider the input: n scenario + * @param index ,return index,"input":return 0,"input:n":return n + * @param index ,control index, input: "^cond/switch_t" + * @return SUCCESS + * @return FAILED + + */ + Status CheckInputNodeName(const string &input_node_name, string *node_name, int32_t *index, bool *control); + + /** + * @ingroup domi_omg + * @brief ge stoi + * @param input_node_name,Consider the input: n scenario + * @param index_str ,stoi param + * @param index ,return index,"input":return 0,"input:n":return n + * @return SUCCESS + * @return FAILED + + */ + Status GeStoi(const string &input_node_name, const string &index_str, int32_t *index); + + /** + * @ingroup domi_omg + * @brief Clearing the error information of non key operators in fusion operators + + */ + Status ClearFusionOpError(const vector &op_node_name_list); + + /** + * @ingroup domi_omg + * @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef + */ + Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def); + /** + * @ingroup domi_omg + * @brief Optimize for Identity/ReadVariableOp operator + * @param [in] graph_def GraphDef to be optimized + * @param [in] nodedef_map Map of all nodes in graph + * @param [in] nodedef_to_optimize vector of NodeDef to be optimized + * @return SUCCESS optimize successfully + * @return others failed + */ + Status GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, map &nodedef_map, + const vector &nodedef_to_optimize); + /** + * @ingroup domi_omg + * @brief For the identity operator whose output is "_retval", optimize it. + * @param [in] nodedef_map Map of all nodes in graph + * @param [in] curr_node_name Name of node to be optimized + * @param [in] clear_input_flag Flag of whether to clear the input of the current node + * @return SUCCESS optimize successfully + * @return others failed + */ + Status OptimizeIdentityByOutput(map &nodedef_map, const string &curr_node_name, + bool &clear_input_flag); + Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map &nodedef_map, + const vector &nodedef_to_optimize); + Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, + domi::tensorflow::NodeDef *nodeCurrent); + Status OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, map &nodedef_map, + const std::pair &input_data, const std::vector &control_list); + void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *nodeCurrent, + bool &clearInputFlag); + void OptimizeTranspose(std::map &transposeInfo); + void SoftmaxAddAttr(GraphDef *graph_def); + + /** + * @ingroup domi_omg + * @brief Delete isolated nodes in graph + */ + Status RemoveIsolateNode(ge::ComputeGraphPtr &graph); + + /** + * @ingroup domi_omg + * @brief Infer format for input ops. + + */ + domiTensorFormat_t InferInputFormats(); + + /** + * @ingroup domi_omg + * @brief Get node format. + + */ + Status GetNodeFormat(const NodeDef *node, TfTranspose pred_transpose, domiTensorFormat_t &format, + set &visited_node); + + /** + * @ingroup domi_omg + * @brief Get format transpose. + + */ + Status GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc); + Status TrimGraph(const domi::tensorflow::GraphDef &input_graph_def, domi::tensorflow::GraphDef *output_graph_def); + Status TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def, + domi::tensorflow::GraphDef *output_graph_def); + Status TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def, + domi::tensorflow::GraphDef *output_graph_def); + string NodeNameFromInput(const string &input_name); + + Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node); + Status CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node); + void UpdateInputTensor(ge::OpDescPtr &op_desc, const std::vector &input_desc, + const size_t input_tensor_num); + void UpdateOutputTensor(ge::OpDescPtr &op_desc, const std::vector &output_desc, + size_t output_tensor_num); + Status TransNodeToOpDesc(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, const string &op_type); + + Status UppdateInputMap(shared_ptr &scope_graph, const ge::ScopeFusionOpInfo &info, + OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); + Status UppdateOutputMap(shared_ptr &scope_graph, const ge::ScopeFusionOpInfo &info, + OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); + void GetInputOutputTensorNum (ge::OpDescPtr &op_desc, size_t &input_tensor_num, size_t &output_tensor_num) const; + Status CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set &dims, bool &valid); + Status CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type); + + /** + * @ingroup domi_omg + * @brief Trans common decorate function to PartitionedCall. + * @param [in] node_def: Node of common function. + * @param [out] op: result of PartitionedCall OpDesc. + * @return 0: SUCCESS / Others: FAILED + */ + Status DefunToPartitionedCall(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op); + + /** + * @ingroup domi_omg + * @brief Calling ParseParams method of fusion operator + * @param op_parser,op parser of the fusion operator + * @return SUCCESS + * @return FAILED + + */ + Status FusionNodeParseParams(shared_ptr &op_parser, + const domi::tensorflow::NodeDef *node_def, ge::NodePtr &node); + + /** + * @ingroup domi_omg + * @brief Optimizing const nodes for custom operators + * @param [in] graph_def graph object + * @return true optimize successfully + * @return false optimize failed + * + */ + Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def); + + /** + * @ingroup domi_omg + * @brief Delete input from nodedef + * @param [in] node_def Nodedef object + * @param [in] remove_index_set Index collection of input nodes to be deleted + * @return true remove successfully + * @return false remove failed + * + */ + Status RemoveInputs(domi::tensorflow::NodeDef *node_def, const set &remove_index_set); + + void RemoveInputAttr(domi::tensorflow::NodeDef *node_def, const map> &remove_inputs_map); + + /** + * @ingroup domi_omg + * @brief Parse the parameters in nodedef and construct Ge node. + * This function is a thread function,Parallel parse nodedef in tensorflow graph + * The member variables that need to be modified in this function should be locked + * @param [in] parser TensorFlowModelParser + * @param [in] graph ge graph + * @param [in] graphMutex ge graph lock + * @param [in] scope_graph + * @param [in] node_def Nodedef + * @return SUCCESS + * @return FAILED + * + */ + static Status ParseNodeDef(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, std::mutex *graphMutex, + shared_ptr &scope_graph, const domi::tensorflow::NodeDef *node_def); + + /** + * @ingroup domi_omg + * @brief adape op type + * @param [in] node_def Nodedef + * @param [in] isDatasetInit + * @return SUCCESS adapt successfully + * @return others adapt failed + * + */ + Status AdaptOpType(const domi::tensorflow::NodeDef *node_def, bool isDatasetInit); + + Status GetTensorflowGraphInOutMap(domi::tensorflow::GraphDef *graph_def); + Status RemoveIsolateNode(domi::tensorflow::GraphDef *graph_def); + static Status RecordFusionResult(std::shared_ptr &scope_graph, + const domi::tensorflow::NodeDef *node, + ge::OpDescPtr &op_def); + + Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library); + + Status SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context, + const std::vector> &inputs, + const std::vector> &outputs); + + void GetFusionInputInfo(const string &fusion_op_name, OpNodeContext &fusion_context, + std::map>> &remap_data_input, + std::map> &remap_ctrl_input, + std::set &fusion_input_nodes); + + void GetFusionOutputInfo(const string &fusion_op_name, OpNodeContext &fusion_context, + std::map>>> &remap_data_output, + std::map> &remap_ctrl_output, + std::set &fusion_output_nodes); + + void UpdateInnerInputMap(const string &fusion_op_name, OpNodeContext &fusion_context, + const std::vector &inner_nodes_name, + std::set &fusion_input_nodes); + + void UpdateInnerOutputMap(const string &fusion_op_name, OpNodeContext &fusion_context, + const std::vector &inner_nodes_name, + std::set &fusion_output_nodes); + + Status UpdateInnerNodeContext(const string &fusion_op_name, const std::vector &inner_nodes_name); + + Status AddFusionInnerNodeDef(shared_ptr &scope_graph, + const string &fusion_op_name, + vector &node_name_list); + + Status AddFusionNodeDef(shared_ptr &scope_graph, vector &node_name_list); + + static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, + std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def); + + void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase); + void DumpAllNodeContext(const string &phase); + + Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr &op_parser); + + /** + * save + */ + unordered_map nodedef_map_; + + /** + * context, Input output relationship + */ + unordered_map op_node_context_map_; + + /** + * Name of node of OP type, corresponding to node of DaVinci + */ + std::unordered_map node_map_; + + /** + * node_map_ Multithreaded write operation is involved, requiring lock protection + */ + std::mutex nodeMapMutex_; + + /** + * save + */ + unordered_map> fusion_op_nodedef_map_; + // Policy types of fusion operators,true:scope_pass match,false:prefix match + unordered_map fusion_op_policy_; + // The names of all children operators and the description of fusion operators + unordered_map fusion_op_children_; + /** + * save + */ + unordered_map> fusion_op_type_map_; + /** + * save nodedef of the fusion operator + */ + vector fusion_nodedef_list; + /** + * control edge,{Key=NodeName,Value=index} + */ + unordered_map> edges_control_map; + + unordered_map framework_ops_; + + /** + * save + */ + unordered_map adaptedOpTypeMap_; + + // { node_name <{input_node_name}, {output_node_name}> } + unordered_map, set>> node_inputs_outputs_map_; + + unordered_map scope_inner_node_map_; +}; + +/** + * @ingroup domi_omg + * @brief Tensorflow weight parse + */ +class TensorFlowWeightsParser : public domi::WeightsParser { + public: + /** + * @ingroup domi_omg + * @brief Parse weight data from file and save to graph + * @param [in] file Path of weight file after training + * @param [in|out] graph Save weight information after analysis + * @return SUCCESS parse successfully + * @return PARAM_INVALID param invalid + * @return PARSE_WEIGHTS_FAILED parse failed + */ + Status Parse(const char *file, ge::Graph &graph) override; + + Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override; +}; +} // namespace domi +#endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_parser_register.h b/parser/tensorflow/tensorflow_parser_register.h new file mode 100644 index 0000000..6ff0e2e --- /dev/null +++ b/parser/tensorflow/tensorflow_parser_register.h @@ -0,0 +1,132 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Copyright (c) <2018>, +#ifndef PARSER_TENSORFLOW_TENSORFLOW_PARSER_REGISTER_H_ +#define PARSER_TENSORFLOW_TENSORFLOW_PARSER_REGISTER_H_ + +#include +#include +#include +#include "framework/common/util.h" +#include "framework/omg/parser/op_parser.h" +#include "parser/common/op_def/ir_pb_converter.h" +#include "parser/common/op_def/operator.h" +#include "common/ge/ge_util.h" +#include "parser/common/op_parser_factory.h" +#include "parser/tensorflow/tensorflow_op_parser.h" +#include "proto/tensorflow/node_def.pb.h" + +using domi::tensorflow::NodeDef; + +namespace ge { +class TensorflowFinalizeable { + public: + virtual bool Finalize() = 0; + virtual ~TensorflowFinalizeable() {} +}; + +class TensorflowReceiver { + public: + TensorflowReceiver(TensorflowFinalizeable &f) { f.Finalize(); } + ~TensorflowReceiver() {} +}; + +namespace tensorflow_parser { +template +class TensorflowParserBuilder; + +class TensorflowWeightParserBuilder : public TensorflowFinalizeable { + public: + virtual ~TensorflowWeightParserBuilder() {} +}; + +template +class TensorflowOpParserAdapter; + +template +class TensorflowParserBuilder : public TensorflowWeightParserBuilder { + public: + using ParseParamsFn = std::function; + + explicit TensorflowParserBuilder(const std::string &davinci_optype) : davinci_optype_(davinci_optype) {} + + ~TensorflowParserBuilder() {} + + TensorflowParserBuilder &SetParseParamsFn(ParseParamsFn parse_params_fn) { + parse_params_fn_ = parse_params_fn; + return *this; + } + + bool Finalize() override { + auto op_parser_adapter = ge::MakeShared>(*this); + if (op_parser_adapter == nullptr) { + GELOGE(FAILED, "Op parser adapter is null."); + } + // register to OpParserFactory + OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( + domi::TENSORFLOW, davinci_optype_, [=] { return std::shared_ptr(op_parser_adapter); }); + return true; + } + + private: + std::string davinci_optype_; // op type in davinci model + + ParseParamsFn parse_params_fn_; + + friend class TensorflowOpParserAdapter; +}; + +template +class TensorflowOpParserAdapter : public TensorFlowOpParser { + using ParseParamsFn = std::function; + + public: + TensorflowOpParserAdapter(TensorflowParserBuilder builder) { parse_params_fn_ = builder.parse_params_fn_; } + + ~TensorflowOpParserAdapter() {} + + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override { + const domi::tensorflow::NodeDef *node = static_cast(op_src); + GE_CHECK_NOTNULL(node); + std::shared_ptr param = ge::MakeShared(); + if (param == nullptr) { + GELOGE(domi::FAILED, "Param is null"); + return domi::FAILED; + } + GE_RETURN_IF_ERROR(parse_params_fn_(node, param.get())); + param.get()->Name(node->name()); + std::shared_ptr op_param = std::static_pointer_cast(param); + ConvertToOpDesc(*op_param, op_dest); + + return domi::SUCCESS; + } + + private: + ParseParamsFn parse_params_fn_; +}; +} // namespace tensorflow_parser + +#define DOMI_REGISTER_TENSORFLOW_PARSER(name, param_clazz) \ + DOMI_REGISTER_TENSORFLOW_PARSER_UNIQ_HELPER(__COUNTER__, name, param_clazz) +#define DOMI_REGISTER_TENSORFLOW_PARSER_UNIQ_HELPER(ctr, name, param_clazz) \ + DOMI_REGISTER_TENSORFLOW_PARSER_UNIQ(ctr, name, param_clazz) +#define DOMI_REGISTER_TENSORFLOW_PARSER_UNIQ(ctr, name, param_clazz) \ + static TensorflowReceiver register_tensorflow_parser##ctr __attribute__((unused)) = \ + tensorflow_parser::TensorflowParserBuilder(name) +} // namespace ge + +#endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_REGISTER_H_ diff --git a/parser/tensorflow/tensorflow_ref_switch_parser.cc b/parser/tensorflow/tensorflow_ref_switch_parser.cc new file mode 100644 index 0000000..08b7a30 --- /dev/null +++ b/parser/tensorflow/tensorflow_ref_switch_parser.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_ref_switch_parser.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/op/ge_op_utils.h" +#include "parser/common/op_def/ir_pb_converter.h" +#include "parser/common/op_def/ref_switch_op.h" +#include "parser/common/op_parser_factory.h" + +using domi::tensorflow::DataType; +using domi::tensorflow::DT_FLOAT; +using domi::tensorflow::AttrValue; +using domi::tensorflow::NodeDef; +using domi::TENSORFLOW; +using namespace ge::parser; + +namespace ge { +// AUTO GEN PLEASE DO NOT MODIFY IT +Status TensorFlowRefSwitchParser::ParseT(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op) { + // The upper caller guarantees node is not empty. + domi::tensorflow::AttrValue attr; + + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_T, attr), + op->T(domi::TensorAssign::ConvertTensorflowDataType(DT_FLOAT)); + return SUCCESS); + + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed"); + + domi::tensorflow::DataType tfType = attr.type(); + ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tfType); + CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED, GELOGE(FAILED, "Data type %s of node %s is not supported.", + DataType_Name(tfType).c_str(), node->name().c_str()); + return PARAM_INVALID); + + op->T(type); + + return SUCCESS; +} + +Status TensorFlowRefSwitchParser::ParseParams(const Message *opSrc, ge::OpDescPtr &opDest) { + GE_CHECK_NOTNULL(opSrc); + const NodeDef *node = DOMI_DYNAMIC_CAST(opSrc); + GE_CHECK_NOTNULL(node); + + RefSwitchOperator op; + op.Name(node->name()); + + GELOGI("RefSwitch Op %s ParseParams Begin.", node->name().c_str()); + GE_RETURN_IF_ERROR(PreParseParams(node, &op)); + + GE_RETURN_WITH_LOG_IF_ERROR(ParseT(node, &op), "Parse T for node %s failed.", node->name().c_str()); + + GE_RETURN_IF_ERROR(PostParseParams(node, &op)); + + Status status = ConvertToOpDesc(op, opDest); + + return status; +} + +// AUTO GEN PLEASE DO NOT MODIFY IT +Status TensorFlowRefSwitchParser::PreParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op) { + return SUCCESS; +} + +Status TensorFlowRefSwitchParser::PostParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op) { + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFSWITCH, TensorFlowRefSwitchParser); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_ref_switch_parser.h b/parser/tensorflow/tensorflow_ref_switch_parser.h new file mode 100644 index 0000000..723ebfc --- /dev/null +++ b/parser/tensorflow/tensorflow_ref_switch_parser.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ +#define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ + +#include "common/op_def/ref_switch_op.h" +#include "parser/tensorflow/tensorflow_op_parser.h" + +using domi::tensorflow::NodeDef; + +namespace ge { +class TensorFlowRefSwitchParser : public TensorFlowOpParser { + // AUTO GEN PLEASE DO NOT MODIFY IT + public: + /** + * @ingroup domi_omg + * @brief 解析模型文件信息 + * @param [in] v_input_const 待解析的模型数据 + * @param [out] node 解析后的模型数据 + * @return SUCCESS 解析成功 + * @return FAILED 解析失败 + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; + + protected: + /** + * @ingroup domi_omg + * @brief 解析模型文件信息 + * @param [in] v_input_const 待解析的模型数据 + * @param [out] node 解析后的模型数据 + * @return SUCCESS 解析成功 + * @return FAILED 解析失败 + */ + Status PreParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op); + + /** + * @ingroup domi_omg + * @brief 解析模型文件信息 + * @param [in] v_input_const 待解析的模型数据 + * @param [out] node 解析后的模型数据 + * @return SUCCESS 解析成功 + * @return FAILED 解析失败 + */ + Status PostParseParams(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op); + + /** + * @ingroup domi_omg + * @brief 解析模型文件信息 + * @param [in] v_input_const 待解析的模型数据 + * @param [out] node 解析后的模型数据 + * @return SUCCESS 解析成功 + * @return FAILED 解析失败 + */ + Status ParseT(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op); + + // AUTO GEN PLEASE DO NOT MODIFY IT +}; +} // namespace ge + +#endif // DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_REF_SWITCH_H_ diff --git a/parser/tensorflow/tensorflow_reshape_parser.cc b/parser/tensorflow/tensorflow_reshape_parser.cc new file mode 100644 index 0000000..d579f72 --- /dev/null +++ b/parser/tensorflow/tensorflow_reshape_parser.cc @@ -0,0 +1,95 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_reshape_parser.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/omg/omg.h" +#include "graph/utils/type_utils.h" +#include "parser/common/op_parser_factory.h" +#include "parser/tensorflow/tensorflow_util.h" +#include "common/math/math_util.h" + +using domi::TENSORFLOW; +using namespace ge::parser; + +namespace ge { +Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { + int32_t tf_datatype = 0; + auto a_list = attr_value.list(); + GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), PARAM_INVALID, + "parse ge_desc failed."); + uint32_t size_type = 1; + int64_t real_size = 1; + int64_t tmp_dim = 0; + + auto data_type = ge_desc.GetDataType(); + bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); + GE_IF_BOOL_EXEC(!type_ret, GELOGE(FAILED, "Can't GetDataTypeLength of data_type: %s", + ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); + return PARAM_INVALID); + // calculate size + for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { + tmp_dim = ge_desc.GetShape().GetDim(j); + GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); + real_size *= tmp_dim; + } + FMK_INT64_MULCHECK(real_size, size_type); + ge::TensorUtils::SetSize(ge_desc, real_size * size_type); + ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); + GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u", + ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(), + ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type); + return SUCCESS; +} + +Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op); + + const NodeDef *node_src = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(node_src); + GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); + domi::tensorflow::AttrValue input_attr_value; + domi::tensorflow::AttrValue output_attr_value; + + GE_IF_BOOL_EXEC( + GetParserContext().train_flag == true, + + ge::GeTensorDesc input_desc; + ge::GeTensorDesc output_desc; + + if (TensorFlowUtil::FindAttrValue(node_src, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { + GE_CHK_BOOL_RET_STATUS(SUCCESS == ParseDesc(input_attr_value, input_desc), FAILED, "parse input desc failed"); + } + + if (TensorFlowUtil::FindAttrValue(node_src, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { + GE_CHK_BOOL_RET_STATUS(SUCCESS == ParseDesc(output_attr_value, output_desc), FAILED, + "parse output desc failed"); + } + + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED, + "set input desc failed"); + + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED, + "set output desc failed");); + + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, RESHAPE, TensorFlowReshapeParser); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_reshape_parser.h b/parser/tensorflow/tensorflow_reshape_parser.h new file mode 100644 index 0000000..975982a --- /dev/null +++ b/parser/tensorflow/tensorflow_reshape_parser.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_TENSORFLOW_TENSORFLOW_RESHAPE_PARSER_H_ +#define PARSER_TENSORFLOW_TENSORFLOW_RESHAPE_PARSER_H_ + +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowReshapeParser : public TensorFlowOpParser { + private: + Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); + + public: + /** + * @ingroup domi_omg + * @brief parse weight information + * @param [in] v_input_const weight data to be parsed + * @param [out] op_dest weight data after parsing + * @return SUCCESS parse successfully + * @return FAILED parse failed + * @author + */ + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; +}; +} // namespace ge + +#endif // PARSER_TENSORFLOW_TENSORFLOW_RESHAPE_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_shape_n_parser.cc b/parser/tensorflow/tensorflow_shape_n_parser.cc new file mode 100644 index 0000000..e8c0e9c --- /dev/null +++ b/parser/tensorflow/tensorflow_shape_n_parser.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_shape_n_parser.h" +#include "parser/common/op_def/ir_pb_converter.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/op/ge_op_utils.h" +#include "parser/common/op_parser_factory.h" +#include "parser/common/op_def/shape_n_op.h" + +using domi::TENSORFLOW; +using domi::tensorflow::AttrValue; +using domi::tensorflow::DataType; +using domi::tensorflow::DT_FLOAT; +using domi::tensorflow::DT_INT32; +using namespace ge::parser; + +namespace { + const std::string kShapeAttrDtype = "out_type"; +} // namespace + +namespace ge { +// AUTO GEN PLEASE DO NOT MODIFY IT + +Status TensorFlowShapeNParser::ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { + // The upper caller guarantees the input params is not empty. + domi::tensorflow::AttrValue attr; + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_T, attr), + op->InType(domi::TensorAssign::ConvertTensorflowDataType(DT_FLOAT)); + return SUCCESS); + + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed"); + + domi::tensorflow::DataType tf_type = attr.type(); + ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_type); + CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED, GELOGE(FAILED, "Data type %s of node %s is not supported.", + DataType_Name(tf_type).c_str(), node->name().c_str()); + return PARAM_INVALID); + + op->InType(type); + + return SUCCESS; +} + +Status TensorFlowShapeNParser::ParseOutType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { + // The upper caller guarantees the input params is not empty. + domi::tensorflow::AttrValue attr; + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, kShapeAttrDtype, attr), + op->OutType(domi::TensorAssign::ConvertTensorflowDataType(DT_INT32)); + return SUCCESS); + + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed"); + + domi::tensorflow::DataType tf_type = attr.type(); + ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_type); + CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED, GELOGE(FAILED, "Data type %s of node %s is not supported.", + DataType_Name(tf_type).c_str(), node->name().c_str()); + return PARAM_INVALID); + + op->OutType(type); + + return SUCCESS; +} + +Status TensorFlowShapeNParser::ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { + // The upper caller guarantees the input params is not empty. + domi::tensorflow::AttrValue attr; + const int64_t attr_n = 2; + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr), op->N(attr_n); return SUCCESS); + + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "int"), "check Attr N failed"); + + op->N(attr.i()); + + return SUCCESS; +} + +Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { + GE_CHECK_NOTNULL(op_dest); + const NodeDef *node = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(node); + ShapeNOperator op; + op.Name(node->name()); + + GE_RETURN_IF_ERROR(PreParseParams(node, &op)); + + GE_RETURN_WITH_LOG_IF_ERROR(ParseInType(node, &op), "Parse in type for node %s failed.", node->name().c_str()); + + GE_RETURN_WITH_LOG_IF_ERROR(ParseN(node, &op), "Parse N for node %s failed.", node->name().c_str()); + + GE_RETURN_WITH_LOG_IF_ERROR(ParseOutType(node, &op), "Parse out type for node %s failed.", node->name().c_str()); + + GE_RETURN_IF_ERROR(PostParseParams(node, &op)); + + // add dynamic input/output + domi::tensorflow::AttrValue attr_num; + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr_num), + GELOGE(FAILED, "Get Attr N failed in Node %s.", node->name().c_str()); + return PARAM_INVALID); + int32_t dynamic_tensor_num = attr_num.i(); + + Status ret; + domi::tensorflow::AttrValue output_attr_value; + if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { + GE_CHK_STATUS_RET( + TensorFlowUtil::TransTensorDescriptor(output_attr_value, &op, TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG), + "trans output_attr_value failed, op: %s", node->name().c_str()); + ret = ConvertToOpDesc(op, op_dest); + if (ret != SUCCESS) { + return ret; + } + } else { + ret = ConvertToOpDesc(op, op_dest); + if (ret != SUCCESS) { + return ret; + } + graphStatus status = op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num); + if (status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add dynamic output:y for node:%s failed.", op_dest->GetName().c_str()); + return FAILED; + } + } + graphStatus status = op_dest->AddDynamicInputDesc("x", dynamic_tensor_num); + if (status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add dynamic input:x for node:%s failed.", op_dest->GetName().c_str()); + return FAILED; + } + GELOGI("add dynamic input and output for op [%s], type[%s], name: %s, number:%d", op_dest->GetName().c_str(), + op_dest->GetType().c_str(), SHAPEN_ATTR_N.c_str(), dynamic_tensor_num); + return SUCCESS; +} + +// AUTO GEN PLEASE DO NOT MODIFY IT +Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { + return SUCCESS; +} + +Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) { + return SUCCESS; +} + +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SHAPEN, TensorFlowShapeNParser); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_shape_n_parser.h b/parser/tensorflow/tensorflow_shape_n_parser.h new file mode 100644 index 0000000..0447262 --- /dev/null +++ b/parser/tensorflow/tensorflow_shape_n_parser.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ +#define DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ + +#include "common/op_def/shape_n_op.h" +#include "parser/tensorflow/tensorflow_op_parser.h" + +using domi::tensorflow::NodeDef; + +namespace ge { +class TensorFlowShapeNParser : public TensorFlowOpParser { + // AUTO GEN PLEASE DO NOT MODIFY IT + public: + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; + + protected: + Status PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); + Status PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); + + Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); + Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); + Status ParseOutType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); + + // AUTO GEN PLEASE DO NOT MODIFY IT +}; +} // namespace ge + +#endif // DOMI_OMG_PARSER_OP_PARSER_TENSORFLOW_SHAPE_N_H_ diff --git a/parser/tensorflow/tensorflow_squeeze_parser.cc b/parser/tensorflow/tensorflow_squeeze_parser.cc new file mode 100644 index 0000000..6bc0408 --- /dev/null +++ b/parser/tensorflow/tensorflow_squeeze_parser.cc @@ -0,0 +1,136 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_squeeze_parser.h" +#include +#include +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/op/attr_value_util.h" +#include "framework/common/op/op_parser_util.h" +#include "framework/common/util.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "graph/utils/type_utils.h" +#include "parser/common/op_parser_factory.h" +#include "common/math/math_util.h" + +using domi::tensorflow::AttrValue; +using std::vector; +using std::shared_ptr; +using domi::TENSORFLOW; +using namespace ge::parser; + +namespace ge { +Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { + int32_t tf_datatype = 0; + auto a_list = attr_value.list(); + GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), domi::PARAM_INVALID, + "parse ge_desc failed."); + uint32_t size_type; + int64_t real_size = 1; + int64_t tmp_dim = 0; + + auto data_type = ge_desc.GetDataType(); + bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); + GE_IF_BOOL_EXEC(!type_ret, GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s", + ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); + return domi::PARAM_INVALID); + // calculate size + for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { + tmp_dim = ge_desc.GetShape().GetDim(j); + GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); + FMK_INT64_MULCHECK(real_size, tmp_dim); + real_size *= tmp_dim; + } + FMK_INT64_MULCHECK(real_size, size_type); + ge::TensorUtils::SetSize(ge_desc, real_size * size_type); + ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); + GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u", + ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(), + ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type); + return SUCCESS; +} + +Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { + GE_CHECK_NOTNULL(op_src); + GE_CHECK_NOTNULL(op); + + const NodeDef *node = DOMI_DYNAMIC_CAST(op_src); + GE_CHECK_NOTNULL(node); + GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); + bool has_axis = true; + bool has_dims = true; + + domi::tensorflow::AttrValue axis; + domi::tensorflow::AttrValue dims; + if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis)) { + has_axis = false; + } + if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims)) { + has_dims = false; + } + if (!has_axis && !has_dims) { + return SUCCESS; + } + if (has_axis && has_dims) { + GELOGE(FAILED, "In NodeDef %s dim and axis is error.", node->name().c_str()); + return domi::PARAM_INVALID; + } + + domi::tensorflow::AttrValue_ListValue values; + if (has_axis) { + values = axis.list(); + } else { + values = dims.list(); + } + int i = 0; + int size = values.i_size(); + vector v_result; + for (i = 0; i < size; i++) { + int32_t result = values.i(i); + v_result.push_back(result); + } + if (!ge::AttrUtils::SetListInt(op, SQUEEZE_ATTR_AXIS, v_result)) { + GELOGE(FAILED, "Set squeeze axis attr failed"); + return FAILED; + } + + domi::tensorflow::AttrValue input_attr_value; + domi::tensorflow::AttrValue output_attr_value; + + GE_IF_BOOL_EXEC( + GetParserContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; + + if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { + GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed"); + } + + if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { + GE_CHK_BOOL_RET_STATUS(ParseDesc(output_attr_value, output_desc) == SUCCESS, FAILED, + "parse output desc failed"); + } + + if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) { + GELOGE(FAILED, "Set input desc failed"); + return FAILED; + } if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) { + GELOGE(FAILED, "Set output desc failed"); + return FAILED; + }) + return SUCCESS; +} +REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SQUEEZE, TensorFlowSqueezeParser); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_squeeze_parser.h b/parser/tensorflow/tensorflow_squeeze_parser.h new file mode 100644 index 0000000..a621c5e --- /dev/null +++ b/parser/tensorflow/tensorflow_squeeze_parser.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_TENSORFLOW_TENSORFLOW_SQUEEZE_PARSER_H_ +#define PARSER_TENSORFLOW_TENSORFLOW_SQUEEZE_PARSER_H_ + +#include "parser/tensorflow/tensorflow_op_parser.h" + +namespace ge { +class TensorFlowSqueezeParser : public TensorFlowOpParser { + public: + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; + + private: + Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); +}; +} // namespace ge + +#endif // PARSER_TENSORFLOW_TENSORFLOW_SQUEEZE_PARSER_H_ diff --git a/parser/tensorflow/tensorflow_util.cc b/parser/tensorflow/tensorflow_util.cc new file mode 100644 index 0000000..4ab109a --- /dev/null +++ b/parser/tensorflow/tensorflow_util.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "parser/tensorflow/tensorflow_util.h" +#include +#include +#include +#include +#include "common/math/math_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/op/ge_op_utils.h" +#include "framework/omg/parser/parser_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/ge_tensor.h" +#include "graph/utils/type_utils.h" +#include "parser/tensorflow/tensorflow_op_parser.h" + +using domi::tensorflow::DT_INVALID; + +namespace ge { +using AttrValueMap = ::google::protobuf::Map; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue( + const domi::tensorflow::NodeDef *node_def, const string &attr_name, domi::tensorflow::AttrValue &attr_value) { + GE_CHECK_NOTNULL(node_def); + const google::protobuf::Map &attr = node_def->attr(); + const google::protobuf::Map::const_iterator it = attr.find(attr_name); + if (it != attr.end()) { + attr_value = it->second; + return true; + } + + return false; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType( + const domi::tensorflow::AttrValue &attr_value, const string &type) { + uint32_t num_set = 0; +#define VALIDATE_FIELD(name, type_string, oneof_case) \ + do { \ + if (attr_value.has_list()) { \ + if (attr_value.list().name##_size() > 0) { \ + if (type != "list(" type_string ")") { \ + GELOGE(FAILED, "GeAttrValue had value with type 'list(" type_string ")'when '%s' expected", type.c_str()); \ + return FAILED; \ + } \ + ++num_set; \ + } \ + } else if (attr_value.value_case() == domi::tensorflow::AttrValue::oneof_case) { \ + if (type != type_string) { \ + GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \ + return FAILED; \ + } \ + ++num_set; \ + } \ + } while (false) + + VALIDATE_FIELD(s, "string", kS); + VALIDATE_FIELD(i, "int", kI); + VALIDATE_FIELD(f, "float", kF); + VALIDATE_FIELD(b, "bool", kB); + VALIDATE_FIELD(type, "type", kType); + VALIDATE_FIELD(shape, "shape", kShape); + VALIDATE_FIELD(tensor, "tensor", kTensor); + VALIDATE_FIELD(func, "func", kFunc); + +#undef VALIDATE_FIELD + + if (attr_value.value_case() == domi::tensorflow::AttrValue::kPlaceholder) { + GELOGE(FAILED, "GeAttrValue had value with unexpected type 'placeholder'"); + return FAILED; + } + + // Okay to have an empty list, but not to be missing a non-list value. + if ((num_set == 0) && (!ge::StringUtils::StartWith(type, "list("))) { + GELOGE(FAILED, "GeAttrValue missing value with expected type '%s'", type.c_str()); + return FAILED; + } + + // Ref types and DT_INVALID are illegal, and DataTypes must + // be a valid enum type. + if (type == "type") { + if (!domi::tensorflow::DataType_IsValid(attr_value.type())) { + GELOGE(FAILED, "GeAttrValue has invalid DataType enum: %d", attr_value.type()); + return FAILED; + } + if (attr_value.type() == DT_INVALID) { + GELOGE(FAILED, "GeAttrValue has invalid DataType"); + return FAILED; + } + } else if (type == "list(type)") { + for (auto &as_int : attr_value.list().type()) { + const domi::tensorflow::DataType dtype = static_cast(as_int); + if (!domi::tensorflow::DataType_IsValid(dtype)) { + GELOGE(FAILED, "GeAttrValue has invalid DataType enum: %d", as_int); + return FAILED; + } + if (dtype == DT_INVALID) { + GELOGE(FAILED, "GeAttrValue contains invalid DataType"); + return FAILED; + } + } + } + + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( + const NodeDef *node_src, const string &attr_src, domi::tensorflow::DataType &data_type) { + GE_CHECK_NOTNULL(node_src); + + string node_name = node_src->name(); + + // Find the value of attr_src from node_src + domi::tensorflow::AttrValue attr_value; + GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(FindAttrValue(node_src, attr_src, attr_value), + "In NodeDef %s Attr %s is not exist.", node_name.c_str(), attr_src.c_str()); + + // Check whether the attr_src.value contains the type field + GE_RETURN_WITH_LOG_IF_ERROR(CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_TYPE), "check Attr %s failed", + attr_src.c_str()); + + data_type = attr_value.type(); + + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromAttrValueList( + ge::GeTensorDesc &ge_desc, const domi::tensorflow::AttrValue_ListValue &a_list, int32_t i, int32_t &tf_datatype) { + const std::string SERIALIZE_FORMAT = "serialize_format"; + const std::string SERIALIZE_DATATYPE = "serialize_datatype"; + const std::string SERIALIZE_SHAPE = "serialize_shape"; + + ge_desc.SetFormat(ge::FORMAT_ND); + ge_desc.SetOriginFormat(ge::FORMAT_ND); + + tf_datatype = a_list.func(i).attr().at(SERIALIZE_DATATYPE).i(); + ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_datatype); + GE_CHK_BOOL_RET_STATUS(type != ge::DataType::DT_UNDEFINED, PARAM_INVALID, + "In FrameworkOp translate datatype:%d failed, domi cann't support.", tf_datatype); + ge_desc.SetDataType(type); + int shape_dim_dim = a_list.func(i).attr().at(SERIALIZE_SHAPE).list().i_size(); + vector data_dim; + for (int j = 0; j < shape_dim_dim; j++) { + data_dim.push_back(a_list.func(i).attr().at(SERIALIZE_SHAPE).list().i(j)); + } + ge_desc.SetShape(ge::GeShape(data_dim)); + ge_desc.SetOriginShape(ge::GeShape(data_dim)); + return true; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::TransTensorDescriptor( + const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const string &type) { + GE_CHECK_NOTNULL(op); + if (!attr_value.has_list()) { + return PARAM_INVALID; + } + + vector tf_in_type; + vector tf_out_type; + // list contain many TensorDescriptors + domi::tensorflow::AttrValue_ListValue a_list = attr_value.list(); + for (int32_t i = 0; i < a_list.func_size(); i++) { + ge::GeTensorDesc ge_desc; + int32_t tf_datatype = 0; + GE_CHK_BOOL_RET_STATUS(ParseFromAttrValueList(ge_desc, a_list, i, tf_datatype), PARAM_INVALID, + "parse ge_desc failed."); + + uint32_t size_type = 1; + int64_t tmp_dim = 0; + + auto data_type = ge_desc.GetDataType(); + GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID, + "dataType no define size , parse ge_desc failed."); + // get size + for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { + tmp_dim = ge_desc.GetShape().GetDim(j); + + // The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. + // Here, special treatment is given to the two operators. + // Adjust shape to fit resnet50 network only. + GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); + break;); + GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), vector data_dim = {tmp_dim}; + ge_desc.SetShape(ge::GeShape(data_dim)); break;); + } + ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); + + GELOGD("IO:%d: after translate tf_desc, datatype: %s, format: %s, size_type: %u", io, + ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(), + ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), size_type); + + if (io == TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG) { + op->InputTensorDesc(ge_desc); + tf_in_type.push_back(tf_datatype); + } else if (io == TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG) { + op->OutputTensorDesc(ge_desc); + tf_out_type.push_back(tf_datatype); + } + } + op->AttrVector(ge::T_IN_DATATYPE, tf_in_type); + op->AttrVector(ge::T_OUT_DATATYPE, tf_out_type); + return SUCCESS; +} +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( + const string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) { + GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); + node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); +} +} // namespace ge diff --git a/parser/tensorflow/tensorflow_util.h b/parser/tensorflow/tensorflow_util.h new file mode 100644 index 0000000..40e780f --- /dev/null +++ b/parser/tensorflow/tensorflow_util.h @@ -0,0 +1,208 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_ +#define OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_ + +#include +#include +#include +#include +#include +#include "parser/common/op_def/operator.h" +#include "external/graph/attr_value.h" +#include "external/graph/graph.h" +#include "external/graph/operator.h" +#include "framework/omg/parser/parser_types.h" +#include "framework/omg/omg_inner_types.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "proto/tensorflow/graph.pb.h" +using std::string; +using std::vector; +using domi::tensorflow::NodeDef; +using domi::tensorflow::FunctionDef; +using domi::tensorflow::AttrValue_ListValue; +using domi::tensorflow::FunctionDefLibrary; + +namespace ge { +/***************************TensorFlow attribute type, constant definition*******************************************/ +static const string TENSORFLOW_ATTR_TYPE_STRING = "string"; +static const string TENSORFLOW_ATTR_TYPE_INT = "int"; +static const string TENSORFLOW_ATTR_TYPE_FLOAT = "float"; +static const string TENSORFLOW_ATTR_TYPE_BOOL = "bool"; +static const string TENSORFLOW_ATTR_TYPE_TYPE = "type"; +static const string TENSORFLOW_ATTR_TYPE_SHAPE = "shape"; +static const string TENSORFLOW_ATTR_TYPE_TENSOR = "tensor"; +static const string TENSORFLOW_ATTR_TYPE_FUNC = "func"; + +static const string TENSORFLOW_ATTR_LIST_TYPE_STRING = "list(string)"; +static const string TENSORFLOW_ATTR_LIST_TYPE_INT = "list(int)"; +static const string TENSORFLOW_ATTR_LIST_TYPE_FLOAT = "list(float)"; +static const string TENSORFLOW_ATTR_LIST_TYPE_BOOL = "list(bool)"; +static const string TENSORFLOW_ATTR_LIST_TYPE_TYPE = "list(type)"; +static const string TENSORFLOW_ATTR_LIST_TYPE_SHAPE = "list(shape)"; +static const string TENSORFLOW_ATTR_LIST_TYPE_TENSOR = "list(tensor)"; +static const string TENSORFLOW_ATTR_LIST_TYPE_FUNC = "list(func)"; + +/***************************constant definition*******************************************/ +static const string TENSORFLOW_ATTR_OUTPUT_OP = "output_op"; + +static const string TENSORFLOW_ATTR_T = "T"; +static const string TENSORFLOW_ATTR_N = "N"; +static const string TENSORFLOW_ATTR_DATA_FORMAT = "data_format"; +static const string TENSORFLOW_ATTR_PADDING = "padding"; +static const string TENSORFLOW_ATTR_KSIZE = "ksize"; +static const string TENSORFLOW_ATTR_STRIDES = "strides"; +static const string TENSORFLOW_ATTR_DILATIONS = "dilations"; +static const string TENSORFLOW_ATTR_DTYPE = "dtype"; +static const string TENSORFLOW_ATTR_VALUE = "value"; +static const string TENSORFLOW_ATTR_TRANSINPUT = "transpose_a"; +static const string TENSORFLOW_ATTR_TRANSWEIGHT = "transpose_b"; +static const string TENSORFLOW_ATTR_SHAPE = "shape"; +static const string TENSORFLOW_ATTR_TIDX = "Tidx"; +static const string TENSORFLOW_ATTR_TPADDINGS = "Tpaddings"; +static const string TENSORFLOW_ATTR_TMULTIPLES = "Tmultiples"; +static const string TENSORFLOW_ATTR_TINDICES = "Tindices"; +static const string TENSORFLOW_ATTR_TPARAMS = "Tparams"; +static const string TENSORFLOW_ATTR_TAXIS = "Taxis"; +static const string TENSORFLOW_ATTR_DSTT = "DstT"; +static const string TENSORFLOW_ATTR_SRCT = "SrcT"; +static const string TENSORFLOW_ATTR_PERM = "perm"; +static const string TENSORFLOW_ATTR_INDEX = "Index"; +static const string TENSORFLOW_ATTR_TSHAPE = "Tshape"; +static const string TENSORFLOW_ATTR_AXIS = "Axis"; +static const string TENSORFLOW_ATTR_BIAS = "bias"; +static const string TENSORFLOW_ATTR_DEPTH_RADIUS = "depth_radius"; +static const string TENSORFLOW_ATTR_ALPHA = "alpha"; +static const string TENSORFLOW_ATTR_BETA = "beta"; +static const string TENSORFLOW_ATTR_MODE = "mode"; + +// op:Const +static const string TENSORFLOWF_NODE_OP_CONST = "Const"; +static const string TENSORFLOWF_NODE_OP_IDENTITY = "Identity"; +static const string TENSORFLOWF_NODE_OP_SWITCH = "Switch"; +static const string TENSORFLOWF_NODE_OP_PLACEHOLDER = "Placeholder"; +static const string TENSORFLOWF_NODE_OP_ADDN = "AddN"; +static const string TENSORFLOWF_NODE_OP_MATMUL = "MatMul"; +static const string TENSORFLOWF_NODE_OP_RELU = "Relu"; +static const string TENSORFLOWF_NODE_OP_SHAPE = "Shape"; +static const string TENSORFLOWF_NODE_OP_TRANSPOSE = "Transpose"; +static const string TENSORFLOWF_NODE_OP_MERGE = "Merge"; + +// data_format +static const string TENSORFLOWF_TENSOR_NCHW = "NCHW"; +static const string TENSORFLOWF_TENSOR_NHWC = "NHWC"; + +static const int TENSORFLOW_CONV_STRIDE_NUM = 4; +static const int TENSORFLOW_CONV_DILATION_NUM = 4; + +// padding +static const string TENSORFLOWF_OP_PADDING_VALID = "VALID"; +static const string TENSORFLOWF_OP_PADDING_SAME = "SAME"; + +// normal input size +static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL = 2; +static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_RESHAPE = 1; +static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_POOL = 1; + +// normal weight size +static const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_MATMUL = 1; +static const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_RESHAPE = 1; + +// input or output +static const uint32_t TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG = 1; +static const uint32_t TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG = 2; + +class TensorFlowUtil { + public: + /** + * @ingroup domi_omg + * @brief find the corresponding AttrValue in NodeDef + * @param [in] nodeDef nodedef object to find + * @param [in] attr_name attribute name + * @param [out] attr_value attribute value + * @return true attribute exists + * @return false attribute does not exist + * + */ + static bool FindAttrValue(const domi::tensorflow::NodeDef *nodeDef, const string &attr_name, + domi::tensorflow::AttrValue &attr_value); + + /** + * @ingroup domi_omg + * @brief Check the actual type and expected type of the AttrValue, int, float, list (int), list (bool), etc. + * @param [in] attr_value attrValue to check + * @param [in] type expected attribute type + * @return SUCCESS success + * @return FAILED failed + * + */ + static domi::Status CheckAttrHasType(const domi::tensorflow::AttrValue &attr_value, const string &type); + + /** + * @ingroup domi_omg + * @brief parsing data types + * @param [in] node_src node to be parsed + * @param [in] attr_src attribute to be parsed + * @param [out] data_type parsed data type + * @return SUCCESS Parsing success + * @return FAILED parsing failed + * + */ + static domi::Status ParseDataType(const NodeDef *node_src, + const string &attr_src, + domi::tensorflow::DataType &data_type); + + /** + * @ingroup domi_omg + * @brief parsing data types + * @param [in] attr_value attr in NodeDef to be converted + * @param [out] op the parsed information is stored in the properties of the parent class + * @return SUCCESS conversion success + * @return FAILED conversion failed + * + */ + static domi::Status TransTensorDescriptor(const domi::tensorflow::AttrValue &attr_value, + ParserOperator *op, + const uint32_t io, + const string &type = ""); + /* + * @brief 添加NodeDef属性 + * @param [in] attr_name attribute name + * @param [in] attr_value attribute Value Object + * @param [out] node_def + * @return void + * + */ + static void AddNodeAttr(const string &attr_name, + const domi::tensorflow::AttrValue &value, + domi::tensorflow::NodeDef *node_def); + + static domi::Status ClearUnusedParam(ge::ComputeGraphPtr &graph); + + static bool ParseFromAttrValueList(ge::GeTensorDesc &ge_desc, + const domi::tensorflow::AttrValue_ListValue &a_list, + int32_t i, + int32_t &tf_datatype); +}; +} // namespace ge +#endif // OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_ diff --git a/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc b/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc new file mode 100644 index 0000000..4ec74bd --- /dev/null +++ b/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "framework/common/op/ge_op_utils.h" +#include "parser/common/op_def/var_is_initialized_op_op.h" +#include "parser/common/op_parser_factory.h" +#include "parser/tensorflow/tensorflow_op_parser.h" +#include "parser/tensorflow/tensorflow_parser_register.h" + +using namespace ge::parser; + +namespace ge { +Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *op) { + GE_CHECK_NOTNULL(op_src); + const NodeDef *node = reinterpret_cast(op_src); + GE_CHECK_NOTNULL(node); + GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); + op->Name(node->name()); + + return SUCCESS; +} + +DOMI_REGISTER_TENSORFLOW_PARSER(VARISINITIALIZEDOP, VarIsInitializedOpOperator).SetParseParamsFn(ParseParams); + +DOMI_REGISTER_TENSORFLOW_PARSER(ISVARIABLEINITIALIZED, VarIsInitializedOpOperator).SetParseParamsFn(ParseParams); +} // namespace ge diff --git a/parser/tensorflow/tensorflow_variable_v2_parser.cc b/parser/tensorflow/tensorflow_variable_v2_parser.cc new file mode 100644 index 0000000..139dd0e --- /dev/null +++ b/parser/tensorflow/tensorflow_variable_v2_parser.cc @@ -0,0 +1,255 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/op/ge_op_utils.h" +#include "graph/compute_graph.h" +#include "graph/ge_attr_value.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/operator.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/tensor_utils.h" +#include "parser/common/op_def/variable_op.h" +#include "parser/common/op_parser_factory.h" +#include "parser/tensorflow/tensorflow_op_parser.h" +#include "parser/tensorflow/tensorflow_parser_register.h" + +using domi::tensorflow::AttrValue; +using domi::tensorflow::NodeDef; +using domi::tensorflow::TensorShapeProto; +using namespace ge::parser; + +namespace ge { +const std::string SERIALIZE_FORMAT = "serialize_format"; +/* Original definition of variablev2 operator +node_def { + name: "Variable_7/Momentum" + op: "VariableV2" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "_class" + value { + list { + s: "loc:@Variable_7" + } + } + } + attr { + key: "_var_format" + value { + s: "4D" + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 10 + } + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } + } +*/ +static Status ParseSrcType(const domi::tensorflow::NodeDef *node, VariableOperator *op) { + // The upper caller guarantees input params is not empty. + domi::tensorflow::AttrValue attr; + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, VAR_ATTR_DTYPE, attr), + GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.", + VAR_ATTR_DTYPE.c_str(), node->name().c_str()); + return PARAM_INVALID); + + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_TYPE), + "check Attr type failed"); + + domi::tensorflow::DataType tf_type = attr.type(); + ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_type); + + CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED, GELOGE(FAILED, "Data type %s of node %s is not supported.", + DataType_Name(tf_type).c_str(), node->name().c_str()); + return PARAM_INVALID); + + op->SrcType(type); + return SUCCESS; +} + +Status ParseContainer(const domi::tensorflow::NodeDef *node, VariableOperator *op) { + // The upper caller guarantees input params is not empty. + domi::tensorflow::AttrValue attr; + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, VAR_ATTR_CONTAINER, attr), + GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.", + VAR_ATTR_CONTAINER.c_str(), node->name().c_str()); + return PARAM_INVALID); + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING), + "check Attr s failed"); + + std::string container = attr.s(); + + op->Container(container); + return SUCCESS; +} + +Status ParseSharedName(const domi::tensorflow::NodeDef *node, VariableOperator *op) { + // The upper caller guarantees input params is not empty. + domi::tensorflow::AttrValue attr; + CHECK_FALSE_EXEC( + TensorFlowUtil::FindAttrValue(node, VAR_ATTR_SHARED_NAME, attr), + GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.", VAR_ATTR_SHARED_NAME.c_str(), node->name().c_str()); + return PARAM_INVALID); + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING), + "check Attr s failed"); + + std::string shared_name = attr.s(); + op->SharedName(shared_name); + + return SUCCESS; +} + +static Status ParseVarName(const domi::tensorflow::NodeDef *node, VariableOperator *op) { + // The upper caller guarantees input params is not empty. + domi::tensorflow::AttrValue attr; + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, ge::VAR_ATTR_NAME, attr), + GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.", ge::VAR_ATTR_NAME.c_str(), + node->name().c_str()); return PARAM_INVALID); + + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING), + "check Attr s failed"); + + std::string var_name = attr.s(); + op->SharedName(var_name); + + return SUCCESS; +} + +static Status InitOutTensor(const vector &shape, int64_t data_type, ge::GeTensorDesc &out_tensor_desc, + ge::Format format) { + out_tensor_desc.SetFormat(format); + + out_tensor_desc.SetDataType((ge::DataType)data_type); + ge::TensorUtils::SetReuseInput(out_tensor_desc, false); + ge::TensorUtils::SetRealDimCnt(out_tensor_desc, shape.size()); + + out_tensor_desc.SetShape(ge::GeShape(shape)); + int64_t size = out_tensor_desc.GetShape().GetShapeSize(); + size *= sizeof(float); + ge::TensorUtils::SetSize(out_tensor_desc, size); + return SUCCESS; +} + +static Status ParseVarShape(const domi::tensorflow::NodeDef *node, VariableOperator *op) { + // The upper caller guarantees input params is not empty. + string node_src_name = node->name(); + domi::tensorflow::AttrValue attr_value; + + if (!TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, attr_value)) { + GELOGE(FAILED, "In NodeDef %s Attr %s is not exist.", node_src_name.c_str(), + ge::ATTR_NAME_OUTPUT_TENSOR_DESC.c_str()); + return FAILED; + } + + ge::GeTensorDesc infer_shape_domi_desc; + domi::tensorflow::AttrValue_ListValue attr_list = attr_value.list(); + int32_t tf_datatype = 0; + GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(infer_shape_domi_desc, attr_list, 0, tf_datatype), + PARAM_INVALID, "parse domi_desc failed."); + + ge::Format src_format = ge::FORMAT_ND; + + CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, VAR_ATTR_SHAPE, attr_value), + GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.", VAR_ATTR_SHAPE.c_str(), + node->name().c_str()); return PARAM_INVALID); + + GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_SHAPE), + "check Attr s failed"); + + const TensorShapeProto &data_shape = attr_value.shape(); + + vector var_dims_v; + for (int32_t i = 0; i < data_shape.dim_size(); i++) { + var_dims_v.push_back(data_shape.dim(i).size()); + } + + op->VarShape(var_dims_v); + + ge::GeTensorDesc out_tensor_desc; + GE_RETURN_WITH_LOG_IF_ERROR(InitOutTensor(var_dims_v, op->GetVarSrcType(), out_tensor_desc, src_format), + "Init Output Tensor failed"); + + op->OutputTensorDesc(out_tensor_desc); + + return SUCCESS; +} + +static void ParsePlacement(const domi::tensorflow::NodeDef *node, VariableOperator *op) { + // The upper caller guarantees input params is not empty. + string node_src_name = node->name(); + domi::tensorflow::AttrValue attr_value; + GELOGI("Start to parse placement, %s", node_src_name.c_str()); + if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_VARIABLE_PLACEMENT, attr_value)) { + std::string placement = attr_value.s(); + op->Placement(placement); + } +} + +Status ParseParams(const Message *op_src, VariableOperator *op) { + GE_CHECK_NOTNULL(op_src); + const NodeDef *node = reinterpret_cast(op_src); + GE_CHECK_NOTNULL(node); + GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); + string node_op = node->op(); + if (node_op == TEMPORARYVARIABLE) { + GE_RETURN_IF_ERROR(ParseVarName(node, op)); + } else { + GE_RETURN_IF_ERROR(ParseContainer(node, op)); + GE_RETURN_IF_ERROR(ParseSharedName(node, op)); + } + + GE_RETURN_IF_ERROR(ParseSrcType(node, op)); + GE_RETURN_IF_ERROR(ParseVarShape(node, op)); + ParsePlacement(node, op); + + GELOGD("VariabeV2 OP parser params success.op name : %s.", node->name().c_str()); + + return SUCCESS; +} + +DOMI_REGISTER_TENSORFLOW_PARSER(VARIABLE, VariableOperator).SetParseParamsFn(ParseParams); + +DOMI_REGISTER_TENSORFLOW_PARSER(VARHANDLEOP, VariableOperator).SetParseParamsFn(ParseParams); + +DOMI_REGISTER_TENSORFLOW_PARSER(TEMPORARYVARIABLE, VariableOperator).SetParseParamsFn(ParseParams); +} // namespace ge