Browse Source

add parser code

pull/1/MERGE
wqtshg 4 years ago
parent
commit
6743d3bccc
100 changed files with 18385 additions and 0 deletions
  1. +32
    -0
      inc/external/parser/caffe_parser.h
  2. +33
    -0
      inc/external/parser/tensorflow_parser.h
  3. +150
    -0
      parser/CMakeLists.txt
  4. +144
    -0
      parser/caffe/caffe_custom_parser_adapter.cc
  5. +60
    -0
      parser/caffe/caffe_custom_parser_adapter.h
  6. +160
    -0
      parser/caffe/caffe_data_parser.cc
  7. +57
    -0
      parser/caffe/caffe_data_parser.h
  8. +187
    -0
      parser/caffe/caffe_op_parser.cc
  9. +120
    -0
      parser/caffe/caffe_op_parser.h
  10. +2486
    -0
      parser/caffe/caffe_parser.cc
  11. +433
    -0
      parser/caffe/caffe_parser.h
  12. +143
    -0
      parser/caffe/caffe_reshape_parser.cc
  13. +59
    -0
      parser/caffe/caffe_reshape_parser.h
  14. +1821
    -0
      parser/caffe/proto/caffe/caffe.proto
  15. +190
    -0
      parser/caffe/proto/ge_ir.proto
  16. +396
    -0
      parser/caffe/proto/om.proto
  17. +165
    -0
      parser/caffe/proto/task.proto
  18. +76
    -0
      parser/common/CMakeLists.txt
  19. +492
    -0
      parser/common/acl_graph_parser_util.cc
  20. +161
    -0
      parser/common/acl_graph_parser_util.h
  21. +248
    -0
      parser/common/convert/pb2json.cc
  22. +68
    -0
      parser/common/convert/pb2json.h
  23. +212
    -0
      parser/common/data_op_parser.cc
  24. +109
    -0
      parser/common/data_op_parser.h
  25. +155
    -0
      parser/common/model_saver.cc
  26. +55
    -0
      parser/common/model_saver.h
  27. +95
    -0
      parser/common/module.mk
  28. +38
    -0
      parser/common/op_def/arg_op.cc
  29. +36
    -0
      parser/common/op_def/arg_op.h
  30. +45
    -0
      parser/common/op_def/constant_op.cc
  31. +37
    -0
      parser/common/op_def/constant_op.h
  32. +712
    -0
      parser/common/op_def/defs.cc
  33. +45
    -0
      parser/common/op_def/fill_op.cc
  34. +42
    -0
      parser/common/op_def/fill_op.h
  35. +74
    -0
      parser/common/op_def/frameworkop_op.cc
  36. +49
    -0
      parser/common/op_def/frameworkop_op.h
  37. +205
    -0
      parser/common/op_def/ir_pb_converter.cc
  38. +36
    -0
      parser/common/op_def/ir_pb_converter.h
  39. +30
    -0
      parser/common/op_def/no_op_op.cc
  40. +33
    -0
      parser/common/op_def/no_op_op.h
  41. +215
    -0
      parser/common/op_def/op_schema.cc
  42. +175
    -0
      parser/common/op_def/op_schema.h
  43. +200
    -0
      parser/common/op_def/operator.cc
  44. +117
    -0
      parser/common/op_def/operator.h
  45. +34
    -0
      parser/common/op_def/ref_switch_op.cc
  46. +34
    -0
      parser/common/op_def/ref_switch_op.h
  47. +56
    -0
      parser/common/op_def/shape_n_op.cc
  48. +40
    -0
      parser/common/op_def/shape_n_op.h
  49. +37
    -0
      parser/common/op_def/var_is_initialized_op_op.cc
  50. +34
    -0
      parser/common/op_def/var_is_initialized_op_op.h
  51. +57
    -0
      parser/common/op_def/variable_op.cc
  52. +46
    -0
      parser/common/op_def/variable_op.h
  53. +159
    -0
      parser/common/op_map.cc
  54. +45
    -0
      parser/common/op_map.h
  55. +117
    -0
      parser/common/op_parser_factory.cc
  56. +198
    -0
      parser/common/op_parser_factory.h
  57. +76
    -0
      parser/common/parser_api.cc
  58. +81
    -0
      parser/common/parser_factory.cc
  59. +1270
    -0
      parser/common/parser_fp16_t.cc
  60. +653
    -0
      parser/common/parser_fp16_t.h
  61. +24
    -0
      parser/common/parser_inner_ctx.cc
  62. +494
    -0
      parser/common/parser_types.cc
  63. +83
    -0
      parser/common/pass_manager.cc
  64. +76
    -0
      parser/common/pass_manager.h
  65. +287
    -0
      parser/common/pre_checker.cc
  66. +194
    -0
      parser/common/pre_checker.h
  67. +190
    -0
      parser/common/proto/ge_ir.proto
  68. +136
    -0
      parser/common/proto/insert_op.proto
  69. +396
    -0
      parser/common/proto/om.proto
  70. +62
    -0
      parser/common/proto/tensorflow/attr_value.proto
  71. +100
    -0
      parser/common/proto/tensorflow/function.proto
  72. +56
    -0
      parser/common/proto/tensorflow/graph.proto
  73. +63
    -0
      parser/common/proto/tensorflow/node_def.proto
  74. +164
    -0
      parser/common/proto/tensorflow/op_def.proto
  75. +29
    -0
      parser/common/proto/tensorflow/resource_handle.proto
  76. +94
    -0
      parser/common/proto/tensorflow/tensor.proto
  77. +45
    -0
      parser/common/proto/tensorflow/tensor_shape.proto
  78. +74
    -0
      parser/common/proto/tensorflow/types.proto
  79. +31
    -0
      parser/common/proto/tensorflow/versions.proto
  80. +528
    -0
      parser/common/proto_file_parser.cc
  81. +63
    -0
      parser/common/proto_file_parser.h
  82. +132
    -0
      parser/common/register_tbe.cc
  83. +34
    -0
      parser/common/register_tbe.h
  84. +212
    -0
      parser/common/tbe_plugin_loader.cc
  85. +62
    -0
      parser/common/tbe_plugin_loader.h
  86. +78
    -0
      parser/common/thread_pool.cc
  87. +83
    -0
      parser/common/thread_pool.h
  88. +307
    -0
      parser/common/tuple.h
  89. +53
    -0
      parser/common/types_map.h
  90. +32
    -0
      parser/func_to_graph/CMakeLists.txt
  91. +279
    -0
      parser/func_to_graph/func2graph.py
  92. +9
    -0
      parser/func_to_graph/module.mk
  93. +62
    -0
      parser/func_to_graph/proto/attr_value.proto
  94. +100
    -0
      parser/func_to_graph/proto/function.proto
  95. +56
    -0
      parser/func_to_graph/proto/graph.proto
  96. +14
    -0
      parser/func_to_graph/proto/graph_library.proto
  97. +63
    -0
      parser/func_to_graph/proto/node_def.proto
  98. +164
    -0
      parser/func_to_graph/proto/op_def.proto
  99. +29
    -0
      parser/func_to_graph/proto/resource_handle.proto
  100. +94
    -0
      parser/func_to_graph/proto/tensor.proto

+ 32
- 0
inc/external/parser/caffe_parser.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_ACL_GRAPH_CAFFE_H_
#define INC_EXTERNAL_ACL_GRAPH_CAFFE_H_

#include <memory>
#include <string>
#include <vector>

#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/graph.h"

namespace ge {
graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph);
} // namespace ge

#endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_

+ 33
- 0
inc/external/parser/tensorflow_parser.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_
#define INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_

#include <atomic>
#include <memory>
#include <string>
#include <vector>

#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/graph.h"

namespace ge {
graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph);
} // namespace ge

#endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_

+ 150
- 0
parser/CMakeLists.txt View File

@@ -0,0 +1,150 @@
set(PROTO_LIST
"${TOP_DIR}/inc/register/proto/tensorflow/graph_library.proto"
)

set(SRC_LIST
"tensorflow/tensorflow_arg_parser.cc"
"tensorflow/tensorflow_auto_mapping_parser_adapter.cc"
"tensorflow/tensorflow_constant_parser.cc"
"tensorflow/tensorflow_data_parser.cc"
"tensorflow/tensorflow_enter_parser.cc"
"tensorflow/tensorflow_fill_parser.cc"
"tensorflow/tensorflow_frameworkop_parser.cc"
"tensorflow/tensorflow_fusionop_util.cc"
"tensorflow/tensorflow_identity_parser.cc"
"tensorflow/tensorflow_merge_parser.cc"
"tensorflow/tensorflow_no_op_parser.cc"
"tensorflow/tensorflow_parser.cc"
"tensorflow/tensorflow_ref_switch_parser.cc"
"tensorflow/tensorflow_reshape_parser.cc"
"tensorflow/tensorflow_shape_n_parser.cc"
"tensorflow/tensorflow_squeeze_parser.cc"
"tensorflow/tensorflow_var_is_initialized_op_parser.cc"
"tensorflow/tensorflow_variable_v2_parser.cc"
"caffe/caffe_parser.cc"
"caffe/caffe_data_parser.cc"
"caffe/caffe_reshape_parser.cc"
"caffe/caffe_custom_parser_adapter.cc"
"caffe/caffe_op_parser.cc"
"tensorflow/scope/scope_pass_manager.cc"
"tensorflow/graph_functiondef.cc"
"tensorflow/graph_optimizer.cc"
"tensorflow/iterator_fusion_pass.cc"
"common/op_def/arg_op.cc"
"common/op_def/constant_op.cc"
"common/op_def/fill_op.cc"
"common/op_def/frameworkop_op.cc"
"common/op_def/no_op_op.cc"
"common/op_def/ref_switch_op.cc"
"common/op_def/shape_n_op.cc"
"common/op_def/var_is_initialized_op_op.cc"
"common/op_def/variable_op.cc"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})

############ libfmk_parser.so ############
add_library(fmk_parser SHARED ${SRC_LIST} ${PROTO_SRCS})

target_compile_options(fmk_parser PRIVATE
-Werror
)

target_compile_definitions(fmk_parser PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
)

target_include_directories(fmk_parser PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${TOP_DIR}/framework/domi
${TOP_DIR}/framework/domi/common
${TOP_DIR}/framework/domi/parser
${TOP_DIR}/inc
${TOP_DIR}/inc/external
${TOP_DIR}/inc/external/parser
${TOP_DIR}/inc/external/graph
${TOP_DIR}/inc/framework
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
)

target_link_libraries(fmk_parser
$<BUILD_INTERFACE:intf_pub>
-Wl,--no-as-needed
protobuf
error_manager
parser_common
graph
register
_caffe_parser
c_sec
slog
mmpa
-Wl,--as-needed
json
-lrt
)

##################################################################
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/stub_tensorflow_parser.cc
${CMAKE_CURRENT_BINARY_DIR}/stub_caffe_parser.cc
COMMAND echo "Generating stub files."
&& ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR}
&& mv tensorflow_parser.cc stub_tensorflow_parser.cc
&& mv caffe_parser.cc stub_caffe_parser.cc
&& echo "Generating stub files end."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ../stub/gen_stubapi.py ${TOP_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR}
)

##################################################################

############ stub/libfmk_parser.so ############
add_library(fmk_parser_stub SHARED
${CMAKE_CURRENT_BINARY_DIR}/stub_tensorflow_parser.cc
${CMAKE_CURRENT_BINARY_DIR}/stub_caffe_parser.cc
)

target_compile_options(fmk_parser_stub PRIVATE
-O2
)

target_compile_definitions(fmk_parser_stub PRIVATE
$<$<STREQUAL:${PRODUCT_SIDE},host>:FMK_SUPPORT_DUMP>
PROTOBUF_INLINE_NOT_IN_HEADERS=0
REUSE_MEMORY=1
FMK_HOST_INFER
)

target_include_directories(fmk_parser_stub PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${TOP_DIR}/inc
${TOP_DIR}/inc/external
${TOP_DIR}/inc/external/parser
${TOP_DIR}/inc/external/graph
${TOP_DIR}/inc/framework
${CMAKE_BINARY_DIR}
${CMAKE_CURRENT_BINARY_DIR}
)

target_link_libraries(fmk_parser_stub PRIVATE
$<BUILD_INTERFACE:intf_pub>
)

set_target_properties(fmk_parser_stub PROPERTIES
OUTPUT_NAME fmk_parser
LIBRARY_OUTPUT_DIRECTORY stub
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS fmk_parser OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)

install(TARGETS fmk_parser_stub OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/stub
)

+ 144
- 0
parser/caffe/caffe_custom_parser_adapter.cc View File

@@ -0,0 +1,144 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/caffe/caffe_custom_parser_adapter.h"
#include <memory>
#include <vector>
#include "common/debug/log.h"
#include "common/ge/ge_util.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/omg_inner_types.h"
#include "framework/omg/parser/parser_types.h"
#include "graph/utils/graph_utils.h"
#include "parser/common/op_parser_factory.h"
#include "register/op_registry.h"

using domi::ParseParamByOpFunc;
using domi::ParseParamFunc;
using std::vector;

namespace ge {
namespace {
const char *const kConvolution = "Convolution";
const char *const kInnerProduct = "InnerProduct";
const int64_t kDimDedaultValue = 1;
const int kBlobIndexOne = 1;
} // namespace

Status CaffeCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
GE_CHECK_NOTNULL(op_src);
const LayerParameter *layer = reinterpret_cast<const LayerParameter *>(op_src);
GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str());
GE_CHECK_NOTNULL(op_dest);

ParseParamFunc customOpParser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest->GetType(), layer->type());
GE_CHECK_NOTNULL(customOpParser);

op_dest->SetName(layer->name());
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
GE_CHK_BOOL_RET_STATUS(customOpParser(op_src, op) == SUCCESS, FAILED, "Custom parser params failed");
return SUCCESS;
}

Status CaffeCustomParserAdapter::ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest) {
GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ", op_src.GetName().c_str(),
op_src.GetOpType().c_str());
GE_CHECK_NOTNULL(op_dest);

ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType());
GE_CHECK_NOTNULL(custom_op_parser);

op_dest->SetName(op_src.GetName());
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);

GE_CHK_BOOL_RET_STATUS(custom_op_parser(op_src, op) == SUCCESS, FAILED, "Custom parser params failed");
return SUCCESS;
}

Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto op = node->GetOpDesc();
GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
const LayerParameter *layer = reinterpret_cast<const LayerParameter *>(op_src);

GE_CHK_BOOL_RET_STATUS(nullptr != layer, FAILED, "Dynamic cast op_src to LayerParameter failed");
GELOGI("layer: %s blobs_size: %d bottom_size: %d", layer->name().c_str(), layer->blobs_size(), layer->bottom_size());
if (layer->blobs_size() == 0) {
return SUCCESS;
}

bool bias_en = false;
int start_pos = layer->bottom_size();
for (int i = 0; i < layer->blobs_size(); ++i) {
ge::GeTensorPtr weight = ge::MakeShared<ge::GeTensor>();
GE_CHECK_NOTNULL(weight);
GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), "Convert blobs(%d) for layer %s failed", i,
layer->name().c_str());
GE_IF_BOOL_EXEC(layer->type() == kConvolution && i == kBlobIndexOne,
const ConvolutionParameter &conv_params_src = layer->convolution_param();
bias_en = conv_params_src.bias_term(););
GE_IF_BOOL_EXEC(layer->type() == kInnerProduct && i == kBlobIndexOne,
const InnerProductParameter &fc_params_src = layer->inner_product_param();
bias_en = fc_params_src.bias_term(););
auto bias_shape = weight->MutableTensorDesc().GetShape();
// The num 0, 1, 2, 3 represet the dim index.
bool matched = bias_en && bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) &&
bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1 && bias_shape.GetDim(2) == 1;
if (matched) {
weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(3)}));
}
matched = layer->type() == kInnerProduct && i == 0 &&
bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) &&
bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1;
if (matched) {
weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(2), bias_shape.GetDim(3)}));
}

// construct const node
auto const_opdesc = ge::OpDescUtils::CreateConstOp(weight); // use org weight before SetWeights Overwrite
GE_CHECK_NOTNULL(const_opdesc);
auto owner_graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(owner_graph);

// add edge from const to current node
auto const_node = owner_graph->AddNodeFront(const_opdesc);
GE_CHECK_NOTNULL(const_node);
auto index = start_pos + i;
auto valid_input_name = op->GetValidInputNameByIndex(static_cast<uint32_t>(index));
if (valid_input_name.empty()) {
if (node->AddLinkFrom(static_cast<const uint32_t &>(index), const_node) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "AddEdge failed of from Node %s output to Node %s input %d", const_node->GetName().c_str(),
node->GetName().c_str(), index);
}
} else {
if (node->AddLinkFrom(valid_input_name, const_node) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "AddEdge failed of from Node %s output to Node %s input %s", const_node->GetName().c_str(),
node->GetName().c_str(), valid_input_name.c_str());
}
}

std::vector<ge::NodePtr> original_nodes;
ge::GraphUtils::RecordOriginalNames(original_nodes, const_node);
}
GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, "tvm_origin_input_num", layer->bottom_size())),
GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return

return SUCCESS;
}
REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(CAFFE, CaffeCustomParserAdapter);
} // namespace ge

+ 60
- 0
parser/caffe/caffe_custom_parser_adapter.h View File

@@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_
#define PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_

#include "parser/caffe/caffe_op_parser.h"

namespace ge {
class CaffeCustomParserAdapter : public CaffeOpParser {
public:
/**
* @ingroup domi_omg
* @brief parse params of the operation
* @param [in] op_src params to be parsed
* @param [out] op_dest params after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
* @author
*/
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;

/**
* @ingroup domi_omg
* @brief parse params of the operation
* @param [in] op_src params to be parsed
* @param [out] op_dest params after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
* @author
*/
Status ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest);

/**
* @ingroup domi_omg
* @brief parse weight of the operation
* @param [in] op_src params to be parsed
* @param [out] node params after parsing
* @return SUCCESS parse successfullyparse failed
* @return FAILED
* @author
*/
Status ParseWeights(const Message *op_src, ge::NodePtr &node) override;
};
} // namespace ge

#endif // PARSER_CAFFE_CAFFE_CUSTOM_PARSER_ADAPTER_H_

+ 160
- 0
parser/caffe/caffe_data_parser.cc View File

@@ -0,0 +1,160 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/caffe/caffe_data_parser.h"
#include <unordered_map>
#include <utility>
#include "common/debug/log.h"
#include "framework/omg/parser/parser_types.h"
#include "common/util.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "parser/common/op_parser_factory.h"

using namespace ge::parser;

namespace ge {
Status CaffeDataParser::GetOutputDesc(const string &name, int dim_size, const std::vector<int64_t> &input_dims,
ge::OpDescPtr &op) {
GE_CHECK_NOTNULL(op);
GELOGI("The input dim size is %zu in layer %s.", input_dims.size(), name.c_str());

// Caffe default data type is float32
GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, DATA_ATTR_NAME_DATA_TYPE, ge::DT_FLOAT)),
GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return

// Initialize input and output description of OP according to input_dims information
GE_RETURN_WITH_LOG_IF_ERROR(ParseShape(input_dims, op), "data layer %s ParseShape failed", name.c_str());

return SUCCESS;
}

Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
const domi::caffe::LayerParameter *layer = DOMI_DYNAMIC_CAST<const domi::caffe::LayerParameter *>(op_src);
GE_CHECK_NOTNULL(layer);
GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str());

if (layer->type() == ge::parser::INPUT_TYPE) {
GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed",
layer->name().c_str(), layer->type().c_str());
} else if(layer->type() == ge::parser::DUMMY_DATA) {
GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed",
layer->name().c_str(), layer->type().c_str());
} else {
ErrorManager::GetInstance().ATCReportErrMessage("E11030");
GELOGE(PARAM_INVALID, "Caffe prototxt has no optype [Input]");
return FAILED;
}
return SUCCESS;
}

Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op) {
if (layer->has_input_param()) {
const domi::caffe::InputParameter &input_param = layer->input_param();
if (input_param.shape_size() == 0) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11027", {"layername", "layertype"}, {layer->name(), layer->type()});
GELOGE(PARAM_INVALID,
"input_param shape size is zero, caffe layer name [%s], layer type [%s].",
layer->name().c_str(), layer->type().c_str());
return FAILED;
}
for (int i = 0; i < input_param.shape_size(); i++) {
const domi::caffe::BlobShape &blob_shape = input_param.shape(i);
vector<int64_t> shape;
unordered_map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims;
std::vector<int64_t> model_dims;
for (auto &blob_shape_dim_temp : blob_shape.dim()) {
model_dims.push_back(blob_shape_dim_temp);
}
string name = layer->name();
GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name));
GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s",
name.c_str());
}
} else {
// Get from external input
const ge::ParserContext &ctx = GetParserContext();
std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
string name = layer->name();
auto search = input_dims.find(name);
if (search == input_dims.end()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11028", {"layername", "layertype"}, {layer->name(), layer->type()});
GELOGE(PARAM_INVALID,
"Caffe prototxt has no input_param or user should set --input_shape in atc parameter, "
"caffe layer name [%s], layer type [%s].", layer->name().c_str(), layer->type().c_str());
return FAILED;
}
std::vector<int64_t> dims = search->second;
GE_CHK_STATUS_RET(GetOutputDesc(name, dims.size(), dims, op), "Get output desc failed in layer %s.",
name.c_str());
}
return SUCCESS;
}

Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op) {
if (layer->has_dummy_data_param()) {
const domi::caffe::DummyDataParameter &dummy_data_param = layer->dummy_data_param();
if (dummy_data_param.shape_size() == 0) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11027", {"layername", "layertype"}, {layer->name(), layer->type()});
GELOGE(PARAM_INVALID,
"input_param shape size is zero, caffe layer name [%s], layer type [%s].",
layer->name().c_str(), layer->type().c_str());
return FAILED;
}
for (int i = 0; i < dummy_data_param.shape_size(); i++) {
const domi::caffe::BlobShape &blob_shape = dummy_data_param.shape(i);

vector<int64_t> shape;
unordered_map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims;
std::vector<int64_t> model_dims;
for (auto &blob_shape_dim_temp : blob_shape.dim()) {
model_dims.push_back(blob_shape_dim_temp);
}

string name = layer->name();
GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name));
GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s",
name.c_str());
}
} else {
// Get from external input
const ge::ParserContext &ctx = GetParserContext();
std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims;
string name = layer->name();
auto search = input_dims.find(name);
if (search == input_dims.end()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11028", {"layername", "layertype"}, {layer->name(), layer->type()});
GELOGE(PARAM_INVALID,
"Caffe prototxt has no input_param or user should set --input_shape in atc parameter, "
"caffe layer name [%s], layer type [%s].", layer->name().c_str(), layer->type().c_str());
return FAILED;
}
std::vector<int64_t> dims = search->second;
GE_CHK_STATUS_RET(GetOutputDesc(name, dims.size(), dims, op), "Get output desc failed in layer %s.",
name.c_str());
}
return SUCCESS;
}

REGISTER_OP_PARSER_CREATOR(CAFFE, DATA, CaffeDataParser);
} // namespace ge

+ 57
- 0
parser/caffe/caffe_data_parser.h View File

@@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_CAFFE_CAFFE_DATA_PARSER_H_
#define PARSER_CAFFE_CAFFE_DATA_PARSER_H_

#include <string>
#include <vector>
#include "parser/caffe/caffe_op_parser.h"
#include "parser/common/data_op_parser.h"

namespace ge {
class CaffeDataParser : public CaffeOpParser, public DataOpParser {
public:
/**
* @ingroup domi_omg
* @brief parse params of the operation
* @param [in] op_src params to be parsed
* @param [out] graph params after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override;

private:
/**
* @ingroup domi_omg
* @brief Get the output dimension according to the input dimension
* @param [in] name the name of the input layer
* @param [in] input_dims the dimension of the input layer
* @param [out] op_def op after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status GetOutputDesc(const std::string &name, int dim_size,
const std::vector<int64_t> &input_dims, ge::OpDescPtr &op);

// caffe data layer type could be type of `Input` or `DummyData`
Status ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op);
Status ParseParamsForDummyData(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op);
};
} // namespace ge

#endif // PARSER_CAFFE_CAFFE_DATA_PARSER_H_

+ 187
- 0
parser/caffe/caffe_op_parser.cc View File

@@ -0,0 +1,187 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/caffe/caffe_op_parser.h"
#include <memory>
#include "parser/common/op_parser_factory.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/omg/parser/parser_types.h"

using namespace ge::parser;

using domi::CAFFE;

namespace ge {
Status CaffeOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { return SUCCESS; }

Status CaffeOpParser::ParseWeights(const Message *op_src, ge::NodePtr &node) { return SUCCESS; }

Status CaffeOpParser::AddConstInput(ge::NodePtr &node) { return SUCCESS; }

void CaffeOpParser::ConvertShape(const BlobProto &proto, std::vector<int64_t> &shape) {
shape.clear();

if (proto.has_num() || proto.has_channels() || proto.has_height() || proto.has_width()) {
// Compatible with old formats, shape description: (num, channels, height, width)
shape.push_back(proto.num());
shape.push_back(proto.channels());
shape.push_back(proto.height());
shape.push_back(proto.width());
} else {
// The shape of the new format is described with "repeated Int64 dim"
for (int i = 0; i < proto.shape().dim_size(); ++i) {
shape.push_back(proto.shape().dim(i));
}
}
}

Status CaffeOpParser::ConvertWeight(const BlobProto &proto, const string &lay_name, ge::GeTensorPtr &weight) {
GE_CHECK_NOTNULL(weight);
std::vector<int64_t> shape_vec;
ConvertShape(proto, shape_vec);
ge::GeShape shape(shape_vec);
// Calculate the number of data in weight
int count = 1;
for (size_t i = 0; i < shape.GetDimNum(); ++i) {
int dim = shape.GetDim(i);
if (dim <= 0) {
GELOGE(FAILED, "Convert weight fail, Blob size invalid");
return FAILED;
}

if (dim >= INT64_MAX / count) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11033", {"opname", "blobsize", "reason"},
{lay_name, std::to_string(dim) + "*" + std::to_string(count),
"it exceeds INT64_MAX[" + std::to_string(INT64_MAX) + "]"});
GELOGE(FAILED, "Convert weight fail, Blob size exceeds INT64_MAX, dim:%d, count:%d", dim, count);
return FAILED;
}

count *= dim;
}
return ParseWeightType(proto, shape, count, lay_name, weight);
}

Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape &shape, int size,
const string &lay_name, ge::GeTensorPtr &weight) {
// Extract weight data and store it in weightdef by float type
GE_CHECK_NOTNULL(weight);
ge::DataType dtype = ge::DT_FLOAT;
if (proto.double_data_size() > 0) {
// Convert by double type
if (size != proto.double_data_size()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11033", {"opname", "blobsize", "reason"},
{lay_name, std::to_string(proto.double_data_size()),
"it does not match shape size[" + std::to_string(size) + "]"});
GELOGE(FAILED, "Convert weight fail, Blob size does not match shape size, shape size:%d, blob size:%d", size,
proto.double_data_size());
return FAILED;
}
std::unique_ptr<float[]> buf(new (std::nothrow) float[size]());
GE_CHECK_NOTNULL(buf);
for (int i = 0; i < size; ++i) {
buf[i] = proto.double_data(i);
}
GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return
} else if (proto.int8_data().length() > 0) {
if (size != static_cast<int>(proto.int8_data().length())) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11033", {"opname", "blobsize", "reason"},
{lay_name, std::to_string(proto.int8_data().length()),
"it does not match shape size[" + std::to_string(size) + "]"});
GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%ld", size,
proto.int8_data().length());
return FAILED;
}
const char *data_ptr = proto.int8_data().data();
GE_CHECK_NOTNULL(data_ptr);
GE_IF_BOOL_EXEC(
weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return
dtype = ge::DT_INT8;
} else if (proto.int32_data_size() > 0) {
if (size != proto.int32_data_size()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11033", {"opname", "blobsize", "reason"},
{lay_name, std::to_string(proto.int32_data_size()),
"it does not match shape size[" + std::to_string(size) + "]"});
GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%d", size,
proto.int32_data_size());
return FAILED;
}
std::unique_ptr<int32_t[]> int32_weight_buf(new (std::nothrow) int32_t[size]());
GE_CHECK_NOTNULL(int32_weight_buf);
for (int i = 0; i < size; ++i) {
int32_weight_buf[i] = proto.int32_data(i);
}
GE_IF_BOOL_EXEC(
weight->SetData(reinterpret_cast<uint8_t *>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return
dtype = ge::DT_INT32;
} else if (proto.uint64_data_size() > 0) {
if (size != proto.uint64_data_size()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11033", {"opname", "blobsize", "reason"},
{lay_name, std::to_string(proto.uint64_data_size()),
"it does not match shape size[" + std::to_string(size) + "]"});
GELOGE(FAILED, "Convert weight failed, Blob size does not match shape size, shape size:%d, blob size:%d", size,
proto.uint64_data_size());
return FAILED;
}
std::unique_ptr<uint64_t[]> uint64_weight_buf(new (std::nothrow) uint64_t[size]());
GE_CHECK_NOTNULL(uint64_weight_buf);
for (int i = 0; i < size; ++i) {
uint64_weight_buf[i] = proto.uint64_data(i);
}
GE_IF_BOOL_EXEC(weight->SetData(reinterpret_cast<uint8_t *>(uint64_weight_buf.get()), size * sizeof(uint64_t)) !=
ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return
dtype = ge::DT_UINT64;
} else {
// Convert by float type
if (size != proto.data_size()) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E11033", {"opname", "blobsize", "reason"},
{lay_name, std::to_string(proto.data_size()),
"it does not match shape size[" + std::to_string(size) + "]"});
GELOGE(FAILED, "Convert weight fail, Blob size does not match shape size, shape size:%d, blob.data_size:%d", size,
proto.data_size());
return FAILED;
}
const float *data_ptr = proto.data().data();
GE_CHECK_NOTNULL(data_ptr);
GE_IF_BOOL_EXEC(
weight->SetData(reinterpret_cast<const uint8_t *>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return
}
ge::GeTensorDesc weight_desc = ge::GeTensorDesc();
weight_desc.Update(shape, ge::FORMAT_NCHW, dtype);
weight->SetTensorDesc(weight_desc);
return SUCCESS;
}

// Dropout's corresponding op_parser is registered as caffeopparser, optimized in optimization stage.
REGISTER_OP_PARSER_CREATOR(CAFFE, DROPOUT, CaffeOpParser);

// A new operator added by framework in OM model is used to
// collect and arrange all outputs in the order of the original model's output
// Net output operator does not need special processing in the parse stage,
// and directly registers in the op_parser file
REGISTER_OP_PARSER_CREATOR(CAFFE, NETOUTPUT, CaffeOpParser);
} // namespace ge

+ 120
- 0
parser/caffe/caffe_op_parser.h View File

@@ -0,0 +1,120 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_CAFFE_CAFFE_OP_PARSER_H_
#define PARSER_CAFFE_CAFFE_OP_PARSER_H_

#include <vector>
#include "graph/debug/ge_attr_define.h"
#include "common/util.h"
#include "graph/compute_graph.h"
#include "graph/ge_attr_value.h"
#include "graph/ge_tensor.h"
#include "graph/op_desc.h"
#include "graph/operator.h"
#include "graph/types.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/tensor_utils.h"
#include "omg/parser/op_parser.h"
#include "proto/caffe/caffe.pb.h"

using domi::caffe::ArgMaxParameter;
using domi::caffe::BatchNormParameter;
using domi::caffe::BlobProto;
using domi::caffe::BlobShape;
using domi::caffe::ConcatParameter;
using domi::caffe::ConvolutionParameter;
using domi::caffe::DetectionOutputParameter;
using domi::caffe::EltwiseParameter;
using domi::caffe::FillerParameter;
using domi::caffe::InnerProductParameter;
using domi::caffe::LayerParameter;
using domi::caffe::PoolingParameter;
using domi::caffe::PReLUParameter;
using domi::caffe::ReshapeParameter;
using domi::caffe::ROIAlignParameter;
using domi::caffe::TanHParameter;
using domi::caffe::UpsampleParameter;

namespace ge {
/**
* @ingroup ge_omg
* @brief Used to parse Caffe operator information
*/
class CaffeOpParser : public OpParser {
public:
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;

Status ParseParams(const Message *op_src, ge::Operator &op_dest) override {
return domi::SUCCESS;
}

/**
* @ingroup ge_omg
* @brief parse weight information of the operation
* @param [in] op_src Weight data to be parsed
* @param [out] graph Weight data after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
* @author
*/
Status ParseWeights(const Message *op_src, ge::NodePtr &node) override;

/**
* @ingroup ge_omg
* @brief add const input node
* @param [in] node to add const input
* @param [out] node after add const input
* @return SUCCESS add const input successfully
* @return FAILED add const input failed
* @author
*/
virtual Status AddConstInput(ge::NodePtr &node);

protected:
/**
* @ingroup ge_omg
* @brief Convert blob proto to weight definition
* @param [in] proto Weight data to be parsed
* @param [out] weight Weight data after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
static Status ConvertWeight(const BlobProto &proto, const string &lay_name, ge::GeTensorPtr &weight);

/**
* @ingroup ge_omg
* @brief Convert blob proto to shape definition
* @param [in] proto Shape information before conversion
* @param [out] shape Save converted shape information
*/
static void ConvertShape(const BlobProto &proto, std::vector<int64_t> &shape);

private:
/**
* @ingroup ge_omg
* @brief Convert blob proto to weight definition
* @param [in] proto Weight data to be parsed
* @param [out] weight Weight data after parsing
* @return SUCCESS parse weight type successfully
* @return FAILED parse failed
*/
static Status ParseWeightType(const BlobProto &proto, const ge::GeShape &shape,
int size, const string &lay_name, ge::GeTensorPtr &weight);
};
} // namespace ge

#endif // PARSER_CAFFE_CAFFE_OP_PARSER_H_

+ 2486
- 0
parser/caffe/caffe_parser.cc
File diff suppressed because it is too large
View File


+ 433
- 0
parser/caffe/caffe_parser.h View File

@@ -0,0 +1,433 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_CAFFE_CAFFE_PARSER_H_
#define PARSER_CAFFE_CAFFE_PARSER_H_

#include <map>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "external/graph/operator.h"
#include "omg/parser/op_parser.h"
#include "omg/parser/model_parser.h"
#include "omg/parser/weights_parser.h"
#include "proto/caffe/caffe.pb.h"
#include "proto/om.pb.h"

namespace ge {
using domi::caffe::NetParameter;
using std::map;
using std::set;
using std::string;
using std::unordered_map;
using std::vector;
static std::map<std::vector<std::string>, std::vector<std::string>> params_share_map;

class CaffeModelParser : public domi::ModelParser {
public:
CaffeModelParser() {}
virtual ~CaffeModelParser() {}

/**
* @ingroup domi_omg
* @brief Parse the relevant data from the model file and save it to graph
* @param [in] file Path of model file
* @param [in|out] graph graph for saving model information
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status Parse(const char *file, ge::Graph &graph) override;
Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;

/**
* @ingroup domi_omg
* @brief Convert model files to JSON format
* @param [in] model_file Path of model file
* @param [out] json_file Converted JSON file path
* @return SUCCESS parse successfully
* @return others parse failed
*/
Status ToJson(const char *model_file, const char *json_file) override;
/**
* @ingroup domi_omg
* @brief Parse the relevant data from the model file and save it to graph
* @param [in] graph_def input tensorflow model
* @param [in|out] graph graph for saving model information
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override;
Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, domi::GetGraphCallback callback,
ge::ComputeGraphPtr &graph) override;
/*
* @ingroup domi_omg
* @brief Mapping CAFFE's datatype to GE's datatype
* @param [in] type, datatype types of operators in CAFFE networks
* @return ge::DataType
*/
ge::DataType ConvertToGeDataType(const uint32_t type) override { return ge::DT_FLOAT; }

Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override {
return domi::SUCCESS;
}

private:
Status Parse(const char *file, ge::ComputeGraphPtr &graph);

/**
* @ingroup domi_omg
* @brief Add the Layer in the model to the PreChecker
* @param [in] net caffe net information
* @return SUCCESS build successfully
* @return FAILED build failed
*/
Status PreCheck(const domi::caffe::NetParameter &net);

/**
* @ingroup domi_omg
* @brief Parsing input related information from model files
* @param [in] proto_message caffe net information
* @param [in|out] net_input_name Used to store the acquired input name information
* @param [in|out] net_input_data Used to store the acquired input data information
* @return SUCCESS build successfully
* @return FAILED build failed
*/
Status ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag);

/*
* @ingroup domi_omg
* @brief Parse model by custom proto and save info to operators
* @param [in] model_path, file path of model(prototxt file)
* @param [in] custom_proto, file path of custom proto
* @param [in] caffe_proto, file path of caffe proto
* @param [out] operators, operators saving custom info
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status CustomProtoParse(const char *model_path, const string &custom_proto, const string &caffe_proto,
std::vector<ge::Operator> &operators);

/*
* @ingroup domi_omg
* @brief Parse model by custom proto and save info to operators
* @param [in] model_path, file path of model(prototxt file)
* @param [in] custom_proto_path, file path of custom proto
* @param [in] custom_proto_name, custom proto name
* @param [out] operators, operators saving custom info
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path,
const string &custom_proto_name, std::vector<ge::Operator> &operators);

/*
* @ingroup domi_omg
* @brief Parse caffe proto file
* @param [in] proto_file, file path of caffe proto
* @param [out] identifier_op_map, identifer and op map
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status ParseProtoFile(const string &proto_file, std::map<int32_t, string> &identifier_op_map);

/*
* @ingroup domi_omg
* @brief Save identifier op map info
* @param [in] line, line of proto
* @param [out] identifier_op_map, identifer and op map
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status SaveIdentifierOpMapInfo(const string &line, std::map<int32_t, string> &identifier_op_map);

/*
* @ingroup domi_omg
* @brief Get op identifier
* @param [in] line, line of proto
* @param [out] identifier, identifer of op
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status GetIdentifier(const std::string &line, int32_t &identifier);
/*
* @ingroup domi_omg
* @brief Read caffe model and shield google warning
* @param [in] model_path, file path of model(prototxt file)
* @param [out] message, message saving custom info
* @return SUCCESS read file successfully
* @return FAILED read file failed
*/
Status ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message);

/*
* @ingroup domi_omg
* @brief Read caffe model and save it to message
* @param [in] model_path, file path of model(prototxt file)
* @param [out] message, message saving custom info
* @return SUCCESS read file successfully
* @return FAILED read file failed
*/
Status ReadCaffeModelFromText(const char *model_path, google::protobuf::Message *message);

/*
* @ingroup domi_omg
* @brief Parse layer message and save custom info to operators
* @param [in] layer_descriptor, layer description of message
* @param [in] message, message of model
* @param [out] operators, operators saving custom info
* @return SUCCESS parse layer successfully
* @return FAILED parse layer failed
*/
Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message, std::vector<ge::Operator> &operators);

/*
* @ingroup domi_omg
* @brief Create custom operator by op_name and op_type
* @param [in] op_name, name of operator
* @param [in] op_type, type of operator
* @param [in] message, message of model
* @param [in] index, index of field
* @param [out] operators, operators saving custom info
* @return SUCCESS create operator successfully
* @return FAILED create operator failed
*/
Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message,
int index, std::vector<ge::Operator> &operators);

/*
* @ingroup domi_omg
* @brief Parse message and set operator attrs
* @param [in] message, message of model
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info
* @return SUCCESS parse message successfully
* @return FAILED parse message failed
*/
Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops);

/*
* @ingroup domi_omg
* @brief Parse field and set operator attrs
* @param [in] reflection, reflection of message
* @param [in] message, message of model
* @param [in] field, field of message
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info
* @return SUCCESS parse field successfully
* @return FAILED parse field failed
*/
Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);

/*
* @ingroup domi_omg
* @brief Parse repeated field and set operator attrs
* @param [in] reflection, reflection of message
* @param [in] message, message of model
* @param [in] field, field of message
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info by vector
* @return SUCCESS parse field successfully
* @return FAILED parse field failed
*/
Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);

/**
* @ingroup domi_omg
* @brief Add blob information to the bottom_blobs_map and top_blobs_map_
* @param [in] layer layer information
* @param [in|out] inplace_blob_name_remapping save blob information
* @return Status
*/
Status AddBlobsToMap(const domi::caffe::LayerParameter &layer,
std::map<std::string, std::string> &inplace_blob_name_remapping);
/**
* @ingroup domi_omg
* @brief Add node information to graph
* @param [in] layer layer infromation
* @param [in|out] graph graph for saving model information
* @return SUCCESS add successfully
* @return FAILED add failed
*/
Status AddNode(const domi::caffe::LayerParameter &layer, ge::ComputeGraphPtr &graph);
/**
* @ingroup domi_omg
* @brief Add edge information to graph
* @param [in|out] graph graph for saving model information
* @return SUCCESS add successfully
* @return FAILED add failed
*/
Status AddEdges(ge::ComputeGraphPtr &graph);

/**
* @ingroup domi_omg
* @brief Add edge information to graph
* @param [in|out] graph graph for saving model information
* @return SUCCESS add successfully
* @return FAILED add failed
*/
Status AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph);

/**
* @ingroup domi_omg
* @brief Check if the current layer is valid
* @return true valid
* @return false invalid
*/
bool CheckValidLayer(const domi::caffe::LayerParameter &layer);

/**
* @ingroup domi_omg
* @brief Check whether the top of the current layer is 'Inplace'
* @return true is 'Inplace'
* @return false not is 'Inplace'
*/
bool IsInplaceTopBlob(const domi::caffe::LayerParameter &layer, const std::string &top_name);

/**
* @ingroup domi_omg
* @brief Check whether the top of the current layer is user's specified output top
* @return true yes
* @return false no
*/
bool IsOutputTop(const string &op_name, int32_t index);

/**
* @ingroup domi_omg
* @brief Find a layer set with the same param
* @param [in] Param name set of each layer
* @param [in|out] Layer set of the same param
* @return Status
*/
Status FindShareParamLayers(const std::map<std::string, std::vector<std::string>> &);

Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer);

Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer,
const string &op_type);

Status AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph);

std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index);

Status GetCustomOp(const domi::caffe::LayerParameter &layer, vector<ge::Operator> &operators);

bool IsOpAttrEmpty(const ge::Operator &op, const std::string &type);

Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op,
std::shared_ptr<ge::OpParser> &op_parser);

Status GetLeafNodeTops(ge::ComputeGraphPtr &graph);

void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer);

Status ReorderInput(domi::caffe::NetParameter &net);

void AddOutputInfoToContext(string layer_name, int32_t top_index);

Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message);

std::map<std::string, ge::NodePtr> node_map;

// key: blob name, value: layer name and index
std::unordered_map<std::string, std::vector<std::pair<std::string, int32_t>>> bottom_blobs_map_;

// key: blob name, value: layer name and index
std::unordered_map<std::string, std::vector<std::pair<std::string, int32_t>>> top_blobs_map_;

std::vector<ge::Operator> custom_operator_;
std::map<std::string, std::vector<std::string>> layer_tops_map_;
};

/**
* @ingroup domi_omg
* @brief Caffe weight parser
*/
class CaffeWeightsParser : public domi::WeightsParser {
public:
/**
* @ingroup domi_omg
* @brief Parse weight data from file and save to graph
* @param [in] file Path of weight file after training
* @param [in|out] graph Save weight information after parsing
* @return SUCCESS parse successfully
* @return PARAM_INVALID param invalid
* @return PARSE_WEIGHTS_FAILED parse failed
*/
Status Parse(const char *file, ge::Graph &graph) override;

Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;

private:
Status CheckNodes(ge::ComputeGraphPtr &graph);
/**
* @ingroup domi_omg
* @brief Convert netparameter to modedef and save in graph
* @param [in] param Caffe network parameters to be converted
* @param [in|out] graph Save weight information after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
static Status ConvertNetParameter(const NetParameter &param, ge::ComputeGraphPtr &graph);

Status Parse(const char *file, ge::ComputeGraphPtr &graph);

Status ParseWeightByFusionProto(const char *model_path, const string &custom_proto_path,
const string &custom_proto_name, ge::ComputeGraphPtr &graph);

Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
const google::protobuf::Message *message,
ge::ComputeGraphPtr &graph);

Status ConvertLayerParameter(const google::protobuf::Message *layer_message,
ge::ComputeGraphPtr &graph);

Status CheckLayersSize(const google::protobuf::Message *message);

Status ConvertLayerProto(const google::protobuf::Message *message,
google::protobuf::Message *layer);

Status ParseLayerField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
google::protobuf::Message *layer);

Status ConvertBlobsProto(const google::protobuf::Message *message,
google::protobuf::Message *blobs);

Status ConvertBlobShapeProto(const google::protobuf::Message *message,
google::protobuf::Message *dest_message);

Status ConvertInnerProdcutProto(const google::protobuf::Message *message,
google::protobuf::Message *dest_message);

Status ConvertConvParamProto(const google::protobuf::Message *message,
google::protobuf::Message *dest_message);
/**
* @ingroup domi_omg
* @brief Layer types to be ignored in weight resolution
*/
static const set<string> skiped_layer_type_;
std::map<std::string, int32_t> layer_name_record_map_;
};
} // namespace domi

#endif // PARSER_CAFFE_CAFFE_PARSER_H_

+ 143
- 0
parser/caffe/caffe_reshape_parser.cc View File

@@ -0,0 +1,143 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/caffe/caffe_reshape_parser.h"
#include <vector>
#include "common/debug/log.h"
#include "common/ge/ge_util.h"
#include "common/op/op_parser_util.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/utils/graph_utils.h"
#include "parser/common/op_parser_factory.h"
#include "framework/omg/parser/parser_types.h"
#include "proto/om.pb.h"

using namespace ge::parser;
using domi::CAFFE;

namespace ge {
namespace {
const int kAnchorIndexZero = 0;
const int kAnchorIndexOne = 1;
} // namespace

Status CaffeReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
const LayerParameter *layer = DOMI_DYNAMIC_CAST<const LayerParameter *>(op_src);
if (layer == nullptr) {
GELOGE(FAILED, "Reshape Dynamic cast op_src to LayerParameter failed");
return FAILED;
}

GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str());
const ReshapeParameter &reshape_parameter = layer->reshape_param();

GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_AXIS, RESHAPE_AXIS_DEFAULT_VALUE)),
GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return
GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_NUM_AXES, RESHAPE_NUM_AXES_DEFAULT_VALUE)),
GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return

if (!reshape_parameter.has_shape()) {
GELOGE(FAILED, "Reshape has no shape info, ret fail");
return FAILED;
}
const BlobShape &blob_shape = reshape_parameter.shape();
std::vector<int64_t> dims;
for (int i = 0; i < blob_shape.dim_size(); i++) {
dims.push_back(blob_shape.dim(i));
}

if (reshape_parameter.has_axis()) {
GE_LOGW_IF(reshape_parameter.axis() == -1,
"axis with -1 may lead to calculation errors when input less than 4 dims.");
GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_AXIS, reshape_parameter.axis())),
GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return
}
if (reshape_parameter.has_num_axes()) {
GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetInt(op, RESHAPE_ATTR_NUM_AXES, reshape_parameter.num_axes())),
GELOGW("SetInt failed for op %s.", op->GetName().c_str());); // no need to return
}
GE_IF_BOOL_EXEC(!(ge::AttrUtils::SetListInt(op, RESHAPE_ATTR_SHAPE, dims)),
GELOGW("SetListInt failed for op %s.", op->GetName().c_str());); // no need to return
return SUCCESS;
}

Status CaffeReshapeParser::ParseWeights(const Message *op_src, ge::OpDescPtr &op) {
(void)op_src;
(void)op;
return SUCCESS;
}

Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto owner_graph = node->GetOwnerComputeGraph();
if (owner_graph == nullptr) {
GELOGE(FAILED, "node's graph is empty, name: %s", node->GetName().c_str());
return FAILED;
}
ge::OpDescPtr op = node->GetOpDesc();
GE_CHECK_NOTNULL(op);
vector<int64_t> attr_shape;
GE_IF_BOOL_EXEC(!(ge::AttrUtils::GetListInt(op, RESHAPE_ATTR_SHAPE, attr_shape)),
GELOGW("GetListInt failed for op %s.", op->GetName().c_str());); // no need to return
size_t dims_size = attr_shape.size();

// construct GeTensorDesc
ge::GeTensorDesc const_desc = ge::GeTensorDesc();
std::vector<int64_t> shape_vec = {static_cast<int64_t>(dims_size)};
ge::GeShape shape(shape_vec);
const_desc.Update(shape, ge::FORMAT_NCHW, ge::DT_INT64);
ge::graphStatus state = op->UpdateInputDesc(RESHAPE_ATTR_SHAPE, const_desc);
if (state != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "Updata input_shape desc failed.");
return FAILED;
}

// construct GeTensorPtr
ge::GeTensorPtr constTensor = ge::MakeShared<ge::GeTensor>();
GE_CHECK_NOTNULL(constTensor);
constTensor->SetTensorDesc(const_desc);

std::unique_ptr<int64_t[]> data(new (std::nothrow) int64_t[dims_size]());
GE_CHECK_NOTNULL(data);
for (size_t i = 0; i < dims_size; ++i) {
data[i] = attr_shape[i];
}
GE_IF_BOOL_EXEC(
constTensor->SetData(reinterpret_cast<uint8_t *>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS,
GELOGW("SetData failed for GeTensor.");); // no need to return

// construct const node and add edge
auto const_opdesc = ge::OpDescUtils::CreateConstOp(constTensor);
GE_CHECK_NOTNULL(const_opdesc);
auto const_node = owner_graph->AddNodeFront(const_opdesc);
GE_CHECK_NOTNULL(const_node);
ge::OutDataAnchorPtr out_archor_ptr = const_node->GetOutDataAnchor(kAnchorIndexZero);
GE_CHECK_NOTNULL(out_archor_ptr);
ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(kAnchorIndexOne);
GE_CHECK_NOTNULL(in_archor_ptr);
state = ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr);
if (state != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "AddEdge failed of from Node %s to Node %s", const_node->GetName().c_str(), node->GetName().c_str());
return domi::FAILED;
}
return SUCCESS;
}

REGISTER_OP_PARSER_CREATOR(CAFFE, RESHAPE, CaffeReshapeParser);
} // namespace ge

+ 59
- 0
parser/caffe/caffe_reshape_parser.h View File

@@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_
#define PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_

#include "parser/caffe/caffe_op_parser.h"

namespace ge {
class CaffeReshapeParser : public CaffeOpParser {
public:
/**
* @ingroup domi_omg
* @brief parse params of the operation
* @param [in] op_src params to be parsed
* @param [out] op_dest params after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
*/
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override;

/**
* @ingroup domi_omg
* @brief parse weight of the operation
* @param [in] op_src params to be parsed
* @param [out] op_dest params after parsing
* @return SUCCESS parse successfully
* @return FAILED parse failed
* @author
*/
Status ParseWeights(const Message *op_src, ge::OpDescPtr &op);

/**
* @ingroup domi_omg
* @brief add const input node
* @param [in] node to add const input
* @param [out] node after add const input
* @return SUCCESS add const input successfully
* @return FAILED add const input failed
* @author
*/
Status AddConstInput(ge::NodePtr &node) override;
};
} // namespace ge

#endif // PARSER_CAFFE_CAFFE_RESHAPE_PARSER_H_

+ 1821
- 0
parser/caffe/proto/caffe/caffe.proto
File diff suppressed because it is too large
View File


+ 190
- 0
parser/caffe/proto/ge_ir.proto View File

@@ -0,0 +1,190 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 396
- 0
parser/caffe/proto/om.proto View File

@@ -0,0 +1,396 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 165
- 0
parser/caffe/proto/task.proto View File

@@ -0,0 +1,165 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

message ModelTaskDef {
string version = 1;

map<string, string> attr = 9; // Extended field
repeated TaskDef task = 10;

uint64 memory_size = 11;
uint32 stream_num = 12;
uint32 event_num = 13;
uint64 weight_size = 14;

repeated bytes op = 15; // input/output opdef in bytes

uint64 base_addr = 16; // base addr
uint64 weight_addr = 17; // weight addr
uint32 batch_num = 18;
}


message TaskDef {
uint32 id = 1;
uint32 type = 2;

uint32 stream_id = 10;
uint32 event_id = 11;

KernelDef kernel = 20;
KernelExDef kernel_ex = 21;
KernelHcclDef kernel_hccl = 25;
EventExDef event_ex = 26;
LogTimeStampDef log_timestamp = 28;

uint32 label_id = 30;

MemcpyAsyncDef memcpy_async = 31;
StreamSwitchDef stream_switch = 32;
StreamActiveDef stream_active = 33;
bytes private_def = 34;
uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future
StreamSwitchNDef stream_switch_n = 36;

LabelSetDef label_set = 37;
LabelGotoExDef label_goto_ex = 38;
LabelSwitchByIndexDef label_switch_by_index = 39;
}

message KernelDef {
KernelContext context = 1;

string stub_func = 10;
uint32 block_dim = 11;
uint32 args_size = 12;
bytes args = 13;
bytes sm_desc = 14;
bytes flowtable = 15;
string so_name = 16;
string kernel_name = 17;
bytes kernel_ext_info = 18;
uint32 kernel_ext_info_size = 19;
}

message KernelContext {
uint32 kernel_type = 1;
uint32 op_id = 2; // OP type in CCE
uint32 kernel_func_id = 3;
uint32 op_index = 4; // TE/Custom operator
bool is_flowtable = 5; // Identify whether args is a flowtable structure
bytes args_offset = 6; // args offset information
uint32 args_count = 7; // args count
repeated uint32 origin_op_index = 8;
}


message KernelExDef {
uint32 flags = 1;

uint32 op_index = 4;
uint32 args_size = 12;
bytes args = 13;
bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput
uint32 task_info_size = 15;
bytes kernel_ext_info = 16;
uint32 kernel_ext_info_size = 17;
}


message KernelHcclDef {
uint32 op_index = 8;
string hccl_type = 9;
}


message EventExDef {
uint32 op_index = 1;
uint32 event_type = 2;
}

message LogTimeStampDef {
uint64 logid = 1;
bool notify = 2;
uint32 flat = 3;
}

message MemcpyAsyncDef {
uint64 dst = 1;
uint64 dst_max = 2;
uint64 src = 3;
uint64 count = 4;
uint32 kind = 5;
uint32 op_index = 6;
}

message StreamSwitchDef {
uint32 op_index = 1;
uint32 true_stream_id = 2;
int64 value = 3;
uint64 value_ptr = 4;
uint32 data_type = 5;
}

message StreamActiveDef {
uint32 op_index = 1;
uint32 active_stream_id = 2;
}

message StreamSwitchNDef {
uint32 op_index = 1;
uint32 size = 2;
repeated int64 target_value = 3;
repeated uint32 true_stream_id = 4;
uint32 element_size = 5;
uint32 data_type = 6;
}

message LabelSetDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelGotoExDef {
uint32 op_index = 1;
uint32 label_id = 2;
uint32 model_id = 3;
}

message LabelSwitchByIndexDef {
uint32 op_index = 1;
uint32 label_max = 2;
}

+ 76
- 0
parser/common/CMakeLists.txt View File

@@ -0,0 +1,76 @@
set(SRC_LIST
"parser_factory.cc"
"data_op_parser.cc"
"op_parser_factory.cc"
"pre_checker.cc"
"register_tbe.cc"
"parser_api.cc"
"parser_inner_ctx.cc"
"proto_file_parser.cc"
"acl_graph_parser_util.cc"
"tbe_plugin_loader.cc"
"model_saver.cc"
"../tensorflow/tensorflow_custom_parser_adapter.cc"
"../tensorflow/tensorflow_fusion_custom_parser_adapter.cc"
"../tensorflow/tensorflow_fusion_op_parser.cc"
"../tensorflow/tensorflow_util.cc"
"convert/pb2json.cc"
"op_def/ir_pb_converter.cc"
"op_def/defs.cc"
"op_def/op_schema.cc"
"op_def/operator.cc"
"op_map.cc"
"parser_types.cc"
"pass_manager.cc"
"parser_fp16_t.cc"
"thread_pool.cc"
)

############ libparser_common.so ############
add_library(parser_common SHARED ${SRC_LIST})

target_compile_options(parser_common PRIVATE
-Werror
)

target_compile_definitions(parser_common PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
)

target_include_directories(parser_common PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${TOP_DIR}/framework/domi
${TOP_DIR}/framework/domi/common
${TOP_DIR}/framework/domi/parser
${TOP_DIR}/inc
${TOP_DIR}/inc/common/util
${TOP_DIR}/inc/external
${TOP_DIR}/inc/external/graph
${TOP_DIR}/inc/framework
${CMAKE_BINARY_DIR}
${CMAKE_BINARY_DIR}/proto/ge
)

target_link_libraries(parser_common PRIVATE
$<BUILD_INTERFACE:intf_pub>
-Wl,--no-as-needed
graph
protobuf
register
c_sec
slog
mmpa
error_manager
-Wl,--as-needed
json
-lrt
-ldl
)

############ install ############
set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(TARGETS parser_common OPTIONAL
LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}
)

+ 492
- 0
parser/common/acl_graph_parser_util.cc View File

@@ -0,0 +1,492 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/acl_graph_parser_util.h"

#include <dlfcn.h>
#include <cstdlib>
#include <fstream>
#include <regex.h>
#include <ctime>

#include "common/string_util.h"
#include "common/debug/log.h"
#include "common/op/ge_op_utils.h"
#include "ge/ge_api_types.h"
#include "graph/opsproto_manager.h"
#include "omg/parser/parser_inner_ctx.h"
#include "tbe_plugin_loader.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/register_tbe.h"
#include "framework/omg/parser/parser_types.h"
#include "common/util/error_manager/error_manager.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"

using google::protobuf::io::CodedInputStream;
using google::protobuf::io::FileInputStream;
using google::protobuf::io::ZeroCopyInputStream;
using namespace ge::parser;

namespace {
/// The maximum length of the file.
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
const int kMaxFileSizeLimit = INT_MAX;
const int kMaxBuffSize = 256;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M

static string GetSoPath() {
Dl_info dl_info;
if (dladdr(reinterpret_cast<void *>(&GetSoPath), &dl_info) == 0) {
GELOGW("Failed to read so_path!");
return string();
} else {
std::string so_path = dl_info.dli_fname;
char path[PATH_MAX] = {0};
if (so_path.length() >= PATH_MAX) {
GELOGW("File path is too long!");
return string();
}
if (realpath(so_path.c_str(), path) == nullptr) {
GELOGW("Failed to get realpath of %s", so_path.c_str());
return string();
}

so_path = path;
so_path = so_path.substr(0, so_path.rfind('/') + 1);
return so_path;
}
}

static void GetOpsProtoPath(string &opsproto_path) {
GELOGD("Start to get ops proto path schedule.");
const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
string path = path_env;
string file_path = ge::parser::RealPath(path.c_str());
if (file_path.empty()) {
GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str());
return;
}
opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
GELOGI("Get opsproto so path from env : %s", path.c_str());
return;
}
string path_base = GetSoPath();
GELOGI("path_base is %s", path_base.c_str());
path_base = path_base.substr(0, path_base.rfind('/'));
path_base = path_base.substr(0, path_base.rfind('/') + 1);
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
}
} // namespace

namespace ge {
domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node,
std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
ge::OpDescPtr tmpDescPtr = node->GetOpDesc();
if (tmpDescPtr == nullptr) {
GELOGE(domi::FAILED, "Get outnode op desc fail.");
return domi::FAILED;
}
size_t size = tmpDescPtr->GetOutputsSize();
if (node->GetType() != NETOUTPUT) {
for (size_t index = 0; index < size; ++index) {
output_nodes_info.push_back(std::make_pair(node, index));
}
} else {
const auto in_anchors = node->GetAllInDataAnchors();
for (auto in_anchor : in_anchors) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
GELOGE(domi::FAILED, "Get leaf node op desc fail.");
return domi::FAILED;
}
auto out_node = out_anchor->GetOwnerNode();
output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx()));
}
}
return SUCCESS;
}

void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name) {
output_nodes_name.clear();
if (ge::GetParserContext().out_top_names.empty()) {
// tf process, no top name.
for (const auto output_node_info : output_nodes_info) {
std::string node_name = output_node_info.first->GetName();
int32_t index = output_node_info.second;
output_nodes_name.push_back(node_name + ":" + std::to_string(index));
}
return;
}
// caffe process reserved place;
}

domi::Status AclGrphParseUtil::SetDefaultOutputNode(ge::Graph &graph) {
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
if (compute_graph == nullptr) {
GELOGE(FAILED, "compute_graph is nullptr.");
return FAILED;
}

std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info;
std::vector<std::string> output_nodes_name;

for (ge::NodePtr node : compute_graph->GetDirectNode()) {
if (!node->GetInAllNodes().empty() && node->GetOutAllNodes().empty()) {
Status ret = AclGrphParseUtil::GetOutputLeaf(node, output_nodes_info);
if (ret != SUCCESS) {
GELOGE(FAILED, "find leaf fail.");
return FAILED;
}
}
}

AclGrphParseUtil::GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name);
compute_graph->SetGraphOutNodesInfo(output_nodes_info);
ge::GetParserContext().net_out_nodes = output_nodes_name;
GELOGI("Set graph %s default output node success.", graph.GetName().c_str());
return SUCCESS;
}

domi::Status AclGrphParseUtil::LoadOpsProtoLib() {
string opsproto_path;
GetOpsProtoPath(opsproto_path);
GELOGI("Get opsproto path is %s", opsproto_path.c_str());
OpsProtoManager *manager = OpsProtoManager::Instance();
map<string, string> option_tmp;
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
bool is_proto_init = manager->Initialize(option_tmp);
if (!is_proto_init) {
GELOGE(FAILED, "Load ops_proto lib failed, ops proto path is invalid.");
return FAILED;
}
return SUCCESS;
}

void AclGrphParseUtil::SaveCustomCaffeProtoPath() {
GELOGD("Enter save custom caffe proto path.");
std::string path_base = GetSoPath();
path_base = path_base.substr(0, path_base.rfind('/'));
path_base = path_base.substr(0, path_base.rfind('/') + 1);
ge::GetParserContext().caffe_proto_path = path_base + "include/proto/";

string custom_op_path;
const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
std::string path = path_env;
custom_op_path = path + "/framework/custom/caffe/";
GELOGI("Get custom proto path from env : %s", path_env);
GetParserContext().custom_proto_path = custom_op_path;
return;
}
custom_op_path = path_base + "ops/framework/custom/caffe/";
ge::GetParserContext().custom_proto_path = custom_op_path;
return;
}

// Initialize PARSER, load custom op plugin
// options will be used later for parser decoupling
domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, std::string> &options) {
GELOGT(TRACE_INIT, "AclParserInitialize start");
// check init status
if (parser_initialized) {
GELOGW("AclParserInitialize is called more than once");
return SUCCESS;
}

// load custom op plugin
TBEPluginLoader::Instance().LoadPluginSo(options);

// load and save custom op proto for prediction
(void)LoadOpsProtoLib();
SaveCustomCaffeProtoPath();

auto op_registry = domi::OpRegistry::Instance();
if (op_registry == nullptr) {
GELOGE(FAILED, "Get OpRegistry instance failed");
return FAILED;
}

std::vector<OpRegistrationData> registrationDatas = op_registry->registrationDatas;
GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size());
for (OpRegistrationData &reg_data : registrationDatas) {
(void)OpRegistrationTbe::Instance()->Finalize(reg_data, false);
domi::OpRegistry::Instance()->Register(reg_data);
}

// set init status
if (!parser_initialized) {
// Initialize success, first time calling initialize
parser_initialized = true;
}

GELOGT(TRACE_STOP, "AclParserInitialize finished");
return SUCCESS;
}
namespace parser {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) {
if (path == nullptr) {
GELOGE(ge::FAILED, "path pointer is NULL.");
return "";
}
if (strlen(path) >= PATH_MAX) {
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)});
GELOGE(ge::FAILED, "Path[%s] len is too long, it must be less than %d", path, PATH_MAX);
return "";
}
// Nullptr is returned when the path does not exist or there is no permission
// Return absolute path when path is accessible
std::string res;
char resolved_path[PATH_MAX] = {0};
if (realpath(path, resolved_path) != nullptr) {
res = resolved_path;
}

return res;
}

// Get file length
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null.");

std::string real_path = RealPath(input_file.c_str());

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str());
unsigned long long file_length = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK,
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"},
{input_file, strerror(errno)});
return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno));

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0),
ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file});
return -1, "File[%s] size is 0, not valid.", input_file.c_str());

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit,
ErrorManager::GetInstance().ATCReportErrMessage(
"E19016", {"filepath", "filesize", "maxlen"},
{input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)});
return -1, "File[%s] size %lld is out of limit: %d.",
input_file.c_str(), file_length, kMaxFileSizeLimit);
return static_cast<long>(file_length);
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() {
struct timeval tv{};
int ret = gettimeofday(&tv, nullptr);
GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret);
auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds
return static_cast<uint64_t>(total_use_time);
}

static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr,
return false, "incorrect parameter. nullptr == proto");

coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit, kWarningThreshold);
return proto->ParseFromCodedStream(&coded_stream);
}

/** @ingroup domi_common
* @brief Read all data from binary file
* @param [in] file_name File path
* @param [out] buffer The address of the output memory, which needs to be released by the caller
* @param [out] length Output memory size
* @return false fail
* @return true success
*/
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer,
int &length) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr");

std::string real_path = RealPath(file_name);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "file path '%s' not valid", file_name);

std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate);
if (!file.is_open()) {
GELOGE(ge::FAILED, "Read file %s failed.", file_name);
return false;
}

length = static_cast<int>(file.tellg());

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); return false, "file length <= 0");

file.seekg(0, std::ios::beg);

*buffer = new(std::nothrow) char[length]();
GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(), "new an object failed.");

file.read(*buffer, length);
file.close();
return true;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr),
return false,
"Input parameter file or proto is nullptr!");

std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
return false, "pb file path '%s' not valid", file);

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");

std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary);
if (!fs.is_open()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"});
GELOGE(ge::FAILED, "Open real path[%s] failed.", file);
return false;
}

google::protobuf::io::IstreamInputStream istream(&fs);
google::protobuf::io::CodedInputStream coded_stream(&istream);

bool ret = ReadProtoFromCodedInputStream(coded_stream, proto);

fs.close();

if (!ret) {
ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file});
GELOGE(ge::FAILED, "Parse file[%s] failed.", file);
return ret;
}

return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false,
"incorrect parameter. proto is nullptr || data is nullptr || size is 0");

google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size);
return ReadProtoFromCodedInputStream(coded_stream, proto);
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file,
google::protobuf::Message *message) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false,
"incorrect parameter. nullptr == file || nullptr == message");

std::string real_path = RealPath(file);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(),
ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"},
{file, strerror(errno)});
return false, "Path[%s]'s realpath is empty, errmsg[%s]", file,
strerror(errno));

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid.");

std::ifstream fs(real_path.c_str(), std::ifstream::in);

if (!fs.is_open()) {
ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file});
GELOGE(ge::FAILED,
"Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file);
return false;
}

google::protobuf::io::IstreamInputStream input(&fs);
bool ret = google::protobuf::TextFormat::Parse(&input, message);
GE_IF_BOOL_EXEC(!ret,
ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file});
GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, "
"please check whether the file is a valid protobuf format file.", file));
fs.close();

return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size,
google::protobuf::Message *message) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false,
"incorrect parameter. data is nullptr || message is nullptr");
std::string str(data, static_cast<size_t>(size));
std::istringstream fs(str);

google::protobuf::io::IstreamInputStream input(&fs);
bool ret = google::protobuf::TextFormat::Parse(&input, message);
GE_IF_BOOL_EXEC(
!ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));

return ret;
}

///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
/// @param [out] type
/// @return Status
///
Status GetOriginalType(const ge::NodePtr &node, string &type) {
GE_CHECK_NOTNULL(node);
type = node->GetType();
GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
GE_CHECK_NOTNULL(node->GetOpDesc());
bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
if (!ret) {
GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str());
return INTERNAL_ERROR;
}
GELOGD("Get FrameWorkOp original type [%s]", type.c_str());
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) {
char ebuff[kMaxBuffSize];
regex_t reg;
int cflags = REG_EXTENDED | REG_NOSUB;
int ret = regcomp(&reg, mode.c_str(), cflags);
if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGW("regcomp failed, reason: %s", ebuff);
regfree(&reg);
return true;
}

ret = regexec(&reg, str.c_str(), 0, nullptr, 0);
if (ret) {
regerror(ret, &reg, ebuff, kMaxBuffSize);
GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff);
regfree(&reg);
return false;
}

regfree(&reg);
return true;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() {
std::time_t now = std::time(nullptr);
std::tm *ptm = std::localtime(&now);
if (ptm == nullptr) {
GELOGE(ge::FAILED, "Localtime failed.");
return "";
}

const int kTimeBufferLen = 32;
char buffer[kTimeBufferLen + 1] = {0};
// format: 20171122042550
std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm);
return std::string(buffer);
}
} // namespace parser
} // namespace ge

+ 161
- 0
parser/common/acl_graph_parser_util.h View File

@@ -0,0 +1,161 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ACL_GRAPH_PARSE_UTIL_
#define ACL_GRAPH_PARSE_UTIL_

#include <map>
#include <string>
#include <google/protobuf/text_format.h>
#include <sstream>

#include "framework/omg/parser/parser_types.h"
#include "register/register_error_codes.h"
#include "graph/utils/graph_utils.h"

namespace ge {

using google::protobuf::Message;

class AclGrphParseUtil {
public:
AclGrphParseUtil() {}
virtual ~AclGrphParseUtil() {}
domi::Status LoadOpsProtoLib();
void SaveCustomCaffeProtoPath();
domi::Status AclParserInitialize(const std::map<std::string, std::string> &options);
domi::Status SetDefaultOutputNode(ge::Graph &graph);

private:
bool parser_initialized = false;
domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info);
void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
};

namespace parser {
///
/// @ingroup: domi_common
/// @brief: get length of file
/// @param [in] input_file: path of file
/// @return long: File length. If the file length fails to be obtained, the value -1 is returned.
///
extern long GetFileLength(const std::string &input_file);

///
/// @ingroup domi_common
/// @brief Absolute path for obtaining files.
/// @param [in] path of input file
/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned
///
std::string RealPath(const char *path);

///
/// @ingroup domi_common
/// @brief Obtains the absolute time (timestamp) of the current system.
/// @return Timestamp, in microseconds (US)
///
///
uint64_t GetCurrentTimestamp();

///
/// @ingroup domi_common
/// @brief Reads all data from a binary file.
/// @param [in] file_name path of file
/// @param [out] buffer Output memory address, which needs to be released by the caller.
/// @param [out] length Output memory size
/// @return false fail
/// @return true success
///
bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length);

///
/// @ingroup domi_common
/// @brief proto file in bianary format
/// @param [in] file path of proto file
/// @param [out] proto memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromBinaryFile(const char *file, Message *proto);

///
/// @ingroup domi_common
/// @brief Reads the proto structure from an array.
/// @param [in] data proto data to be read
/// @param [in] size proto data size
/// @param [out] proto Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromArray(const void *data, int size, Message *proto);

///
/// @ingroup domi_proto
/// @brief Reads the proto file in the text format.
/// @param [in] file path of proto file
/// @param [out] message Memory for storing the proto file
/// @return true success
/// @return false fail
///
bool ReadProtoFromText(const char *file, google::protobuf::Message *message);

bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message);

///
/// @brief get the Original Type of FrameworkOp
/// @param [in] node
/// @param [out] type
/// @return Status
///
domi::Status GetOriginalType(const ge::NodePtr &node, string &type);

///
/// @ingroup domi_common
/// @brief Check whether the file path meets the whitelist verification requirements.
/// @param [in] filePath file path
/// @param [out] result
///
bool ValidateStr(const std::string &filePath, const std::string &mode);

///
/// @ingroup domi_common
/// @brief Obtains the current time string.
/// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555
///
std::string CurrentTimeInStr();
} // namespace parser
} // namespace ge

/*lint --emacro((773),GE_TIMESTAMP_START)*/
/*lint -esym(773,GE_TIMESTAMP_START)*/
#define PARSER_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::parser::GetCurrentTimestamp()

#define PARSER_TIMESTAMP_END(stage, stage_name) \
do { \
uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \
GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \
(endUsec_##stage - startUsec_##stage)); \
} while (0);

#define PARSER_TIMESTAMP_EVENT_END(stage, stage_name) \
do { \
uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \
GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \
(endUsec_##stage - startUsec_##stage)); \
} while (0);

#endif // ACL_GRAPH_PARSE_UTIL_

+ 248
- 0
parser/common/convert/pb2json.cc View File

@@ -0,0 +1,248 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// File: pb2json.h
// Description: This imply file for protobuf message and json interconversion

#include "common/convert/pb2json.h"
#include <set>
#include <string>
#include "securec.h"
#include "framework/common/fmk_types.h"
#include "framework/common/debug/ge_log.h"

using std::set;
using std::string;

namespace ge {
namespace {
const int kSignificantDigits = 10;
}
// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message,
const set<string> &black_fields, Json &json,
bool enum2str) {
auto descriptor = message.GetDescriptor();
auto reflection = message.GetReflection();
if (descriptor == nullptr || reflection == nullptr) {
return;
}

auto count = descriptor->field_count();

for (auto i = 0; i < count; ++i) {
const auto field = descriptor->field(i);
if (field == nullptr) {
return;
}

// Do not display weight data
if (black_fields.find(field->name()) != black_fields.end()) {
continue;
}

if (field->is_repeated()) {
if (reflection->FieldSize(message, field) > 0) {
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str);
}
continue;
}

if (!reflection->HasField(message, field)) {
continue;
}

OneField2Json(message, field, reflection, black_fields, json, enum2str);
}
}

void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field);
if (0 != tmp_message.ByteSize()) {
Message2Json(tmp_message, black_fields, json[field->name()], enum2str);
}
break;
}

case ProtobufFieldDescriptor::TYPE_BOOL:
json[field->name()] = reflection->GetBool(message, field);
break;

case ProtobufFieldDescriptor::TYPE_ENUM: {
auto *enum_value_desc = reflection->GetEnum(message, field);
Enum2Json(enum_value_desc, field, enum2str, json);
break;
}

case ProtobufFieldDescriptor::TYPE_INT32:
case ProtobufFieldDescriptor::TYPE_SINT32:
case ProtobufFieldDescriptor::TYPE_SFIXED32:
json[field->name()] = reflection->GetInt32(message, field);
break;

case ProtobufFieldDescriptor::TYPE_UINT32:
case ProtobufFieldDescriptor::TYPE_FIXED32:
json[field->name()] = reflection->GetUInt32(message, field);
break;

case ProtobufFieldDescriptor::TYPE_INT64:
case ProtobufFieldDescriptor::TYPE_SINT64:
case ProtobufFieldDescriptor::TYPE_SFIXED64:
json[field->name()] = reflection->GetInt64(message, field);
break;

case ProtobufFieldDescriptor::TYPE_UINT64:
case ProtobufFieldDescriptor::TYPE_FIXED64:
json[field->name()] = reflection->GetUInt64(message, field);
break;

case ProtobufFieldDescriptor::TYPE_FLOAT:
char str[kSignificantDigits];
if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){
json[field->name()] = str;
} else {
json[field->name()] = reflection->GetFloat(message, field);
}

break;

case ProtobufFieldDescriptor::TYPE_STRING:
json[field->name()] = reflection->GetString(message, field);
break;

case ProtobufFieldDescriptor::TYPE_BYTES: {
string field_name = field->name();
string type_bytes = reflection->GetString(message, field);
json[field_name] = TypeBytes2String(field_name, type_bytes);
break;
}

default:
break;
}
}

string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) {
if (field_name != "offset") {
return type_bytes;
}
string result = "";
for (char temp_value : type_bytes) {
uint8_t *value = 0;
value = reinterpret_cast<uint8_t *>(&temp_value);
char str[kSignificantDigits];
if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){
GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str());
continue;
}
result += str;
}
return result;
}

void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
if ((field == nullptr) || (reflection == nullptr)) {
Message2Json(message, black_fields, json, enum2str);
return;
}

for (auto i = 0; i < reflection->FieldSize(message, field); ++i) {
Json tmp_json;
switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i);
if (0 != tmp_message.ByteSize()) {
Message2Json(tmp_message, black_fields, tmp_json, enum2str);
}
} break;

case ProtobufFieldDescriptor::TYPE_BOOL:
tmp_json = reflection->GetRepeatedBool(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_ENUM: {
auto *enum_value_desc = reflection->GetRepeatedEnum(message, field, i);
RepeatedEnum2Json(enum_value_desc, enum2str, tmp_json);
} break;

case ProtobufFieldDescriptor::TYPE_INT32:
case ProtobufFieldDescriptor::TYPE_SINT32:
case ProtobufFieldDescriptor::TYPE_SFIXED32:
tmp_json = reflection->GetRepeatedInt32(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_UINT32:
case ProtobufFieldDescriptor::TYPE_FIXED32:
tmp_json = reflection->GetRepeatedUInt32(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_INT64:
case ProtobufFieldDescriptor::TYPE_SINT64:
case ProtobufFieldDescriptor::TYPE_SFIXED64:
tmp_json = reflection->GetRepeatedInt64(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_UINT64:
case ProtobufFieldDescriptor::TYPE_FIXED64:
tmp_json = reflection->GetRepeatedUInt64(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_FLOAT:
tmp_json = reflection->GetRepeatedFloat(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_STRING:
case ProtobufFieldDescriptor::TYPE_BYTES:
tmp_json = reflection->GetRepeatedString(message, field, i);
break;

default:
break;
}
json += tmp_json;
}
}

void Pb2Json::Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field,
bool enum2str, Json &json) {
if (enum_value_desc != nullptr) {
if (field == nullptr) {
return;
}
if (enum2str) {
json[field->name()] = enum_value_desc->name();
} else {
json[field->name()] = enum_value_desc->number();
}
}
}

void Pb2Json::RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json) {
if (enum_value_desc != nullptr) {
if (enum2str) {
json = enum_value_desc->name();
} else {
json = enum_value_desc->number();
}
}
}
} // namespace ge

+ 68
- 0
parser/common/convert/pb2json.h View File

@@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// File: pb2json.h
// Description: This header file for protobuf message and json interconversion

#ifndef PARSER_COMMON_CONVERT_PB2JSON_H_
#define PARSER_COMMON_CONVERT_PB2JSON_H_
#include <functional>
#include <memory>
#include <set>
#include <string>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "nlohmann/json.hpp"

namespace ge {
using Json = nlohmann::json;
using ProtobufMsg = ::google::protobuf::Message;
using ProtobufReflection = ::google::protobuf::Reflection;
using ProtobufFieldDescriptor = ::google::protobuf::FieldDescriptor;
using ProtobufDescriptor = ::google::protobuf::Descriptor;
using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor;

class Pb2Json {
public:
/**
* @ingroup domi_omg
* @brief Transfer protobuf object to JSON object
* @param [out] json Converted JSON object
* @return void success
* @author
*/
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json,
bool enum2str = false);

protected:
static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const std::set<std::string> &black_fields,
Json &json, bool enum2str);

static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field,
bool enum2str, Json &json);

static void RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json);

static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json,
bool enum2str);

static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes);
};
} // namespace ge

#endif // PARSER_COMMON_CONVERT_PB2JSON_H_

+ 212
- 0
parser/common/data_op_parser.cc View File

@@ -0,0 +1,212 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/data_op_parser.h"
#include <cstdlib>
#include "common/debug/log.h"
#include "common/op/ge_op_utils.h"
#include "common/math/math_util.h"
#include "common/util.h"
#include "graph/utils/type_utils.h"
#include "omg/omg.h"

using namespace cce;
namespace {
const int kDataMemAlignSize = 32;
const int kTwoTimesAlign = 2;
const int kDynamicBatchInputSize = -1;
const uint32_t kScalarLength = 1;
} // namespace

namespace ge {
FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> &shape, ge::OpDescPtr op) {
GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "ParseShape failed for data_op, op is null");

const string &data_op_name = op->GetName();
GetParserContext().input_dims.emplace(data_op_name, shape);

int64_t attr_type = 0;
ge::DataType data_type;
if (ge::AttrUtils::GetInt(op, ge::DATA_ATTR_NAME_DATA_TYPE, attr_type)) {
data_type = static_cast<ge::DataType>(attr_type);
} else {
data_type = ge::DT_FLOAT;
}

// convert input
vector<int64_t> def_format_shape(shape);

ge::GeTensorDesc i_tensor_desc;
ge::GeTensorDesc o_tensor_desc;
const unordered_map<string, domiTensorFormat_t> &input_nodes_format_map = GetParserContext().input_nodes_format_map;
auto map_iter = input_nodes_format_map.find(data_op_name);
if (map_iter != input_nodes_format_map.end() && map_iter->second == domi::DOMI_TENSOR_NC1HWC0) {
// Input 5D NC1HWC0
GE_RETURN_WITH_LOG_IF_ERROR(Init5DInputTensor(def_format_shape, i_tensor_desc), "InitInputTensor failed");
// Output
GE_RETURN_WITH_LOG_IF_ERROR(Init5DOutputTensor(def_format_shape, o_tensor_desc), "InitOutputTensor failed");
} else {
// No need to consider AIPP here,
// The adjustdatanodedesc function of model_builder will process the
// input_desc and output_desc of AIPP's data node.
// Without AIPP, the data of input float is kept in cctranstensor implementation.
// The cast operator can not run in the pvmodel simulation environment,
// so the input data conversion processing maintains the original state.
// To be modified after AICPU operators support pvmodel.
if (data_type == ge::DT_FLOAT) {
// Input
GE_RETURN_WITH_LOG_IF_ERROR(InitInputTensor(def_format_shape, i_tensor_desc), "InitInputTensor failed");
// Output
GE_RETURN_WITH_LOG_IF_ERROR(InitOutputTensor(def_format_shape, o_tensor_desc), "InitOutputTensor failed");
} else {
// Input
GE_RETURN_WITH_LOG_IF_ERROR(InitNDTensor(def_format_shape, data_type, i_tensor_desc),
"Init ND InputTensor failed");
// Output
GE_RETURN_WITH_LOG_IF_ERROR(InitNDTensor(def_format_shape, data_type, o_tensor_desc),
"Init ND Output Tensor failed");
}
}
i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format));
i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format));
o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format));
if (op->AddInputDesc(i_tensor_desc) != ge::GRAPH_SUCCESS) {
GELOGE(domi::INTERNAL_ERROR, "AddInputDesc failed for op %s.", op->GetName().c_str());
return FAILED;
}
if (op->AddOutputDesc(o_tensor_desc) != ge::GRAPH_SUCCESS) {
GELOGE(domi::INTERNAL_ERROR, "AddOutputDesc failed for op %s.", op->GetName().c_str());
return FAILED;
}
return SUCCESS;
}

Status DataOpParser::Init5DInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &tensor_desc) {
tensor_desc.SetDataType(ge::DT_FLOAT16);
tensor_desc.SetFormat(static_cast<ge::Format>(domi::DOMI_TENSOR_NC1HWC0));
ge::TensorUtils::SetReuseInput(tensor_desc, false);
ge::TensorUtils::SetRealDimCnt(tensor_desc, shape.size());
tensor_desc.SetShape(ge::GeShape(shape));

int64_t tensor_size = 0;
ge::graphStatus graph_status = ge::TensorUtils::GetTensorSizeInBytes(tensor_desc, tensor_size);
if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "GetTensorSizeInBytes failed!");
return domi::FAILED;
}
// Set the actual occupied space size
ge::TensorUtils::SetSize(tensor_desc, tensor_size);
return SUCCESS;
}

Status DataOpParser::InitNDTensor(const vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &tensor_desc) {
// Fixed input ND
tensor_desc.SetFormat(static_cast<ge::Format>(DOMI_TENSOR_ND));
tensor_desc.SetDataType(data_type);
tensor_desc.SetOriginDataType(data_type);
ge::TensorUtils::SetReuseInput(tensor_desc, false);
ge::TensorUtils::SetRealDimCnt(tensor_desc, shape.size());
tensor_desc.SetShape(ge::GeShape(shape));
tensor_desc.SetOriginShape(ge::GeShape(shape));

int64_t size = kScalarLength;
if (!tensor_desc.GetShape().GetDims().empty()) {
size = tensor_desc.GetShape().GetShapeSize();
}
uint32_t type_size = 0;
if (ge::TypeUtils::GetDataTypeLength(data_type, type_size)) {
FMK_INT64_UINT32_MULCHECK(size, type_size);
size *= type_size;
} else {
FMK_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float)));
size *= sizeof(float);
}
ge::TensorUtils::SetSize(tensor_desc, size);
return SUCCESS;
}

Status DataOpParser::Init5DOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output) {
output.SetDataType(ge::DT_FLOAT16);
output.SetFormat(static_cast<ge::Format>(domi::DOMI_TENSOR_NC1HWC0));
ge::TensorUtils::SetReuseInput(output, false);
ge::TensorUtils::SetRealDimCnt(output, shape.size());
output.SetShape(ge::GeShape(shape));

int64_t output_size = 0;
ge::graphStatus graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(output, output_size);
if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "GetTensorMemorySizeInBytes failed!");
return domi::FAILED;
}
// Set the actual occupied space size
ge::TensorUtils::SetSize(output, output_size);
return SUCCESS;
}

Status DataOpParser::InitInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &input) {
input.SetFormat(static_cast<ge::Format>(domiTensorFormat_t(DOMI_TENSOR_ND)));
input.SetDataType(ge::DT_FLOAT);
input.SetOriginDataType(ge::DT_FLOAT);
ge::TensorUtils::SetReuseInput(input, false);

input.SetShape(ge::GeShape(shape));
input.SetOriginShape(ge::GeShape(shape));
int64_t size = 0;
// No need to check dynamic_batch_size since its first dim is -1.
if (input.GetShape().GetDim(0) != -1) {
size = input.GetShape().GetShapeSize();
}
FMK_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float)));
ge::TensorUtils::SetSize(input, size * sizeof(float));

return SUCCESS;
}

Status DataOpParser::InitOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output) {
int64_t output_size = 0;
ge::GeShape output_shape = ge::GeShape(shape);
ge::Format format = ge::FORMAT_ND;
ge::DataType data_type = ge::DT_FLOAT;
output.SetFormat(format);
output.SetDataType(data_type);
ge::TensorUtils::SetReuseInput(output, false);
ge::TensorUtils::SetRealDimCnt(output, shape.size());
output.SetShape(output_shape);

ge::graphStatus graph_status = ge::TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_size);
if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "CalcTensorMemSize failed!");
return FAILED;
}

if (output_size == kDynamicBatchInputSize) {
GELOGI("After calc tensor memory size, output_mem_size = %ld", output_size);
return SUCCESS;
}

int64_t size = output_size;
auto valid_max_size = INT64_MAX - kTwoTimesAlign * kDataMemAlignSize;
if (size > valid_max_size || size < 0) {
GELOGE(FAILED, "The updated mem size is out of data range [0, %ld]", valid_max_size);
return FAILED;
} else {
size = ((size + kTwoTimesAlign * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize;
}
// Set the actual occupied space size
ge::TensorUtils::SetSize(output, size);
return SUCCESS;
}
} // namespace ge

+ 109
- 0
parser/common/data_op_parser.h View File

@@ -0,0 +1,109 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_DATA_OP_PARSER_H_
#define PARSER_COMMON_DATA_OP_PARSER_H_

#include <google/protobuf/text_format.h>
#include <vector>
#include "common/debug/log.h"
#include "common/op/attr_value_util.h"
#include "framework/omg/parser/parser_types.h"
#include "omg/omg_inner_types.h"
#include "proto/om.pb.h"

#include "graph/attr_value.h"
#include "graph/compute_graph.h"
#include "graph/ge_tensor.h"
#include "graph/op_desc.h"
#include "graph/operator.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/tensor_utils.h"

using google::protobuf::Message;
using std::vector;

namespace ge {
/**
* @ingroup domi_omg
* @brief Provide a public interface for DataOp
*
*/
class DataOpParser {
public:
virtual ~DataOpParser() {}

protected:
/**
* @ingroup domi_omg
* @brief parser the Shape information of DataOp
* @param [in] shape 4D shape information (dimensions)
* @param [out] op Save converted shap information
* @return SUCCESS Parsing success
* @return FAILED Parsing failed
*/
static Status ParseShape(const vector<int64_t> &shape, ge::OpDescPtr op);

private:
/**
* @ingroup domi_omg
* @brief Convert Input's Shape Information
* @param [in] 4D shape information (dimensions)
* @param [out] Save converted shap information
*/
static Status Init5DInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &tensorDesc);

/**
* @ingroup domi_omg
* @brief Convert Shape of Output
* @param [in] shape 4D shape information (dimensions)
* @param [out] output Save converted shap information
* @return SUCCESS Convert success
* @return FAILED Convert failed
*/
static Status Init5DOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output);

/**
* @ingroup domi_omg
* @brief 4D shape information (dimensions)4D shape information (dimensions)4D shape information (dimensions)
* @param [in] 4D shape information (dimensions)
* @param [out] input Save converted shap information
*/
static Status InitInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &input);

/**
* @ingroup domi_omg
* @brief Convert Shape of Output
* @param [in] shape 4D shape information (dimensions)
* @param [out] output Save converted shap information
* @return SUCCESS Convert success
* @return FAILED Convert failed
*/
static Status InitOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output);

/**
* @ingroup domi_omg
* @brief Convert Shape of Output
* @param [in] shape 4D shape information (dimensions)
* @param [out] output Save converted shap information
* @return SUCCESS Convert success
* @return FAILED Convert failed
*/
static Status InitNDTensor(const vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &desc);
};
} // namespace ge

#endif // PARSER_COMMON_DATA_OP_PARSER_H_

+ 155
- 0
parser/common/model_saver.cc View File

@@ -0,0 +1,155 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <sys/stat.h>
#include <fcntl.h>

#include "parser/common/model_saver.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "common/util/error_manager/error_manager.h"
#include "mmpa/mmpa_api.h"

namespace {
const int kFileOpSuccess = 0;
} // namespace

namespace ge {
namespace parser {
const uint32_t kInteval = 2;

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFile(const char *file_path,
const Json &model) {
Status ret = SUCCESS;
if (file_path == nullptr || SUCCESS != CheckPath(file_path)) {
GELOGE(FAILED, "Check output file failed.");
return FAILED;
}
std::string model_str;
try {
model_str = model.dump(kInteval, ' ', false, Json::error_handler_t::ignore);
} catch (std::exception &e) {
ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()});
GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what());
return FAILED;
} catch (...) {
ErrorManager::GetInstance().ATCReportErrMessage("E19008");
GELOGE(FAILED, "Failed to convert JSON to string.");
return FAILED;
}

char real_path[PATH_MAX] = {0};
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX, return FAILED, "file path is too long!");
if (realpath(file_path, real_path) == nullptr) {
GELOGI("File %s does not exit, it will be created.", file_path);
}

// Open file
mode_t mode = S_IRUSR | S_IWUSR;
int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode);
if (fd == EN_ERROR || fd == EN_INVALID_PARAM) {
ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)});
GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno));
return FAILED;
}
const char *model_char = model_str.c_str();
uint32_t len = static_cast<uint32_t>(model_str.length());
// Write data to file
mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len);
if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19004", {"file", "errmsg"}, {file_path, strerror(errno)});
// Need to both print the error info of mmWrite and mmClose, so return ret after mmClose
GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno));
ret = FAILED;
}
// Close file
if (mmClose(fd) != EN_OK) {
GELOGE(FAILED, "Close file failed.");
ret = FAILED;
}
return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::CheckPath(const std::string &file_path) {
// Determine file path length
if (file_path.size() >= PATH_MAX) {
GELOGE(FAILED, "Path is too long:%zu", file_path.size());
return FAILED;
}

// Find the last separator
int path_split_pos = static_cast<int>(file_path.size() - 1);
for (; path_split_pos >= 0; path_split_pos--) {
if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') {
break;
}
}

if (path_split_pos == 0) {
return SUCCESS;
}

// If there is a path before the file name, create the path
if (path_split_pos != -1) {
if (CreateDirectory(std::string(file_path).substr(0, static_cast<size_t>(path_split_pos))) != kFileOpSuccess) {
GELOGE(FAILED, "CreateDirectory failed, file path:%s.", file_path.c_str());
return FAILED;
}
}

return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int ModelSaver::CreateDirectory(const std::string &directory_path) {
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty.");
auto dir_path_len = directory_path.length();
if (dir_path_len >= PATH_MAX) {
ErrorManager::GetInstance().ATCReportErrMessage(
"E19002", {"filepath", "size"}, {directory_path, std::to_string(PATH_MAX)});
GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), PATH_MAX);
return -1;
}
char tmp_dir_path[PATH_MAX] = {0};
for (size_t i = 0; i < dir_path_len; i++) {
tmp_dir_path[i] = directory_path[i];
if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) {
if (access(tmp_dir_path, F_OK) != 0) {
int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.",
directory_path.c_str());
return ret;
}
}
}
}
}
int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700
if (ret != 0) {
if (errno != EEXIST) {
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path});
GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str());
return ret;
}
}
return 0;
}

} // namespace parser
} // namespace ge

+ 55
- 0
parser/common/model_saver.h View File

@@ -0,0 +1,55 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_FILE_SAVER_H_
#define PARSER_COMMON_FILE_SAVER_H_

#include <string>

#include "ge/ge_api_error_codes.h"
#include "register/register_types.h"
#include "nlohmann/json.hpp"

namespace ge {
namespace parser {
using Json = nlohmann::json;
using std::string;

class ModelSaver {
public:
/**
* @ingroup domi_common
* @brief Save JSON object to file
* @param [in] file_path File output path
* @param [in] model json object
* @return Status result
*/
static Status SaveJsonToFile(const char *file_path, const Json &model);

private:
///
/// @ingroup domi_common
/// @brief Check validity of the file path
/// @return Status result
///
static Status CheckPath(const string &file_path);

static int CreateDirectory(const std::string &directory_path);
};
} // namespace parser
} // namespace ge

#endif //PARSER_COMMON_FILE_SAVER_H_

+ 95
- 0
parser/common/module.mk View File

@@ -0,0 +1,95 @@
LOCAL_PATH := $(call my-dir)

include $(CLEAR_VARS)

LOCAL_MODULE := libparser_common

LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0
LOCAL_CFLAGS += -Werror
ifeq ($(DEBUG), 1)
LOCAL_CFLAGS += -g -O0
endif

COMMON_LOCAL_SRC_FILES := \
parser_factory.cc \
data_op_parser.cc \
op_parser_factory.cc \
pre_checker.cc \
register_tbe.cc \
parser_api.cc \
parser_inner_ctx.cc \
proto_file_parser.cc \
acl_graph_parser_util.cc \
tbe_plugin_loader.cc \
model_saver.cc \
../tensorflow/tensorflow_custom_parser_adapter.cc \
../tensorflow/tensorflow_fusion_custom_parser_adapter.cc \
../tensorflow/tensorflow_fusion_op_parser.cc \
../tensorflow/tensorflow_util.cc \
convert/pb2json.cc \
op_def/ir_pb_converter.cc \
op_def/defs.cc \
op_def/op_schema.cc \
op_def/operator.cc \
op_map.cc \
parser_types.cc \
pass_manager.cc \
parser_fp16_t.cc \
thread_pool.cc \

FMK_COMMON_SRC_FILES := \
# ../../common/fmk_error_codes.cc \
../../common/auth/cipher.cc \
../../common/context/ctx.cc \
../../graph/passes/pass_manager.cc \
../../graph/common/omg_util.cc \
../../common/types.cc \
../../common/auth/file_saver.cc \
../../common/util.cc \
../../common/model_saver.cc \
../../common/fp16_t.cc \
../../common/thread_pool.cc \

LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES)
LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES)

LOCAL_C_INCLUDES := \
proto/om.proto \
proto/insert_op.proto \
proto/ge_ir.proto \
proto/tensorflow/graph.proto \
proto/tensorflow/node_def.proto \
proto/tensorflow/tensor_shape.proto \
proto/tensorflow/attr_value.proto \
proto/tensorflow/function.proto \
proto/tensorflow/op_def.proto \
proto/tensorflow/resource_handle.proto \
proto/tensorflow/tensor.proto \
proto/tensorflow/types.proto \
proto/tensorflow/versions.proto \
$(LOCAL_PATH) \
$(TOPDIR)inc \
$(TOPDIR)inc/external \
$(TOPDIR)inc/external/graph \
$(TOPDIR)inc/framework \
$(TOPDIR)inc/common/util \
$(TOPDIR)framework/domi \
$(TOPDIR)framework/domi/common \
$(TOPDIR)framework/domi/parser \
$(TOPDIR)third_party/json/include \
$(TOPDIR)third_party/protobuf/include \
libc_sec/include \
third_party/openssl/include/x86/include \

LOCAL_SHARED_LIBRARIES := \
libprotobuf \
libslog \
libgraph \
libmmpa \
libc_sec \
liberror_manager \
libregister \

LOCAL_LDFLAGS := -lrt -ldl

include $(BUILD_HOST_SHARED_LIBRARY)

+ 38
- 0
parser/common/op_def/arg_op.cc View File

@@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/op_def/arg_op.h"
#include <string>
#include "framework/common/fmk_types.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator::ArgOpOperator() : ParserOperator("Data") {}

ArgOpOperator::~ArgOpOperator() {}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator &ArgOpOperator::Name(const std::string &name) {
(void)ParserOperator::Name(name);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ArgOpOperator &ArgOpOperator::Index(int64_t index) {
Attr("index", static_cast<int64_t>(index));

return *this;
}

int64_t ArgOpOperator::GetIndex() const { return GetIntAttr("index"); }
} // namespace ge

+ 36
- 0
parser/common/op_def/arg_op.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DOMI_OP_ARG_OP_H_
#define DOMI_OP_ARG_OP_H_
#include "parser/common/op_def/operator.h"

namespace ge {
class ArgOpOperator : public ParserOperator {
public:
ArgOpOperator();

~ArgOpOperator();

ArgOpOperator &Name(const std::string &name);

ArgOpOperator &Index(int64_t index);

int64_t GetIndex() const;
};
} // namespace ge

#endif // DOMI_OP_ARG_OP_H_

+ 45
- 0
parser/common/op_def/constant_op.cc View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/op_def/constant_op.h"
#include <string>
#include <vector>

#include "graph/debug/ge_attr_define.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator::ConstantOperator() : ParserOperator("Constant") {}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator::~ConstantOperator() {}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::Name(const std::string &name) {
ParserOperator::Name(name);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::VectorAttr(
std::string key, std::vector<int64_t> &value) {
Attr(key, value);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ConstantOperator &ConstantOperator::DType(ge::DataType t) {
Attr(VAR_ATTR_DTYPE, (int64_t)t);
return *this;
}

ge::DataType ConstantOperator::GetDType() const { return (ge::DataType)GetIntAttr(VAR_ATTR_DTYPE); }
} // namespace ge

+ 37
- 0
parser/common/op_def/constant_op.h View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#ifndef DOMI_OP_CONSTANT_OP_H_
#define DOMI_OP_CONSTANT_OP_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"

namespace ge {
class ConstantOperator : public ParserOperator {
public:
ConstantOperator();
~ConstantOperator();

ConstantOperator &Name(const std::string &name);
ConstantOperator &VectorAttr(std::string key, std::vector<int64_t> &value);

ConstantOperator &DType(ge::DataType t);
ge::DataType GetDType() const;
};
} // namespace ge

#endif // DOMI_OP_CONSTANT_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT

+ 712
- 0
parser/common/op_def/defs.cc View File

@@ -0,0 +1,712 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/op_def/op_schema.h"

namespace ge {
DOMI_OP_SCHEMA(Data).Output("y");

DOMI_OP_SCHEMA(Const).Output("y");

DOMI_OP_SCHEMA(ConvolutionDepthwise)
.Input("x")
.Input("w")
.Input("b", OpSchema::Optional)
.Output("y")
.Attr("group", AttributeType::INT, static_cast<int64_t>(1))
.Attr("num_output", AttributeType::INT, static_cast<int64_t>(1))
.Attr("pad_mode", AttributeType::INT, static_cast<int64_t>(0))
.Attr("mode", AttributeType::INT, static_cast<int64_t>(1))
.Attr("pad", AttributeType::INTLIST, IntTuple{0, 0, 0, 0})
.Attr("stride", AttributeType::INTLIST, IntTuple{1, 1})
.Attr("dilation", AttributeType::INTLIST, IntTuple{1, 1})
.Attr("kernel", AttributeType::INTLIST, IntTuple{0, 0})
.Attr("before_pad", AttributeType::INTLIST, IntTuple{0, 0, 0, 0});

DOMI_OP_SCHEMA(Region)
.Input("x")
.Output("y")
.Attr("casses", AttributeType::INT, static_cast<int64_t>(20))
.Attr("coords", AttributeType::INT, static_cast<int64_t>(4))
.Attr("boxes", AttributeType::INT, static_cast<int64_t>(1))
.Attr("background", AttributeType::BOOL, static_cast<bool>(false))
.Attr("softmax", AttributeType::BOOL, static_cast<bool>(false))
.Attr("softmax_tree", AttributeType::BOOL, static_cast<bool>(false))
.Attr("yolo_version", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(Gather)
.Input("params")
.Input("indices")
.Input("axis", OpSchema::Optional)
.Output("y")
.Attr("params_type", AttributeType::INT, static_cast<int64_t>(1))
.Attr("indices_type", AttributeType::INT, static_cast<int64_t>(3))
.Attr("validate_indices", AttributeType::BOOL, static_cast<bool>(true));

DOMI_OP_SCHEMA(ArgMax)
.Input("input")
.Output("output")
.Attr("axis", AttributeType::INT, static_cast<int64_t>(0))
.Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(true))
.Attr("axis_type", AttributeType::INT, static_cast<int64_t>(3))
.Attr("outmaxval", AttributeType::BOOL, static_cast<bool>(false))
.Attr("topk", AttributeType::UINT, static_cast<uint32_t>(1));

DOMI_OP_SCHEMA(Split)
.Input("x")
.Input("axis", OpSchema::Optional)
.Output("y")
.Attr("T", AttributeType::INT, static_cast<int64_t>(1))
.Attr("num_split", AttributeType::INT, static_cast<int64_t>(1));

DOMI_OP_SCHEMA(SplitV)
.Input("x")
.Input("axis", OpSchema::Optional)
.Output("y")
.Attr("T", AttributeType::INT, static_cast<int64_t>(1))
.Attr("Tlen", AttributeType::INT, static_cast<int64_t>(1))
.Attr("num_split", AttributeType::INT, static_cast<int64_t>(1));

DOMI_OP_SCHEMA(Fill).Input("x").Input("value").Output("y").Attr("T", AttributeType::INT, static_cast<int64_t>(1));
DOMI_OP_SCHEMA(Rsqrt).Input("x").Output("y");
DOMI_OP_SCHEMA(BiasAdd)
.Input("x")
.Input("bias")
.Output("y")
.Attr("format", AttributeType::INT, static_cast<int64_t>(1));
DOMI_OP_SCHEMA(Reverse)
.Input("x")
.Input("axis")
.Output("y")
.Attr("T", AttributeType::INT, static_cast<int64_t>(1))
.Attr("Tidx", AttributeType::INT, static_cast<int64_t>(1));
DOMI_OP_SCHEMA(Unpack)
.Input("x")
.Output("y")
.Attr("T", AttributeType::INT, static_cast<int64_t>(1))
.Attr("axis", AttributeType::INT, static_cast<int64_t>(0))
.Attr("num", AttributeType::INT, static_cast<int64_t>(1));
DOMI_OP_SCHEMA(Yolo2Reorg)
.Input("x")
.Output("y")
.Attr("reverse", AttributeType::BOOL, static_cast<bool>(1))
.Attr("stride", AttributeType::INT, static_cast<int64_t>(1));

DOMI_OP_SCHEMA(ReduceSum)
.Input("x")
.Output("y")
.Attr("Tidx", AttributeType::INT, static_cast<int64_t>(1))
.Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(1));

DOMI_OP_SCHEMA(Concat)
.Input("x")
.Output("y")
.Attr("Tidx", AttributeType::INT, static_cast<int64_t>(1))
.Attr("N", AttributeType::INT, static_cast<int64_t>(1));

DOMI_OP_SCHEMA(ResizeBilinear)
.Input("x")
.Input("sizes")
.Output("y")
.Attr("output_dim_mode", AttributeType::INT, static_cast<int64_t>(1))
.Attr("align_corners", AttributeType::BOOL, static_cast<bool>(1))
.Attr("zoom_factor", AttributeType::INT, static_cast<int64_t>(1))
.Attr("shrink_factor", AttributeType::INT, static_cast<int64_t>(1))
.Attr("height", AttributeType::INT, static_cast<int64_t>(1))
.Attr("width", AttributeType::INT, static_cast<int64_t>(1))
.Attr("pad_begin", AttributeType::INT, static_cast<int64_t>(1))
.Attr("pad_end", AttributeType::INT, static_cast<int64_t>(1));

DOMI_OP_SCHEMA(LRN)
.Input("x")
.Output("y")
.Attr("lrn_normregion", AttributeType::UINT, static_cast<uint32_t>(0))
.Attr("lrn_k", AttributeType::FLOAT, static_cast<float>(1))
.Attr("lrn_localsize", AttributeType::UINT, static_cast<uint32_t>(5))
.Attr("lrn_alpha", AttributeType::FLOAT, static_cast<float>(1))
.Attr("lrn_beta", AttributeType::FLOAT, static_cast<float>(0.75));

DOMI_OP_SCHEMA(Maximum).Input("x").Input("w").Output("y");

DOMI_OP_SCHEMA(Slice)
.Input("x")
.Output("y")
.Attr("axis", AttributeType::INT, static_cast<int64_t>(2))
.AttrRequired("offsets", AttributeType::INTLIST);

DOMI_OP_SCHEMA(Pad)
.Input("x")
.Input("paddings")
.Input("constant_values", OpSchema::Optional)
.Output("y")
.Attr("T", AttributeType::INT, static_cast<int64_t>(1))
.Attr("t_paddings", AttributeType::INT, static_cast<int64_t>(1))
.Attr("mode", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(PadV2)
.Input("input")
.Output("output")
.Attr("constant_values", AttributeType::INT, static_cast<int64_t>(0))
.AttrRequired("paddings", AttributeType::INTLIST);

DOMI_OP_SCHEMA(MirrorPad)
.Input("input")
.Output("output")
.AttrRequired("paddings", AttributeType::INTLIST)
.Attr("mode", AttributeType::INT, static_cast<int64_t>(2));

DOMI_OP_SCHEMA(Upsample)
.Input("input")
.Input("scales")
.Output("output")
.Attr("mode", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(Cast)
.Input("x")
.Output("y")
.Attr("DstT", AttributeType::INT, static_cast<int64_t>(1))
.Attr("SrcT", AttributeType::INT, static_cast<int64_t>(1));
DOMI_OP_SCHEMA(LogicalNot).Input("x").Output("y");
DOMI_OP_SCHEMA(LogicalAnd).Input("x1").Input("x2").Output("y");
DOMI_OP_SCHEMA(LogicalOr).Input("x1").Input("x2").Output("y");
DOMI_OP_SCHEMA(Equal).Input("x1").Input("x2").Output("y").Attr("T", AttributeType::INT, static_cast<int64_t>(1));

DOMI_OP_SCHEMA(MatMul)
.Input("a")
.Input("b")
.Output("product")
.Attr("transposeX", AttributeType::BOOL, static_cast<bool>(false))
.Attr("transposeW", AttributeType::BOOL, static_cast<bool>(false));

DOMI_OP_SCHEMA(RNN)
.Input("x")
.Input("cont")
.Input("xstatic", OpSchema::Optional)
.Input("w") // filter
.Input("b") // bias
.Input("seqlen") // T
.Input("hx") // Hx
.Input("cx") // cx
.Output("y")
.Output("cyfw")
.Output("hyfw")
.Output("cybw")
.Output("hybw")
.Attr("hidden_size", AttributeType::INT, static_cast<int64_t>(0))
.Attr("num_layers", AttributeType::INT, static_cast<int64_t>(1))
.Attr("support_cont", AttributeType::BOOL, static_cast<bool>(false))
.Attr("support_xstatic", AttributeType::BOOL, static_cast<bool>(false))
.Attr("input_mode", AttributeType::INT, static_cast<int64_t>(0))
.Attr("direction_mode", AttributeType::INT, static_cast<int64_t>(0))
.Attr("mode", AttributeType::INT, static_cast<int64_t>(0))
.Attr("input_data_layout", AttributeType::INT, static_cast<int64_t>(0))
.Attr("output_data_layout", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(FrameworkOp).Attr("framework_type", AttributeType::INT, static_cast<int64_t>(3));
DOMI_OP_SCHEMA(Multinomial)
.Input("logits")
.Output("output")
.Attr("num_samples", AttributeType::INT, static_cast<int64_t>(0))
.AttrRequired("seed", AttributeType::INT)
.AttrRequired("seed2", AttributeType::INT);
DOMI_OP_SCHEMA(ReverseSequence)
.Input("input")
.Input("seq_lengths")
.Output("output")
.AttrRequired("seq_dim", AttributeType::INT)
.AttrRequired("batch_dim", AttributeType::INT);

DOMI_OP_SCHEMA(Interp)
.Input("x")
.Output("y")
.Attr("output_dim_mode", AttributeType::INT, static_cast<int64_t>(2))
.Attr("zoom_factor", AttributeType::INT, static_cast<int64_t>(1))
.Attr("shrink_factor", AttributeType::INT, static_cast<int64_t>(1))
.Attr("height", AttributeType::INT, static_cast<int64_t>(0))
.Attr("width", AttributeType::INT, static_cast<int64_t>(0))
.Attr("pad_begin", AttributeType::INT, static_cast<int64_t>(0))
.Attr("pad_end", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(ShuffleChannel).Input("x").Output("y").Attr("group", AttributeType::UINT, static_cast<uint32_t>(1));

DOMI_OP_SCHEMA(Conv2DBackpropFilter)
.Input("x")
.Input("w")
.Input("b", OpSchema::Optional)
.Output("y")
.Attr("padding", AttributeType::INT, static_cast<int64_t>(1))
.Attr("pads", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0})
.Attr("strides", AttributeType::UINTLIST, UintTuple{1, 1})
.Attr("dilations", AttributeType::UINTLIST, UintTuple{1, 1});

DOMI_OP_SCHEMA(Conv2DBackpropInput)
.Input("input_sizes")
.Input("filter")
.Input("out_backprop")
.Output("output")
.Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC"))
.Attr("group", AttributeType::UINT, static_cast<uint32_t>(1))
.Attr("padding", AttributeType::INT, static_cast<int64_t>(0))
.Attr("dilations", AttributeType::UINTLIST, UintTuple{1, 1})
.Attr("strides", AttributeType::UINTLIST, UintTuple{1, 1})
.Attr("pad", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0});
DOMI_OP_SCHEMA(BiasAddGrad).Input("dy").Output("db").Attr("format", AttributeType::INT, static_cast<int64_t>(1));
DOMI_OP_SCHEMA(ReluGrad).Input("dy").Input("x").Output("dx");

DOMI_OP_SCHEMA(MeanGrad).Input("dy").Output("dx");

DOMI_OP_SCHEMA(NonMaxSuppression)
.Input("boxes")
.Input("scores")
.Output("selected_indices")
.Attr("max_output_size", AttributeType::INT, static_cast<int64_t>(-1))
.Attr("iou_threshold", AttributeType::FLOAT, static_cast<float>(0.5))
.Attr("score_threshold", AttributeType::FLOAT, static_cast<float>(-1));

DOMI_OP_SCHEMA(CropAndResize)
.Input("image")
.Input("boxes")
.Input("box_ind")
.Output("crops")
.Attr("method", AttributeType::INT, static_cast<int64_t>(0))
.Attr("extrapolation_value", AttributeType::FLOAT, static_cast<float>(0))
.Attr("crop_size_h", AttributeType::INT, static_cast<int64_t>(0))
.Attr("crop_size_w", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(TopKV2)
.Input("input")
.Input("k")
.Output("value")
.Output("indices")
.AttrRequired("sorted", AttributeType::BOOL);

DOMI_OP_SCHEMA(InvertPermutation).Input("x").Output("y");

DOMI_OP_SCHEMA(GatherV2)
.Input("params")
.Input("indices")
.Input("axis", OpSchema::Optional)
.Output("y")
.Attr("Tparams", AttributeType::INT, static_cast<int64_t>(0)) // default: DT_FLOAT
.Attr("Tindices", AttributeType::INT, static_cast<int64_t>(3)) // default: DT_INT32
.Attr("Taxis", AttributeType::INT, static_cast<int64_t>(3)); // default: DT_INT32

DOMI_OP_SCHEMA(HighWay)
.Input("x")
.Input("tw") // filter
.Input("tb") // bias
.Input("uw") // filter
.Input("ub") // bias
.Output("y");

DOMI_OP_SCHEMA(Reciprocal).Input("x").Output("y");

DOMI_OP_SCHEMA(Asinh).Input("input").Output("output");

DOMI_OP_SCHEMA(Acosh).Input("input").Output("output");

DOMI_OP_SCHEMA(Minimum).Input("x").Input("y").Output("output");

DOMI_OP_SCHEMA(Clip).Input("input").Input("min").Input("max").Output("output");

DOMI_OP_SCHEMA(FusedBatchNorm)
.Input("x")
.Input("scale")
.Input("offset")
.Input("mean")
.Input("variance")
.Output("y")
.Output("batch_mean")
.Output("batch_variance")
.Output("reserve_space_1")
.Output("reserve_space_2")
.Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC"))
.Attr("epsilon", AttributeType::FLOAT, static_cast<float>(0.0001))
.Attr("is_training", AttributeType::BOOL, static_cast<bool>(false));

DOMI_OP_SCHEMA(FusedBatchNormGrad)
.Input("dy")
.Input("x")
.Input("bnscale")
.Input("save_mean")
.Input("save_variance")
.Output("dx")
.Output("result_bn_scale_diff")
.Output("result_bn_bias_diff")
.Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC"))
.Attr("epsilon", AttributeType::FLOAT, static_cast<float>(0.0))
.Attr("is_training", AttributeType::BOOL, static_cast<bool>(true));

DOMI_OP_SCHEMA(MaxPoolWithArgmax)
.Input("x")
.Output("y")
.Output("argmax")
.AttrRequired("window", AttributeType::INTLIST)
.AttrRequired("stride", AttributeType::INTLIST)
.AttrRequired("pad_mode", AttributeType::INT)
.AttrRequired("ceil_mode", AttributeType::BOOL)
.AttrRequired("data_mode", AttributeType::INT);

DOMI_OP_SCHEMA(MaxPoolGradWithArgmax)
.Input("input")
.Input("grad")
.Output("output")
.AttrRequired("window", AttributeType::INTLIST)
.AttrRequired("stride", AttributeType::INTLIST)
.AttrRequired("pad_mode", AttributeType::INT)
.AttrRequired("ceil_mode", AttributeType::BOOL)
.AttrRequired("data_mode", AttributeType::INT);

DOMI_OP_SCHEMA(HcomBroadcast)
.AttrRequired("root_rank", AttributeType::INT)
.AttrRequired("group", AttributeType::STRING);

DOMI_OP_SCHEMA(HcomAllReduce)
.Input("x")
.Output("y")
.AttrRequired("reduction", AttributeType::STRING)
.AttrRequired("group", AttributeType::STRING);

DOMI_OP_SCHEMA(HcomAllGather)
.Input("x")
.Output("y")
.AttrRequired("rank_size", AttributeType::INT)
.AttrRequired("group", AttributeType::STRING);

DOMI_OP_SCHEMA(SparseSoftmaxCrossEntropyWithLogits)
.Input("features")
.Input("labels")
.Output("loss")
.Output("backprop")
.AttrRequired("T", AttributeType::INT)
.Attr("Tlabels", AttributeType::INT, static_cast<int64_t>(9));

DOMI_OP_SCHEMA(Snapshot).Input("input").Output("output").AttrRequired("T", AttributeType::INT);

DOMI_OP_SCHEMA(ReduceProd)
.Input("bottom")
.Output("top")
.AttrRequired("axes", AttributeType::INTLIST)
.Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(false));

DOMI_OP_SCHEMA(ReduceAll)
.Input("x")
.Output("y")
.AttrRequired("axes", AttributeType::INTLIST)
.Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(false));

DOMI_OP_SCHEMA(ReduceMax)
.Input("x")
.Output("y")
.AttrRequired("axis", AttributeType::INTLIST)
.Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(false));

DOMI_OP_SCHEMA(AddN).Input("x").Output("y");

DOMI_OP_SCHEMA(ShapeN)
.Input("x")
.Output("y")
.AttrRequired("N", AttributeType::INT)
.AttrRequired("in_type", AttributeType::INT)
.AttrRequired("dtype", AttributeType::INT);

DOMI_OP_SCHEMA(ReduceMin)
.Input("x")
.Output("y")
.AttrRequired("axis", AttributeType::INTLIST)
.Attr("keep_dims", AttributeType::BOOL, static_cast<bool>(false));

DOMI_OP_SCHEMA(Sqrt).Input("x").Output("y");

DOMI_OP_SCHEMA(L2Loss).Input("x").Output("y");

DOMI_OP_SCHEMA(Multiply).Input("x").Input("y").Output("z");

DOMI_OP_SCHEMA(Add).Input("x").Output("y");

DOMI_OP_SCHEMA(Constant).Output("y");

DOMI_OP_SCHEMA(ApplyMomentum)
.Input("variable")
.Input("accumulation")
.Input("learningRate")
.Input("gradient")
.Input("momuntum")
.Input("fp16variable")
.Attr("algo", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(AvgPoolGrad)
.Input("shape")
.Input("grad")
.Output("output")
.Attr("padding", AttributeType::INT, static_cast<int64_t>(0))
.Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC"))
.Attr("strides", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0})
.Attr("ksize", AttributeType::UINTLIST, UintTuple{0, 0, 0, 0});

DOMI_OP_SCHEMA(Lars)
.Input("w")
.Input("g")
.Input("weight_decay")
.Output("y")
.Attr("hyperpara", AttributeType::FLOAT, static_cast<float>(0.001))
.Attr("epsilon", AttributeType::FLOAT, static_cast<float>(0.00001));

DOMI_OP_SCHEMA(AssignSub)
.Input("variable")
.Input("input")
.Input("output")
.Attr("mode", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(AssignAdd)
.Input("variable")
.Input("input")
.Output("output")
.Attr("mode", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(SpaceToBatchND).Input("input").Input("block_shape").Input("paddings").Output("output");

DOMI_OP_SCHEMA(Variable)
.Output("variable")
.Attr("container", AttributeType::STRING, static_cast<std::string>(""))
.Attr("shared_name", AttributeType::STRING, static_cast<std::string>(""))
.AttrRequired("dtype", AttributeType::INT);

DOMI_OP_SCHEMA(Assign).Input("variable").Input("value").Output("variable");

DOMI_OP_SCHEMA(VarIsInitializedOp).Input("variable").Output("value");

DOMI_OP_SCHEMA(NoOp).Attr("algo", AttributeType::INT, static_cast<int64_t>(0));

DOMI_OP_SCHEMA(LogTimeStamp)
.Attr("logid", AttributeType::STRING, static_cast<std::string>(""))
.Attr("notify", AttributeType::BOOL, static_cast<bool>(false));

DOMI_OP_SCHEMA(ResizeNearestNeighbor)
.Input("images")
.Output("resized_images")
.Attr("align_corners", AttributeType::BOOL, static_cast<bool>(false))
.AttrRequired("height", AttributeType::INT)
.AttrRequired("width", AttributeType::INT);

DOMI_OP_SCHEMA(BatchToSpaceND).Input("input").Input("block_shape").Input("crops").Output("output");

DOMI_OP_SCHEMA(Assert).Input("x").Input("w").Output("y");

DOMI_OP_SCHEMA(Pow).Input("x").Input("y").Output("z");

DOMI_OP_SCHEMA(GreaterEqual).Input("x1").Input("x2").Output("y");

DOMI_OP_SCHEMA(SpaceToDepth)
.Input("input")
.Output("output")
.Attr("block_size", AttributeType::INT, static_cast<int64_t>(0))
.AttrRequired("T", AttributeType::INT)
.Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC"));

DOMI_OP_SCHEMA(DepthToSpace)
.Input("input")
.Output("output")
.Attr("block_size", AttributeType::INT, static_cast<int64_t>(0))
.AttrRequired("T", AttributeType::INT)
.Attr("data_format", AttributeType::STRING, static_cast<std::string>("NHWC"));

DOMI_OP_SCHEMA(Rint).Input("input").Output("output").AttrRequired("T", AttributeType::INT);

DOMI_OP_SCHEMA(ExtractImagePatches)
.Input("images")
.Output("y")
.AttrRequired("ksizes", AttributeType::INTLIST)
.AttrRequired("strides", AttributeType::INTLIST)
.AttrRequired("rates", AttributeType::INTLIST)
.AttrRequired("padding", AttributeType::STRING);

DOMI_OP_SCHEMA(Atan).Input("x").Output("output");

DOMI_OP_SCHEMA(Atanh).Input("x").Output("output");

DOMI_OP_SCHEMA(Acos).Input("x").Output("y");

DOMI_OP_SCHEMA(Asin).Input("x").Output("y");

DOMI_OP_SCHEMA(Log)
.Input("x")
.Output("output")
.AttrRequired("scale", AttributeType::INT)
.AttrRequired("shift", AttributeType::INT)
.AttrRequired("base", AttributeType::INT);

DOMI_OP_SCHEMA(Neg).Input("input").Output("output");

DOMI_OP_SCHEMA(Tan).Input("x").Output("output");

DOMI_OP_SCHEMA(Round).Input("x").Output("output");

DOMI_OP_SCHEMA(Exp)
.Input("x")
.Output("y")
.Attr("scale", AttributeType::FLOAT, static_cast<float>(1))
.Attr("shift", AttributeType::FLOAT, static_cast<float>(0))
.Attr("base", AttributeType::FLOAT, static_cast<float>(-1));

DOMI_OP_SCHEMA(Less).Input("x").Input("y").Output("output");

DOMI_OP_SCHEMA(LessEqual).Input("x").Input("y").Output("output");

DOMI_OP_SCHEMA(OneHot).Input("indices").Input("depth").Input("on_value").Input("off_value").Output("output");

DOMI_OP_SCHEMA(ZerosLike).Input("x").Output("y");

DOMI_OP_SCHEMA(Where).Input("x").Output("y");

DOMI_OP_SCHEMA(RefSwitch).Input("x").Output("y");

DOMI_OP_SCHEMA(FakeQuantWithMinMaxVars)
.Input("x")
.Input("min")
.Input("max")
.Output("y")
.Attr("narrow_range", AttributeType::BOOL, static_cast<bool>(false))
.Attr("num_bits", AttributeType::INT, static_cast<int64_t>(8));

DOMI_OP_SCHEMA(Sinh).Input("x").Output("y");

DOMI_OP_SCHEMA(Cosh).Input("x").Output("y");

DOMI_OP_SCHEMA(Floor).Input("x").Output("output");

DOMI_OP_SCHEMA(RandomUniform).Input("input").Output("output");

DOMI_OP_SCHEMA(BatchMatMul).Input("x").Input("y").Output("output");

DOMI_OP_SCHEMA(FloorMod).Input("x").Input("y").Output("output");

DOMI_OP_SCHEMA(SquaredDifference).Input("x").Input("y").Output("output");

DOMI_OP_SCHEMA(LayerNorm).Input("x").Output("output").AttrRequired("Epsilon", AttributeType::FLOAT);

DOMI_OP_SCHEMA(SSDPostProcessor)
.Input("trueImgShape")
.Input("boxEncoding")
.Input("anchors")
.Input("clsPred")
.Output("detectBoxes")
.Output("detectScores")
.Output("detectNum")
.Output("detectClasses")
.AttrRequired("numClasses", AttributeType::INT)
.AttrRequired("scoreThreshold", AttributeType::FLOAT)
.AttrRequired("iouThreshold", AttributeType::FLOAT)
.AttrRequired("maxDetectionsPerClass", AttributeType::INT)
.AttrRequired("maxTotalDetections", AttributeType::INT)
.AttrRequired("boxTypeNum", AttributeType::UINT)
.AttrRequired("scaleFactors_0", AttributeType::UINT)
.AttrRequired("scaleFactors_1", AttributeType::UINT)
.AttrRequired("scaleFactors_2", AttributeType::UINT)
.AttrRequired("scaleFactors_3", AttributeType::UINT)
.AttrRequired("imgH", AttributeType::INT)
.AttrRequired("imgW", AttributeType::INT)
.AttrRequired("useStaticShape", AttributeType::BOOL)
.AttrRequired("convertScoresMode", AttributeType::INT);

DOMI_OP_SCHEMA(RetinaPostProcessor)
.Input("anchors")
.Input("regression")
.Input("classification")
.Output("detectBoxes")
.Output("detectScores")
.Output("detectLabels")
.Output("detectNum")
.AttrRequired("numClasses", AttributeType::INT)
.AttrRequired("maxDetections", AttributeType::INT)
.AttrRequired("nmsThreshold", AttributeType::FLOAT)
.AttrRequired("scoreThreshold", AttributeType::FLOAT)
.AttrRequired("imgH", AttributeType::INT)
.AttrRequired("imgW", AttributeType::INT)
.AttrRequired("boxTypeNum", AttributeType::UINT)
.AttrRequired("means", AttributeType::FLOATLIST)
.AttrRequired("stds", AttributeType::FLOATLIST);

DOMI_OP_SCHEMA(ROIInterPooling)
.Input("input")
.Input("input_1")
.Output("maxPool")
.AttrRequired("hStride", AttributeType::INT)
.AttrRequired("wStride", AttributeType::INT)
.AttrRequired("hKernel", AttributeType::INT)
.AttrRequired("wKernel", AttributeType::INT)
.AttrRequired("hResize", AttributeType::INT)
.AttrRequired("wResize", AttributeType::INT)
.AttrRequired("hFeatureMap", AttributeType::INT)
.AttrRequired("wFeatureMap", AttributeType::INT);

DOMI_OP_SCHEMA(FirstStageProcessor)
.Input("anchors")
.Input("boxEncoding")
.Input("clsPred")
.Input("trueImgShape")
.Output("detectBoxes")
.Output("detectScores")
.Output("detectLables")
.Output("detectNum")
.AttrRequired("scaleFactorsNum", AttributeType::INT)
.AttrRequired("iouThreshold", AttributeType::FLOAT)
.AttrRequired("scoreThreshold", AttributeType::FLOAT)
.AttrRequired("maxSizePerClass", AttributeType::INT)
.AttrRequired("maxTotalSize", AttributeType::INT)
.AttrRequired("imgH", AttributeType::INT)
.AttrRequired("imgW", AttributeType::INT)
.AttrRequired("boxTypeNum", AttributeType::UINT)
.AttrRequired("scaleFactors_0", AttributeType::UINT)
.AttrRequired("scaleFactors_1", AttributeType::UINT)
.AttrRequired("scaleFactors_2", AttributeType::UINT)
.AttrRequired("scaleFactors_3", AttributeType::UINT);

DOMI_OP_SCHEMA(SecondStageProcessor)
.Input("anchors")
.Input("boxEncoding")
.Input("clsPred")
.Input("validBoxNum")
.Input("trueImgShape")
.Output("detectBoxes")
.Output("detectScores")
.Output("detectLables")
.Output("detectNum")
.AttrRequired("scaleFactorsNum", AttributeType::INT)
.AttrRequired("iouThreshold", AttributeType::FLOAT)
.AttrRequired("scoreThreshold", AttributeType::FLOAT)
.AttrRequired("maxSizePerClass", AttributeType::INT)
.AttrRequired("maxTotalSize", AttributeType::INT)
.AttrRequired("numClasses", AttributeType::INT)
.AttrRequired("scaleFactors_0", AttributeType::UINT)
.AttrRequired("scaleFactors_1", AttributeType::UINT)
.AttrRequired("scaleFactors_2", AttributeType::UINT)
.AttrRequired("scaleFactors_3", AttributeType::UINT);

DOMI_OP_SCHEMA(StreamSwitch)
.Input("loopIndex")
.Input("itersPerLoop")
.AttrRequired("switch_condition", AttributeType::UINT)
.AttrRequired("true_branch_stream", AttributeType::INT);

DOMI_OP_SCHEMA(StreamActive).AttrRequired("active_stream_list", AttributeType::INTLIST);

DOMI_OP_SCHEMA(MemcpyAsync).Input("in").Output("out");

DOMI_OP_SCHEMA(CleanAddr)
.AttrRequired("automic_add_addr_start", AttributeType::INT)
.AttrRequired("automic_add_mem_size", AttributeType::INT);
} // namespace ge

+ 45
- 0
parser/common/op_def/fill_op.cc View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/op_def/fill_op.h"
#include "framework/common/fmk_types.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator::FillOperator() : ParserOperator("Fill") {}

FMK_FUNC_DEV_VISIBILITY FillOperator::~FillOperator() {}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::DataType(int64_t dataType) {
Attr("T", static_cast<int64_t>(dataType));
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::Alpha(float alpha) {
Attr("alpha", static_cast<float>(alpha));
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FillOperator &FillOperator::Beta(float beta) {
Attr("beta", static_cast<float>(beta));
return *this;
}

int64_t FillOperator::GetDataType() const { return GetIntAttr("T"); }

float FillOperator::GetAlpha() const { return GetFloatAttr("alpha"); }

float FillOperator::GetBeta() const { return GetFloatAttr("beta"); }
} // namespace ge

+ 42
- 0
parser/common/op_def/fill_op.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DOMI_OP_FILL_OP_H_
#define DOMI_OP_FILL_OP_H_
#include "parser/common/op_def/operator.h"

namespace ge {
class FillOperator : public ParserOperator {
public:
FillOperator();

~FillOperator();

FillOperator &DataType(int64_t dataType);

FillOperator &Alpha(float alpha);

FillOperator &Beta(float beta);

int64_t GetDataType() const;

float GetAlpha() const;

float GetBeta() const;
};
} // namespace ge

#endif // DOMI_OP_FILL_OP_H_

+ 74
- 0
parser/common/op_def/frameworkop_op.cc View File

@@ -0,0 +1,74 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/op_def/frameworkop_op.h"
#include <string>
#include "framework/common/fmk_types.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator::FrameworkOpOperator()
: ParserOperator("FrameworkOp") {}

FrameworkOpOperator::~FrameworkOpOperator() {}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Name(
const std::string &name) {
ParserOperator::Name(name);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Index(int64_t index) {
Attr(RETVAL_ATTR_NAME_INDEX, static_cast<int64_t>(index));
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::NodeDefPkg(
const std::string &nodedef_pkg) {
Attr_bt(ATTR_NAME_FRAMEWORK_NODE_DEF, nodedef_pkg);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::Frameworktype(
int64_t framework_type) {
Attr(ATTR_NAME_FRAMEWORK_FWK_TYPE, static_cast<int64_t>(framework_type));
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::TfOpDef(
const std::string &opdef_string) {
Attr(ATTR_NAME_FRAMEWORK_OP_DEF, opdef_string);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::OriginalType(
const std::string &type) {
Attr(ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FrameworkOpOperator &FrameworkOpOperator::FuncDefPkg(const std::string &func_string) {
Attr_bt(ATTR_NAME_FRAMEWORK_FUNC_DEF, func_string);
return *this;
}

FMK_FUNC_HOST_VISIBILITY int64_t FrameworkOpOperator::GetFrameworkType() const {
return GetIntAttr(ATTR_NAME_FRAMEWORK_FWK_TYPE);
}

FMK_FUNC_HOST_VISIBILITY std::string FrameworkOpOperator::GetNodeDefPkg() const {
return GetStringAttr(ATTR_NAME_FRAMEWORK_NODE_DEF);
}
} // namespace ge

+ 49
- 0
parser/common/op_def/frameworkop_op.h View File

@@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_
#define DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_
#include "graph/debug/ge_attr_define.h"
#include "parser/common/op_def/operator.h"

namespace ge {
class FrameworkOpOperator : public ParserOperator {
public:
FrameworkOpOperator();

~FrameworkOpOperator();

FrameworkOpOperator &Name(const std::string &name);

FrameworkOpOperator &OriginalType(const std::string &type);

FrameworkOpOperator &NodeDefPkg(const std::string &nodedef_pkg);

FrameworkOpOperator &Frameworktype(int64_t framework_type);

FrameworkOpOperator &TfOpDef(const std::string &opdef_string);

FrameworkOpOperator &Index(int64_t index);

FrameworkOpOperator &FuncDefPkg(const std::string &func_string);

int64_t GetFrameworkType() const;

std::string GetNodeDefPkg() const;
};
} // namespace ge

#endif // DOMI_OP_FRAMEWORKOP_OP_OPERATOR_H_

+ 205
- 0
parser/common/op_def/ir_pb_converter.cc View File

@@ -0,0 +1,205 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/op_def/ir_pb_converter.h"
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "google/protobuf/map.h"
#include "graph/ge_tensor.h"
#include "graph/buffer.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/util.h"

namespace ge {
static void ConvertList(const std::pair<std::string, OpAttribute> &op_attr_pair, ge::OpDescPtr op_def) {
domi::AttrDef_ListValue a_list = op_attr_pair.second.value_.list();

vector<int64_t> v_i;
for (int32_t i = 0; i < a_list.i_size(); i++) {
v_i.push_back((int64_t)a_list.i(i));
}
if (v_i.size() > 0) {
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i);
return;
}
vector<float> v_f;
for (int32_t i = 0; i < a_list.f_size(); i++) {
v_f.push_back(a_list.f(i));
}
if (v_f.size() > 0) {
(void)ge::AttrUtils::SetListFloat(op_def, op_attr_pair.first, v_f);
return;
}
vector<bool> v_b;
for (int32_t i = 0; i < a_list.b_size(); i++) {
v_b.push_back(a_list.b(i));
}
if (v_b.size() > 0) {
(void)ge::AttrUtils::SetListBool(op_def, op_attr_pair.first, v_b);
return;
}
vector<int32_t> v_u;
for (int32_t i = 0; i < a_list.u_size(); i++) {
v_u.push_back((int32_t)a_list.u(i));
}
if (v_u.size() > 0) {
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_u);
return;
}
// set for empty list
(void)ge::AttrUtils::SetListInt(op_def, op_attr_pair.first, v_i);
GELOGI("set empty list for node %s attr %s", op_def->GetName().c_str(), op_attr_pair.first.c_str());
}

static void UpdateTensorForOpDesc(const ParserOperator &op, ge::OpDescPtr op_def) {
if (op_def == nullptr) {
return;
}
uint32_t in_index = 0;
for (const ge::GeTensorDesc &input_desc : op.GetInputTensorDesc()) {
if (in_index < op_def->GetInputsSize()) {
(void)op_def->UpdateInputDesc(in_index++, input_desc);
} else {
(void)op_def->AddInputDesc(input_desc);
in_index++;
}
}

uint32_t out_index = 0;
for (const ge::GeTensorDesc &output_desc : op.GetOutputTensorDesc()) {
if (out_index < op_def->GetOutputsSize()) {
op_def->UpdateOutputDesc(out_index++, output_desc);
} else {
op_def->AddOutputDesc(output_desc);
out_index++;
}
}
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(const ParserOperator &op,
ge::OpDescPtr op_def) {
GE_RETURN_WITH_LOG_IF_TRUE(op_def == nullptr, "parameter is null.");
GE_CHK_BOOL_RET_STATUS(op.GetSchema(), domi::PARAM_INVALID, "Op schema is null, op type: %s", op.GetType().c_str());
op_def->SetName(op.GetName());
op_def->SetType(op.GetType());
GE_IF_BOOL_EXEC(op.GetType() == ge::parser::YOLO, op_def->SetType(ge::parser::REGION));

UpdateTensorForOpDesc(op, op_def);
GELOGD("Convert to op desc: name:%s, input size: %zu, output size:%zu", op_def->GetName().c_str(),
op_def->GetInputsSize(), op_def->GetOutputsSize());

for (const auto &op_attr_pair : op.GetOpAttrs()) {
if (op_attr_pair.second.value_.has_list()) {
ConvertList(op_attr_pair, op_def);
} else {
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kBt) {
auto &buffer = op_attr_pair.second.value_.bt();
(void)ge::AttrUtils::SetZeroCopyBytes(op_def, op_attr_pair.first,
ge::Buffer::CopyFrom(reinterpret_cast<uint8_t *>(const_cast<char *>(buffer.data())), buffer.size()));
}

if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kS) {
(void)ge::AttrUtils::SetStr(op_def, op_attr_pair.first, op_attr_pair.second.value_.s());
}
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kI) {
(void)ge::AttrUtils::SetInt(op_def, op_attr_pair.first, op_attr_pair.second.value_.i());
}
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kF) {
(void)ge::AttrUtils::SetFloat(op_def, op_attr_pair.first, op_attr_pair.second.value_.f());
}
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kB) {
(void)ge::AttrUtils::SetBool(op_def, op_attr_pair.first, op_attr_pair.second.value_.b());
}
if (op_attr_pair.second.value_.value_case() == domi::AttrDef::kU) {
(void)ge::AttrUtils::SetInt(op_def, op_attr_pair.first, op_attr_pair.second.value_.u());
}
}
}
GE_CHK_BOOL_RET_STATUS(op.GetSchema()->Verify(op_def), domi::PARAM_INVALID, "Op schema verify failed, op name: %s",
op.GetName().c_str());

return domi::SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertFromOpDesc(const ge::OpDescPtr op_def,
ParserOperator &op) {
GE_RETURN_WITH_LOG_IF_TRUE(op_def == nullptr, "parameter is null.");
op.Name(op_def->GetName());

map<string, ge::GeAttrValue> allattrs = op_def->GetAllAttrs();

for (const auto &attr : allattrs) {
ge::GeAttrValue::ValueType v_t = attr.second.GetValueType();
switch (v_t) {
case ge::GeAttrValue::ValueType::VT_LIST_STRING: {
std::vector<string> vec;
(void)ge::AttrUtils::GetListStr(op_def, attr.first, vec);
op.Attr(attr.first, vec);
break;
}
case ge::GeAttrValue::ValueType::VT_LIST_FLOAT: {
std::vector<float> vec;
(void)ge::AttrUtils::GetListFloat(op_def, attr.first, vec);
op.Attr(attr.first, vec);
break;
}
case ge::GeAttrValue::ValueType::VT_LIST_BOOL: {
std::vector<bool> vec;
(void)ge::AttrUtils::GetListBool(op_def, attr.first, vec);
op.Attr(attr.first, vec);
break;
}
case ge::GeAttrValue::ValueType::VT_LIST_INT: {
std::vector<int64_t> vec;
(void)ge::AttrUtils::GetListInt(op_def, attr.first, vec);
op.Attr(attr.first, vec);
break;
}
case ge::GeAttrValue::ValueType::VT_STRING: {
string s = "";
(void)ge::AttrUtils::GetStr(op_def, attr.first, s);
op.Attr(attr.first, s);
break;
}
case ge::GeAttrValue::ValueType::VT_FLOAT: {
float f = 0.0;
(void)ge::AttrUtils::GetFloat(op_def, attr.first, f);
op.Attr(attr.first, f);
break;
}
case ge::GeAttrValue::ValueType::VT_BOOL: {
bool b = false;
(void)ge::AttrUtils::GetBool(op_def, attr.first, b);
op.Attr(attr.first, b);
break;
}
case ge::GeAttrValue::ValueType::VT_INT: {
int64_t i = 0;
(void)ge::AttrUtils::GetInt(op_def, attr.first, i);
op.Attr(attr.first, i);
break;
}
default:
break;
}
}

return domi::SUCCESS;
}
} // namespace ge

+ 36
- 0
parser/common/op_def/ir_pb_converter.h View File

@@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H
#define DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H

#include "framework/common/fmk_error_codes.h"
#include "common/op_def/op_schema.h"
#include "parser/common/op_def/operator.h"
#include "graph/ge_attr_value.h"
#include "graph/ge_tensor.h"
#include "graph/op_desc.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "proto/om.pb.h"

namespace ge {
domi::Status ConvertToOpDesc(const ParserOperator &op, ge::OpDescPtr op_def);

domi::Status ConvertFromOpDesc(const ge::OpDescPtr op_def, ParserOperator &op);
} // namespace ge

#endif // DOMI_COMMON_OP_DEF_IR_PB_CONVERTER_H

+ 30
- 0
parser/common/op_def/no_op_op.cc View File

@@ -0,0 +1,30 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/no_op_op.h"
#include <string>

namespace ge {
FMK_FUNC_HOST_VISIBILITY NoOpOperator::NoOpOperator() : ParserOperator("NoOp") {}

FMK_FUNC_HOST_VISIBILITY NoOpOperator::~NoOpOperator() {}

FMK_FUNC_HOST_VISIBILITY NoOpOperator &NoOpOperator::Name(const std::string &name) {
ParserOperator::Name(name);
return *this;
}
} // namespace ge

+ 33
- 0
parser/common/op_def/no_op_op.h View File

@@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#ifndef DOMI_OP_NO_OP_OP_H_
#define DOMI_OP_NO_OP_OP_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"

namespace ge {
class NoOpOperator : public ParserOperator {
public:
NoOpOperator();
~NoOpOperator();

NoOpOperator &Name(const std::string &name);
};
} // namespace ge

#endif // DOMI_OP_NO_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT

+ 215
- 0
parser/common/op_def/op_schema.cc View File

@@ -0,0 +1,215 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/op_def/op_schema.h"
#include <iostream>
#include <utility>
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"

namespace ge {
OpSchema::FormalParameter::FormalParameter(const std::string &name, FormalParameterOption param_option)
: name_(name), param_option_(param_option) {}

OpSchema::FormalParameter::~FormalParameter() {}

const std::string &OpSchema::FormalParameter::Name() const { return name_; }

OpSchema::FormalParameterOption OpSchema::FormalParameter::Option() const { return param_option_; }

OpSchema::OpSchema(const std::string &name) : name_(name) {}

OpSchema::~OpSchema() {}

OpSchema &OpSchema::Input(const std::string &name, FormalParameterOption param_option) {
inputs_.emplace_back(FormalParameter(name, param_option));
return *this;
}

OpSchema &OpSchema::Output(const std::string &name, FormalParameterOption param_option) {
outputs_.emplace_back(FormalParameter(name, param_option));
return *this;
}

OpSchema &OpSchema::Attr(const Attribute &attr) {
(void)attributes_.insert(std::make_pair(attr.name_, attr));
return *this;
}

#if defined(CFG_BUILD_DEBUG)
#define ATTR_SETTER_WITH_SINGLE_VALUE(Type, field, attrtype) \
OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Type &default_value) { \
if (attrtype != attr_type) { \
GELOGE(FAILED, "Attribute specification param_type mismatch, input attr type %u, required attr type %u.", \
(uint32_t)attr_type, (uint32_t)attrtype); \
return *this; \
} \
\
domi::AttrDef a; \
a.set_##field(default_value); \
Attr(Attribute(name, attr_type, a)); \
return *this; \
}
#else
#define ATTR_SETTER_WITH_SINGLE_VALUE(Type, field, attrtype) \
OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Type &default_value) { \
if (attrtype != attr_type) { \
return *this; \
} \
domi::AttrDef a; \
a.set_##field(default_value); \
Attr(Attribute(name, attr_type, a)); \
return *this; \
}

#endif

#if defined(CFG_BUILD_DEBUG)
#define ATTR_SETTER_WITH_LIST_VALUE(Type, field, attrtype) \
OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const std::vector<Type> &default_value) { \
if (attrtype != attr_type) { \
GELOGE(FAILED, "Attribute specification vector param_type mismatch, input attr type %u, required attr type %u.", \
(uint32_t)attr_type, (uint32_t)attrtype); \
return *this; \
} \
domi::AttrDef vec_a; \
for (const auto &v : default_value) { \
vec_a.mutable_list()->add_##field(v); \
} \
Attr(Attribute(name, attr_type, vec_a)); \
return *this; \
} \
OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Tuple<Type> &default_value) { \
if (attrtype != attr_type) { \
GELOGE(FAILED, "Attribute specification vector param_type mismatch, input attr type %u, required attr type %u.", \
(uint32_t)attr_type, (uint32_t)attrtype); \
return *this; \
} \
domi::AttrDef tuple_a; \
for (const auto &v : default_value) { \
tuple_a.mutable_list()->add_##field(v); \
} \
Attr(Attribute(name, attr_type, tuple_a)); \
return *this; \
}
#else
#define ATTR_SETTER_WITH_LIST_VALUE(Type, field, attrtype) \
OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const std::vector<Type> &default_value) { \
if (attrtype != attr_type) { \
return *this; \
} \
domi::AttrDef vec_a; \
for (const auto &v : default_value) { \
vec_a.mutable_list()->add_##field(v); \
} \
Attr(Attribute(name, attr_type, vec_a)); \
return *this; \
} \
OpSchema &OpSchema::Attr(const std::string &name, AttributeType attr_type, const Tuple<Type> &default_value) { \
if (attrtype != attr_type) { \
return *this; \
} \
domi::AttrDef tuple_a; \
for (const auto &v : default_value) { \
tuple_a.mutable_list()->add_##field(v); \
} \
Attr(Attribute(name, attr_type, tuple_a)); \
return *this; \
}

#endif
ATTR_SETTER_WITH_SINGLE_VALUE(uint32_t, u, AttributeType::UINT)
ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i, AttributeType::INT)
ATTR_SETTER_WITH_SINGLE_VALUE(bool, b, AttributeType::BOOL)
ATTR_SETTER_WITH_SINGLE_VALUE(float, f, AttributeType::FLOAT)
ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s, AttributeType::STRING)

ATTR_SETTER_WITH_LIST_VALUE(uint32_t, u, AttributeType::UINTLIST)
ATTR_SETTER_WITH_LIST_VALUE(int64_t, i, AttributeType::INTLIST)
ATTR_SETTER_WITH_LIST_VALUE(bool, b, AttributeType::BOOLLIST)
ATTR_SETTER_WITH_LIST_VALUE(float, f, AttributeType::FLOATLIST)
ATTR_SETTER_WITH_LIST_VALUE(std::string, s, AttributeType::STRINGLIST)

OpSchema &OpSchema::AttrRequired(const std::string &name, AttributeType attr_type) {
Attr(Attribute(name, attr_type, true));
return *this;
}

bool OpSchema::HasDefaultAttr(const std::string &name) const {
auto it = attributes_.find(name);
if (it == attributes_.end()) {
return false;
}

// required does not need a default value
return !it->second.required_;
}

const domi::AttrDef &OpSchema::GetDefaultAttr(const std::string &name) const {
auto it = attributes_.find(name);
if (it == attributes_.end()) {
const static domi::AttrDef attr_def;
return attr_def;
}
return it->second.default_value_;
}

bool OpSchema::Verify(const ge::OpDescPtr op_def) const {
if (op_def->GetType() != name_) {
GELOGE(FAILED, "Name not math, op schema name: %s, opdef type: %s.", name_.c_str(), op_def->GetType().c_str());
return false;
}

// Required field verification
for (const auto &pair : attributes_) {
const auto &attr = pair.second;
if (!attr.required_) {
continue;
}
if (!op_def->HasAttr(attr.name_)) {
GELOGE(FAILED, "Required attribute: %s of op: %s is missing.", attr.name_.c_str(), op_def->GetName().c_str());
return false;
}
}

return true;
}

OpSchemaFactory &OpSchemaFactory::Instance() {
static OpSchemaFactory instance;
return instance;
}

const OpSchema *OpSchemaFactory::Get(const std::string &op) const {
auto it = op_schema_map_.find(op);
if (it == op_schema_map_.end()) {
return nullptr;
}
return &it->second;
}

OpSchemaRegistry::OpSchemaRegistry(OpSchema &op_schema) {
OpSchemaFactory &op_factory = OpSchemaFactory::Instance();

// save op_schema to the map
if (op_factory.op_schema_map_.count(op_schema.name_)) {
GELOGD("Failed to register op schema: %s., reason: already exist!", op_schema.name_.c_str());
return;
}

(void)op_factory.op_schema_map_.emplace(std::make_pair(op_schema.name_, op_schema));
}
} // namespace ge

+ 175
- 0
parser/common/op_def/op_schema.h View File

@@ -0,0 +1,175 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DOMI_COMMON_OP_SCHEMA_H
#define DOMI_COMMON_OP_SCHEMA_H

#include <string>
#include <unordered_map>
#include <vector>
#include "common/tuple.h"
#include "graph/op_desc.h"
#include "proto/om.pb.h"
#include "framework/common/fmk_types.h"

namespace ge {
enum class AttributeType {
UNDEFINED,
INT,
UINT,
BOOL,
FLOAT,
STRING,
BYTES,

INTLIST,
UINTLIST,
BOOLLIST,
FLOATLIST,
STRINGLIST
};

class OpSchema;

class OpSchemaRegistry;

class FMK_FUNC_HOST_VISIBILITY OpSchema {
public:
// Formal parameter options.
enum FormalParameterOption {
// The input formal parameter is single and not optional.
// Number of this input is 1.
Single = 0,
// The input formal parameter is single and optional.
// Number of this input is 0 or 1.
Optional = 1,
// The input formal parameter is variadic.
// Number of this input is [1, n].
Variadic = 2,
};

// Formal parameter represenation, including input/output name, typeStr,
// description, and type constraints.
class FormalParameter {
public:
// Constructor.
FormalParameter() = default;

explicit FormalParameter(const std::string &name, FormalParameterOption param_option = Single);

~FormalParameter();

// Get formal parameter name.
const std::string &Name() const;

// Get the parameter option, it could be Single, Optional or Variadic.
FormalParameterOption Option() const;

private:
friend class OpSchema;

// Formal parameter name.
std::string name_;

// Formal parameter option.
FormalParameterOption param_option_;
};

explicit OpSchema(const std::string &name);

~OpSchema();

OpSchema &Input(const std::string &name, FormalParameterOption param_option = Single);

OpSchema &Output(const std::string &name, FormalParameterOption param_option = Single);

struct Attribute {
Attribute(const std::string &name, AttributeType type, bool required)
: name_(name), type_(type), required_(required) {}

Attribute(const std::string &name, AttributeType type, domi::AttrDef default_value)
: name_(name), type_(type), required_(false), default_value_(default_value) {}

const std::string name_;
AttributeType type_;
bool required_;
domi::AttrDef default_value_;
};

OpSchema &Attr(const Attribute &attr);

// Register "optional" attribute with default value.
#define ATTR_SETTER_WITH_DEFAULT_VALUE(TypeName) \
OpSchema &Attr(const std::string &name, AttributeType type, const TypeName &default_value); \
OpSchema &Attr(const std::string &name, AttributeType type, const std::vector<TypeName> &default_value); \
OpSchema &Attr(const std::string &name, AttributeType type, const Tuple<TypeName> &default_value);

ATTR_SETTER_WITH_DEFAULT_VALUE(uint32_t)
ATTR_SETTER_WITH_DEFAULT_VALUE(int64_t)
ATTR_SETTER_WITH_DEFAULT_VALUE(bool)
ATTR_SETTER_WITH_DEFAULT_VALUE(float)
ATTR_SETTER_WITH_DEFAULT_VALUE(std::string)

// Register "required" attribute without default value.
OpSchema &AttrRequired(const std::string &name, AttributeType type);

bool HasDefaultAttr(const std::string &name) const;

const domi::AttrDef &GetDefaultAttr(const std::string &name) const;

// verify op_def
bool Verify(const ge::OpDescPtr op_def) const;

private:
friend class OpSchemaRegistry;

std::string name_;

std::vector<FormalParameter> inputs_;

std::vector<FormalParameter> outputs_;

std::unordered_map<std::string, Attribute> attributes_;
};

class OpSchemaFactory {
public:
// this is a singleton object
static OpSchemaFactory &Instance();

const OpSchema *Get(const std::string &op) const;

private:
OpSchemaFactory() = default;
~OpSchemaFactory() = default;

friend class OpSchemaRegistry;
// the op schema map
std::unordered_map<std::string, OpSchema> op_schema_map_;
};

class FMK_FUNC_HOST_VISIBILITY OpSchemaRegistry {
public:
OpSchemaRegistry(OpSchema &op_schema);
~OpSchemaRegistry() = default;
};

#define DOMI_OP_SCHEMA(name) DOMI_OP_SCHEMA_UNIQ_HELPER(__COUNTER__, name)
#define DOMI_OP_SCHEMA_UNIQ_HELPER(ctr, name) DOMI_OP_SCHEMA_UNIQ(ctr, name)
#define DOMI_OP_SCHEMA_UNIQ(ctr, name) \
static OpSchemaRegistry op_schema_registry##ctr __attribute__((unused)) = OpSchema(#name)
} // namespace ge
#endif // DOMI_COMMON_OP_SCHEMA_H

+ 200
- 0
parser/common/op_def/operator.cc View File

@@ -0,0 +1,200 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "operator.h"
#include <utility>
#include "framework/common/fmk_types.h"
#include "framework/common/util.h"
#include "framework/common/debug/ge_log.h"

using ge::BoolTuple;
using ge::FloatTuple;
using ge::IntTuple;
using ge::StringTuple;
using ge::UintTuple;

namespace ge {
ParserOperator::ParserOperator(const std::string &type) {
type_ = type;
op_schema_ = ge::OpSchemaFactory::Instance().Get(type);
if (op_schema_ == nullptr) {
GELOGW("Cannot find op schema of op type: %s", type.c_str());
}
}

ParserOperator &ParserOperator::Input(const ParserOperator &in_op, uint32_t index) {
if (index == 0) {
inputs_.push_back(in_op.GetName());
} else {
inputs_.push_back(in_op.GetName() + ":" + std::to_string(index));
}
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::Name(const std::string &name) {
name_ = name;
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::Type(const std::string &type) {
type_ = type;
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::InputTensorDesc(
const ge::GeTensorDesc &input_tensordesc) {
input_descs_.push_back(input_tensordesc);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::OutputTensorDesc(
const ge::GeTensorDesc &output_tensordesc) {
output_descs_.push_back(output_tensordesc);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::AttrVector(
std::string key,
std::vector<int32_t> &value) {
domi::AttrDef out;
auto it = op_attrs_.find(key);
if (it != op_attrs_.end()) {
out = it->second.value_;
}
for (auto &v : value) {
out.mutable_list()->add_i(v);
}
(void)op_attrs_.erase(key);
(void)op_attrs_.insert(std::make_pair(key, OpAttribute(key, out)));
return *this;
}
FMK_FUNC_DEV_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::AttrVector(
std::string key,
std::vector<int64_t> &value) {
domi::AttrDef out;
auto it = op_attrs_.find(key);
if (it != op_attrs_.end()) {
out = it->second.value_;
}
for (auto &v : value) {
out.mutable_list()->add_i(v);
}
(void)op_attrs_.erase(key);
(void)op_attrs_.insert(std::make_pair(key, OpAttribute(key, out)));
return *this;
}

ParserOperator &ParserOperator::Attr(const OpAttribute &attr) {
auto it = op_attrs_.find(attr.name_);
if (it != op_attrs_.end()) {
(void)op_attrs_.erase(it);
}
(void)op_attrs_.insert(std::make_pair(attr.name_, attr));
return *this;
}

ParserOperator &ParserOperator::Attr_bt(const std::string &name, const std::string &value) {
domi::AttrDef a;
a.set_bt(value);
Attr(OpAttribute(name, a));
return *this;
}

#define ATTR_SETTER_WITH_SINGLE_VALUE(type, field) \
ParserOperator &ParserOperator::Attr(const std::string &name, const type &value) { \
domi::AttrDef a; \
a.set_##field(value); \
Attr(OpAttribute(name, a)); \
return *this; \
}

#define ATTR_SETTER_WITH_LIST_VALUE(type, field) \
ParserOperator &ParserOperator::Attr(const std::string &name, const std::vector<type> &value) { \
domi::AttrDef a; \
auto attr_list = a.mutable_list(); \
for (size_t i = 0; i < value.size(); ++i) { \
attr_list->add_##field(value[i]); \
} \
Attr(OpAttribute(name, a)); \
return *this; \
} \
ParserOperator &ParserOperator::Attr(const std::string &name, const ge::Tuple<type> &value) { \
domi::AttrDef a; \
auto attr_list = a.mutable_list(); \
for (uint32_t i = 0; i < value.ndim(); ++i) { \
attr_list->add_##field(value[i]); \
} \
Attr(OpAttribute(name, a)); \
return *this; \
}

ATTR_SETTER_WITH_SINGLE_VALUE(int64_t, i)
ATTR_SETTER_WITH_SINGLE_VALUE(bool, b)
ATTR_SETTER_WITH_SINGLE_VALUE(float, f)
ATTR_SETTER_WITH_SINGLE_VALUE(std::string, s)
ATTR_SETTER_WITH_SINGLE_VALUE(uint32_t, i)

ATTR_SETTER_WITH_LIST_VALUE(int64_t, i)
ATTR_SETTER_WITH_LIST_VALUE(bool, b)
ATTR_SETTER_WITH_LIST_VALUE(float, f)
ATTR_SETTER_WITH_LIST_VALUE(std::string, s)
ATTR_SETTER_WITH_LIST_VALUE(uint32_t, i)

#define ATTR_GET_SINGLE_VALUE(type, field, type_name) \
type ParserOperator::Get##type_name##Attr(const std::string &name) const { \
domi::AttrDef single_val; \
auto it = op_attrs_.find(name); \
if (it != op_attrs_.end()) { \
single_val = it->second.value_; \
} else { \
if (op_schema_ && op_schema_->HasDefaultAttr(name)) { \
single_val = op_schema_->GetDefaultAttr(name); \
} \
} \
return single_val.field(); \
}
ATTR_GET_SINGLE_VALUE(uint32_t, i, Uint)
ATTR_GET_SINGLE_VALUE(int64_t, i, Int)
ATTR_GET_SINGLE_VALUE(float, f, Float)
ATTR_GET_SINGLE_VALUE(bool, b, Bool)
ATTR_GET_SINGLE_VALUE(std::string, s, String)

#define ATTR_GET_TUPLE_VALUE(type, field, tuple_type_name) \
tuple_type_name ParserOperator::Get##tuple_type_name##Attr(const std::string &name) const { \
domi::AttrDef value; \
auto it = op_attrs_.find(name); \
if (it != op_attrs_.end()) { \
value = it->second.value_; \
} else { \
if (op_schema_ && op_schema_->HasDefaultAttr(name)) { \
value = op_schema_->GetDefaultAttr(name); \
} \
} \
const auto attr_def = value.list(); \
std::size_t n = attr_def.field##_size(); \
std::vector<type> vec(n); \
for (std::size_t i = 0; i < n; i++) { \
vec[i] = attr_def.field(i); \
} \
return tuple_type_name(vec); \
}

ATTR_GET_TUPLE_VALUE(uint32_t, i, UintTuple)
ATTR_GET_TUPLE_VALUE(int64_t, i, IntTuple)
ATTR_GET_TUPLE_VALUE(float, f, FloatTuple)
ATTR_GET_TUPLE_VALUE(bool, b, BoolTuple)
ATTR_GET_TUPLE_VALUE(std::string, s, StringTuple)
} // namespace domi

+ 117
- 0
parser/common/op_def/operator.h View File

@@ -0,0 +1,117 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef DOMI_COMMON_OP_OPERATOR_H
#define DOMI_COMMON_OP_OPERATOR_H

#include <string>
#include <unordered_map>
#include <vector>
#include "framework/common/fmk_types.h"
#include "common/op_def/op_schema.h"
#include "common/tuple.h"
#include "graph/ge_tensor.h"
#include "proto/om.pb.h"
namespace ge {
struct OpAttribute {
OpAttribute(const std::string &name, const domi::AttrDef &value) : name_(name), value_(value) {}
const std::string name_;
domi::AttrDef value_;
};

class FMK_FUNC_HOST_VISIBILITY ParserOperator {
public:
explicit ParserOperator(const std::string &type);
ParserOperator() { op_schema_ = nullptr; }

virtual ~ParserOperator() { op_schema_ = nullptr; }

ParserOperator &Input(const ParserOperator &in_op, uint32_t index = 0);

ParserOperator &Attr(const OpAttribute &op_attr);

ParserOperator &AttrVector(std::string key, std::vector<int32_t> &value);
ParserOperator &AttrVector(std::string key, std::vector<int64_t> &value);

ParserOperator &Name(const std::string &name);

ParserOperator &Type(const std::string &type);

ParserOperator &InputTensorDesc(const ge::GeTensorDesc &input_tensordesc);

ParserOperator &OutputTensorDesc(const ge::GeTensorDesc &output_tensordesc);

ParserOperator &Attr_bt(const std::string &name, const std::string &value);

// Register "optional" attribute with default value.
#define ATTR_SETTER_WITH_VALUE(TypeName) \
ParserOperator &Attr(const std::string &name, const TypeName &value); \
ParserOperator &Attr(const std::string &name, const std::vector<TypeName> &value); \
ParserOperator &Attr(const std::string &name, const ge::Tuple<TypeName> &value)

ATTR_SETTER_WITH_VALUE(uint32_t);
ATTR_SETTER_WITH_VALUE(int64_t);
ATTR_SETTER_WITH_VALUE(bool);
ATTR_SETTER_WITH_VALUE(float);
ATTR_SETTER_WITH_VALUE(std::string);

const std::string &GetName() const { return name_; }

const std::string &GetType() const { return type_; }

const std::vector<std::string> &GetInputs() const { return inputs_; }

const std::vector<ge::GeTensorDesc> &GetInputTensorDesc() const { return input_descs_; }

const std::vector<ge::GeTensorDesc> &GetOutputTensorDesc() const { return output_descs_; }

const std::unordered_map<std::string, OpAttribute> GetOpAttrs() const { return op_attrs_; }

bool HasAttr(const std::string &name) const { return op_attrs_.find(name) != op_attrs_.end(); }

const ge::OpSchema *GetSchema() const { return op_schema_; }

int64_t GetIntAttr(const std::string &name) const;

uint32_t GetUintAttr(const std::string &name) const;

float GetFloatAttr(const std::string &name) const;

bool GetBoolAttr(const std::string &name) const;

std::string GetStringAttr(const std::string &name) const;

ge::IntTuple GetIntTupleAttr(const std::string &name) const;

ge::UintTuple GetUintTupleAttr(const std::string &name) const;

ge::FloatTuple GetFloatTupleAttr(const std::string &name) const;

ge::BoolTuple GetBoolTupleAttr(const std::string &name) const;

ge::StringTuple GetStringTupleAttr(const std::string &name) const;

private:
const ge::OpSchema *op_schema_;
std::string name_;
std::string type_;
std::vector<std::string> inputs_;
std::unordered_map<std::string, OpAttribute> op_attrs_;
std::vector<ge::GeTensorDesc> input_descs_;
std::vector<ge::GeTensorDesc> output_descs_;
};
} // namespace domi
#endif // DOMI_COMMON_OP_OPERATOR_H

+ 34
- 0
parser/common/op_def/ref_switch_op.cc View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/ref_switch_op.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::RefSwitchOperator() : ParserOperator("RefSwitch") {}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator::~RefSwitchOperator() {}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::Name(const std::string &name) {
ParserOperator::Name(name);
return *this;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY RefSwitchOperator &RefSwitchOperator::T(ge::DataType t) {
Attr("T", (int64_t)t);
return *this;
}
} // namespace ge AUTO GEN PLEASE DO NOT MODIFY IT

+ 34
- 0
parser/common/op_def/ref_switch_op.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#ifndef DOMI_OP_REF_SWITCH_H_
#define DOMI_OP_REF_SWITCH_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"

namespace ge {
class RefSwitchOperator : public ParserOperator {
public:
RefSwitchOperator();
~RefSwitchOperator();

RefSwitchOperator &Name(const std::string &name);
RefSwitchOperator &T(ge::DataType t);
};
} // namespace ge

#endif // DOMI_OP_REF_SWITCH_H_ AUTO GEN PLEASE DO NOT MODIFY IT

+ 56
- 0
parser/common/op_def/shape_n_op.cc View File

@@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/shape_n_op.h"
#include "graph/debug/ge_attr_define.h"
#include "framework/omg/parser/parser_types.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY ShapeNOperator::ShapeNOperator() : ParserOperator("ShapeN") {}

FMK_FUNC_HOST_VISIBILITY ShapeNOperator::~ShapeNOperator() {}

FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::Name(const std::string &name) {
ParserOperator::Name(name);
return *this;
}

FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::N(int64_t n) {
Attr(SHAPEN_ATTR_N, n);
return *this;
}

FMK_FUNC_HOST_VISIBILITY int64_t ShapeNOperator::GetN() const { return GetIntAttr(SHAPEN_ATTR_N); }

FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::InType(ge::DataType t) {
Attr(SHAPEN_ATTR_IN_TYPE, (int64_t)t);
return *this;
}

FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetInType() const {
return (ge::DataType)GetIntAttr(SHAPEN_ATTR_IN_TYPE);
}

FMK_FUNC_HOST_VISIBILITY ShapeNOperator &ShapeNOperator::OutType(ge::DataType t) {
Attr(SHAPEN_ATTR_OUT_TYPE, (int64_t)t);
return *this;
}

FMK_FUNC_HOST_VISIBILITY ge::DataType ShapeNOperator::GetOutType() const {
return (ge::DataType)GetIntAttr(SHAPEN_ATTR_OUT_TYPE);
}
} // namespace ge

+ 40
- 0
parser/common/op_def/shape_n_op.h View File

@@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#ifndef DOMI_OP_SHAPE_N_OP_H_
#define DOMI_OP_SHAPE_N_OP_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"

namespace ge {
class ShapeNOperator : public ParserOperator {
public:
ShapeNOperator();
~ShapeNOperator();

ShapeNOperator &Name(const std::string &name);

ShapeNOperator &N(int64_t n);
int64_t GetN() const;
ShapeNOperator &InType(ge::DataType t);
ge::DataType GetInType() const;
ShapeNOperator &OutType(ge::DataType t);
ge::DataType GetOutType() const;
};
} // namespace ge

#endif // DOMI_OP_SHAPE_N_OP_H_ AUTO GEN PLEASE DO NOT MODIFY IT

+ 37
- 0
parser/common/op_def/var_is_initialized_op_op.cc View File

@@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#include "common/op_def/var_is_initialized_op_op.h"
#include <string>
#include <vector>

namespace ge {
VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::parser::VARISINITIALIZEDOP) {}

VarIsInitializedOpOperator::~VarIsInitializedOpOperator() {}

VarIsInitializedOpOperator &VarIsInitializedOpOperator::Name(const std::string &name) {
ParserOperator::Name(name);
return *this;
}

VarIsInitializedOpOperator &VarIsInitializedOpOperator::VectorAttr(const std::string &key,
std::vector<int64_t> &value) {
Attr(key, value);
return *this;
}
} // namespace ge

+ 34
- 0
parser/common/op_def/var_is_initialized_op_op.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#ifndef DOMI_OP_VARISINITIALIZEDOP_H_
#define DOMI_OP_VARISINITIALIZEDOP_H_
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"

namespace ge {
class VarIsInitializedOpOperator : public ParserOperator {
public:
VarIsInitializedOpOperator();
~VarIsInitializedOpOperator();

VarIsInitializedOpOperator &Name(const std::string &name);
VarIsInitializedOpOperator &VectorAttr(const std::string &key, std::vector<int64_t> &value);
};
} // namespace ge

#endif // DOMI_OP_VARISINITIALIZEDOP_H_ AUTO GEN PLEASE DO NOT MODIFY IT

+ 57
- 0
parser/common/op_def/variable_op.cc View File

@@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/op_def/variable_op.h"

#include "graph/debug/ge_attr_define.h"

namespace ge {
VariableOperator::VariableOperator() : ParserOperator(ge::parser::VARIABLE) {}

VariableOperator::~VariableOperator() {}

VariableOperator &VariableOperator::Name(const std::string &name) {
ParserOperator::Name(name);
return *this;
}

VariableOperator &VariableOperator::Container(const std::string &container) {
Attr(VAR_ATTR_CONTAINER, container);
return *this;
}

VariableOperator &VariableOperator::SharedName(const std::string &sharedname) {
Attr(VAR_ATTR_SHARED_NAME, sharedname);
return *this;
}

VariableOperator &VariableOperator::Placement(const std::string &placement) {
Attr(ATTR_VARIABLE_PLACEMENT, placement);
return *this;
}

VariableOperator &VariableOperator::SrcType(const int64_t &dtype) {
Attr(VAR_ATTR_DTYPE, dtype);
return *this;
}

VariableOperator &VariableOperator::VarShape(const std::vector<int64_t> &shape_value) {
Attr(VAR_ATTR_SHAPE, shape_value);
return *this;
}

int64_t VariableOperator::GetVarSrcType() const { return GetIntAttr(VAR_ATTR_DTYPE); }
} // namespace ge

+ 46
- 0
parser/common/op_def/variable_op.h View File

@@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// AUTO GEN PLEASE DO NOT MODIFY IT
#ifndef DOMI_OP_VARIABLE_H_
#define DOMI_OP_VARIABLE_H_
#include <vector>
#include "parser/common/op_def/operator.h"
#include "framework/omg/parser/parser_types.h"

namespace ge {
class VariableOperator : public ParserOperator {
public:
VariableOperator();
~VariableOperator();

VariableOperator &Name(const std::string &name);

VariableOperator &Container(const std::string &container);

VariableOperator &SharedName(const std::string &sharedname);

VariableOperator &Placement(const std::string &placement);

VariableOperator &SrcType(const int64_t &dtype);

VariableOperator &VarShape(const std::vector<int64_t> &shape_value);

int64_t GetVarSrcType() const;
};
} // namespace ge

#endif // DOMI_OP_VAR_H_ AUTO GEN PLEASE DO NOT MODIFY IT

+ 159
- 0
parser/common/op_map.cc View File

@@ -0,0 +1,159 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/op_map.h"

#include <map>
#include <string>
#include <vector>

#include "framework/omg/parser/parser_types.h"
#include "register/op_registry.h"

using std::map;
using std::string;
using std::vector;
using namespace ge::parser;

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<std::string, std::string> caffe_op_map = {
{"Input", DATA},
{"DummyData", DATA},
{"Reshape", RESHAPE},
{"Dropout", DROPOUT},
{"NetOutput", NETOUTPUT},
};

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<std::string, std::string> tensorflow_op_map = {
{"BroadcastGradientArgs", BROADCASTGRADIENTARGS},
{"StopGradient", STOPGRADIENT},
{"ExpandDims", EXPANDDIMS},
{"DestroyTemporaryVariable", DESTROYTEMPORARYVARIABLE},
{"GuaranteeConst", GUARANTEECONST},
{"BroadcastArgs", BROADCASTARGS},
{"PreventGradient", PREVENTGRADIENT},
{"Empty", EMPTY},
{"Placeholder", DATA},
{"ControlTrigger", CONTROLTRIGGER},
{"_ParallelConcatStart", PARALLELCONCATSTART},
{"Const", CONSTANT},
{"FrameworkOp", FRAMEWORKOP},
{"Reshape", RESHAPE},
{"Squeeze", SQUEEZE},
{"Enter", ENTER},
{"RefEnter", REFENTER},
{"Exit", EXIT},
{"RefExit", REFEXIT},
{"LoopCond", LOOPCOND},
{"NextIteration", NEXTITERATION},
{"RefNextIteration", REFNEXTITERATION},
{"Identity", IDENTITY},
{"IdentityN", IDENTITYN},
{"PlaceholderWithDefault", PLACEHOLDERWITHDEFAULT},
{"Size", SIZE},
{"Shape", SHAPE},
{"ShapeN", SHAPEN},
{"Fill", FILL},
{"Rank", RANK},
{"Merge", MERGE},
{"RefMerge", REFMERGE},
{"Switch", SWITCH},
{"RefSwitch", REFSWITCH},
{"LayerNorm", LAYERNORM},
{"RNN", RNN},
{"_Arg", ARG},
{"_Retval", FRAMEWORKOP},
{"Bitcast", BITCAST},
{"Snapshot", SNAPSHOT},
};

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map<string, string> tensorflow_train_op_map = {
{"BroadcastGradientArgs", BROADCASTGRADIENTARGS},
{"StopGradient", STOPGRADIENT},
{"ExpandDims", EXPANDDIMS},
{"DestroyTemporaryVariable", DESTROYTEMPORARYVARIABLE},
{"TemporaryVariable", TEMPORARYVARIABLE},
{"GuaranteeConst", GUARANTEECONST},
{"BroadcastArgs", BROADCASTARGS},
{"PreventGradient", PREVENTGRADIENT},
{"Empty", EMPTY},
{"ControlTrigger", CONTROLTRIGGER},
{"_Arg", ARG},
{"_ParallelConcatStart", PARALLELCONCATSTART},
{"Const", CONSTANTOP},
{"VariableV2", VARIABLE},
{"VarHandleOp", VARHANDLEOP},
{"VarIsInitializedOp", VARISINITIALIZEDOP},
{"IsVariableInitialized", ISVARIABLEINITIALIZED},
{"ReadVariableOp", READVARIABLEOP},
{"Reshape", RESHAPE},
{"Squeeze", SQUEEZE},
{"NoOp", NOOP},
{"Enter", ENTER},
{"RefEnter", REFENTER},
{"Exit", EXIT},
{"RefExit", REFEXIT},
{"LoopCond", LOOPCOND},
{"NextIteration", NEXTITERATION},
{"RefNextIteration", REFNEXTITERATION},
{"Identity", IDENTITY},
{"IdentityN", IDENTITYN},
{"PlaceholderWithDefault", PLACEHOLDERWITHDEFAULT},
{"Size", SIZE},
{"Shape", SHAPE},
{"ShapeN", SHAPEN},
{"Rank", RANK},
{"Merge", MERGE},
{"Switch", SWITCH},
{"LayerNorm", LAYERNORM},
{"LayerNormGrad", LAYERNORMGRAD},
{"Dropout", DROPOUT},
{"Bitcast", BITCAST},
{"Snapshot", SNAPSHOT},
};

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map<string, int32_t> op_output_tensor_num = {
{SSDDETECTIONOUTPUT, 3},
{REFINEDETDETECTIONOUTPUT, 3},
{FSRDETECTIONOUTPUT, 2},
{FASTERRCNNFIRSTSTAGEPOSTPROCESSOR, 4},
{FASTERRCNNSECONDSTAGEPOSTPROCESSOR, 4},
{YOLODETECTIONOUTPUT, 2},
{FASTRCNNPREDICTIONS, 4},
{RPNPROPOSALS, 3},
{MAXPOOLWITHARGMAX, 2},
{REGION, 3},
{TOPKV2, 2},
{LogTimeStamp, 0},
/* training op */
{MAXPOOLWITHARGMAX, 2},
{FUSEDBATCHNORM, 5},
{FUSEDBATCHNORMGRAD, 3},
{SHAPEN, 0},
{SSDPOSTPROCESSOR, 4},
{LAYERNORM, 3},
{LAYERNORMGRAD, 3},
{SPARSESOFTMAXCROSSENTROPYWITHLOGITS, 2},
};

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector<string> local_framework_op_vec = {
"TensorDataset", "QueueDataset", "DeviceQueueDataset", "ParallelMapDataset", "BatchDatasetV2",
"IteratorV2", "MakeIterator", "IteratorGetNext", "FilterDataset", "MapAndBatchDatasetV2"};

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY vector<string> is_dataset_op_vec = {
"TensorDataset", "QueueDataset", "DeviceQueueDataset", "ParallelMapDataset", "BatchDatasetV2",
"IteratorV2", "MakeIterator", "IteratorGetNext", "FilterDataset", "MapAndBatchDatasetV2"};
} // namespace ge

+ 45
- 0
parser/common/op_map.h View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_COMMON_OP_MAP_H_
#define GE_COMMON_OP_MAP_H_

#include <map>
#include <string>
#include <vector>

/*lint -e1073*/
namespace ge {
// the operator type mapping table of caffe and mindspore
extern std::map<std::string, std::string> caffe_op_map;

// the operator type mapping table of TensorFlow and mindspore
extern std::map<std::string, std::string> tensorflow_op_map;

// the network training operator type mapping table of TensorFlow and mindspore
extern std::map<std::string, std::string> tensorflow_train_op_map;

// local framework op vec
extern std::vector<std::string> local_framework_op_vec;

// dataset op vec
extern std::vector<std::string> is_dataset_op_vec;

// output tensor num
extern std::map<std::string, int32_t> op_output_tensor_num;
} // namespace ge
/*lint +e1073*/
#endif // GE_COMMON_OP_MAP_H_

+ 117
- 0
parser/common/op_parser_factory.cc View File

@@ -0,0 +1,117 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/op_parser_factory.h"
#include "common/debug/log.h"
#include "framework/common/debug/ge_log.h"
#include "graph/utils/type_utils.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY CustomParserAdapterRegistry *CustomParserAdapterRegistry::Instance() {
static CustomParserAdapterRegistry instance;
return &instance;
}

FMK_FUNC_HOST_VISIBILITY void CustomParserAdapterRegistry::Register(const domi::FrameworkType framework,
CustomParserAdapterRegistry::CREATOR_FUN fun) {
if (funcs_.find(framework) != funcs_.end()) {
GELOGW("Framework type %s has already registed.", TypeUtils::FmkTypeToSerialString(framework).c_str());
return;
}
funcs_[framework] = fun;
GELOGI("Register %s custom parser adapter success.", TypeUtils::FmkTypeToSerialString(framework).c_str());
return;
}
FMK_FUNC_HOST_VISIBILITY CustomParserAdapterRegistry::CREATOR_FUN
CustomParserAdapterRegistry::GetCreateFunc(const domi::FrameworkType framework) {
if (funcs_.find(framework) == funcs_.end()) {
GELOGW("Framework type %s has not registed.", TypeUtils::FmkTypeToSerialString(framework).c_str());
return nullptr;
}
return funcs_[framework];
}

FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParserFactory> OpParserFactory::Instance(
const domi::FrameworkType framework) {
// Each framework corresponds to one op parser factory,
// If instances are static data members of opparserfactory, the order of their construction is uncertain.
// Instances cannot be a member of a class because they may be used before initialization, resulting in a run error.
static std::map<domi::FrameworkType, std::shared_ptr<OpParserFactory>> instances;

auto iter = instances.find(framework);
if (iter == instances.end()) {
std::shared_ptr<OpParserFactory> instance(new (std::nothrow) OpParserFactory());
if (instance == nullptr) {
GELOGE(INTERNAL_ERROR, "Create op parser factory failed.");
return nullptr;
}
instances[framework] = instance;
return instance;
}

return iter->second;
}

FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateOpParser(const std::string &op_type) {
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser.
auto iter = op_parser_creator_map_.find(op_type);
if (iter != op_parser_creator_map_.end()) {
return iter->second();
}

GELOGE(FAILED, "OpParserFactory::CreateOpParser: Not supported type: %s", op_type.c_str());
return nullptr;
}

FMK_FUNC_HOST_VISIBILITY std::shared_ptr<OpParser> OpParserFactory::CreateFusionOpParser(const std::string &op_type) {
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create OpParser.
auto iter = fusion_op_parser_creator_map_.find(op_type);
if (iter != fusion_op_parser_creator_map_.end()) {
return iter->second();
}

GELOGE(FAILED, "OpParserFactory::CreateOpParser: Not supported fusion op type: %s", op_type.c_str());
return nullptr;
}

// This function is only called within the constructor of the global opparserregisterar object,
// and does not involve concurrency, so there is no need to lock it
FMK_FUNC_HOST_VISIBILITY void OpParserFactory::RegisterCreator(const std::string &type, CREATOR_FUN fun,
bool is_fusion_op) {
std::map<std::string, CREATOR_FUN> *op_parser_creator_map = &op_parser_creator_map_;
if (is_fusion_op) {
op_parser_creator_map = &fusion_op_parser_creator_map_;
}

GELOGD("OpParserFactory::RegisterCreator: op type:%s, is_fusion_op:%d.", type.c_str(), is_fusion_op);
(*op_parser_creator_map)[type] = fun;
}

FMK_FUNC_HOST_VISIBILITY bool OpParserFactory::OpParserIsRegistered(const std::string &op_type, bool is_fusion_op) {
if (is_fusion_op) {
auto iter = fusion_op_parser_creator_map_.find(op_type);
if (iter != fusion_op_parser_creator_map_.end()) {
return true;
}
} else {
auto iter = op_parser_creator_map_.find(op_type);
if (iter != op_parser_creator_map_.end()) {
return true;
}
}
return false;
}
} // namespace ge

+ 198
- 0
parser/common/op_parser_factory.h View File

@@ -0,0 +1,198 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_OP_PARSER_FACTORY_H_
#define PARSER_COMMON_OP_PARSER_FACTORY_H_

#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "common/ge/ge_util.h"
#include "framework/omg/parser/parser_types.h"
#include "framework/common/debug/ge_log.h"
#include "omg/omg_inner_types.h"
#include "external/register/register.h"

using domi::CAFFE;

namespace ge {
class OpParser;

/**
* @ingroup domi_omg
* @brief Used to create OpParser
*
*/
class OpParserFactory {
public:
/**
* @ingroup domi_omg
* @brief Returns the OpParserFactory instance corresponding to the Framework
* @return OpParserFactory object
*/
static std::shared_ptr<OpParserFactory> Instance(const domi::FrameworkType framework);

/**
* @ingroup domi_omg
* @brief Create OpParser based on input type
* @param [in] op_type Op type
* @return Created OpParser
*/
std::shared_ptr<OpParser> CreateOpParser(const std::string &op_type);

/**
* @ingroup domi_omg
* @brief Create fusion OpParser based on input type
* @param [in] op_type Op type
* @return Created OpParser
*/
std::shared_ptr<OpParser> CreateFusionOpParser(const std::string &op_type);

// The Factory instance is automatically released by shared_ptr.
// The shared_ptr internally calls the destructor indirectly.
// If the destructor is not public, it will generate a compilation error.
// Another solution is to specify the deleter for shared_ptr, and set the deleter as a friend of the current class.
// But this method is more complicated to implement.
~OpParserFactory() {}

bool OpParserIsRegistered(const std::string &op_type, bool is_fusion_op = false);

protected:
/**
* @ingroup domi_omg
* @brief OpParser creation function
* @return Created OpParser
*/
// typedef shared_ptr<OpParser> (*CREATOR_FUN)(void);
using CREATOR_FUN = std::function<std::shared_ptr<OpParser>(void)>;

/**
* @ingroup domi_omg
* @brief Factory instances can only be created automatically, not new methods, so the constructor is not public.
*/
OpParserFactory() {}

/**
* @ingroup domi_omg
* @brief Register creation function
* @param [in] type Op type
* @param [in] fun OpParser creation function
*/
void RegisterCreator(const std::string &type, CREATOR_FUN fun, bool is_fusion_op = false);

private:
/**
* @ingroup domi_omg
* @brief Each Op corresponds to a Creator function
*/
std::map<std::string, CREATOR_FUN> op_parser_creator_map_; // lint !e1073
std::map<std::string, CREATOR_FUN> fusion_op_parser_creator_map_;

friend class OpParserRegisterar;
friend class domi::OpRegistrationData;
friend class OpRegistrationTbe;
};

/**
* @ingroup domi_omg
* @brief For registering Creator functions for different types of Op
*
*/
class OpParserRegisterar {
public:
/**
* @ingroup domi_omg
* @brief Constructor
* @param [in] framework Framework type
* @param [in] op_type Op type
* @param [in] fun Creator function corresponding to Op
*/
OpParserRegisterar(const domi::FrameworkType framework, const std::string &op_type, OpParserFactory::CREATOR_FUN fun,
bool is_fusion_op = false) {
OpParserFactory::Instance(framework)->RegisterCreator(op_type, fun, is_fusion_op);
}
~OpParserRegisterar() {}
};

// Used to save the functions created by the xxxCustomParserAdapter class
class CustomParserAdapterRegistry {
public:
static CustomParserAdapterRegistry *Instance();
using CREATOR_FUN = std::function<std::shared_ptr<OpParser>(void)>;
void Register(const domi::FrameworkType framework, CREATOR_FUN fun);
CREATOR_FUN GetCreateFunc(const domi::FrameworkType framework);

private:
map<domi::FrameworkType, CREATOR_FUN> funcs_;

friend class CustomParserAdapterRegistrar;
};

// Register Creator function for the custom custom operator ParserAdapter
class CustomParserAdapterRegistrar {
public:
CustomParserAdapterRegistrar(const domi::FrameworkType framework, CustomParserAdapterRegistry::CREATOR_FUN fun) {
CustomParserAdapterRegistry::Instance()->Register(framework, fun);
}
~CustomParserAdapterRegistrar() {}
};

/**
* @ingroup domi_omg
* @brief OpParser Registration Macro
* @param [in] framework Framework type
* @param [in] op_type Op type
* @param [in] clazz OpParser implementation class
*/
#define REGISTER_OP_PARSER_CREATOR(framework, op_type, clazz) \
std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Op_Parser() { \
std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \
if (ptr == nullptr) { \
GELOGW("MakeShared failed, result is nullptr."); \
} \
return std::shared_ptr<OpParser>(ptr); \
} \
ge::OpParserRegisterar g_##framework##_##op_type##_Op_Parser_Creator(framework, op_type, \
Creator_##framework##_##op_type##_Op_Parser)

#define REGISTER_FUSION_OP_PARSER_CREATOR(framework, op_type, clazz) \
std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Fusion_Op_Parser() { \
std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \
if (ptr == nullptr) { \
GELOGW("MakeShared failed, result is nullptr."); \
} \
return std::shared_ptr<OpParser>(ptr); \
} \
OpParserRegisterar g_##framework##_##op_type##_Fusion_Op_Parser_Creator( \
framework, op_type, Creator_##framework##_##op_type##_Fusion_Op_Parser, true)

/// @brief xxxCustomParserAdapter Registration Macro
/// @param [in] framework Framework type
/// @param [in] clazz CaffeCustomParserAdapter adaptation class
#define REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(framework, clazz) \
std::shared_ptr<OpParser> Creator_##framework##_Op_Parser_Adapter() { \
std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \
if (ptr == nullptr) { \
GELOGW("MakeShared failed, result is nullptr."); \
} \
return std::shared_ptr<OpParser>(ptr); \
} \
CustomParserAdapterRegistrar g_##framework##_Op_Parser_Creator(framework, Creator_##framework##_Op_Parser_Adapter)
} // namespace ge
#endif // PARSER_COMMON_OP_PARSER_FACTORY_H_

+ 76
- 0
parser/common/parser_api.cc View File

@@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "framework/omg/parser/parser_api.h"
#include "common/debug/log.h"

#include "tbe_plugin_loader.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/register_tbe.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "external/ge/ge_api_types.h"

namespace ge {
static bool parser_initialized = false;
// Initialize PARSER, load custom op plugin
// options will be used later for parser decoupling
Status ParserInitialize(const std::map<std::string, std::string> &options) {
GELOGT(TRACE_INIT, "ParserInitialize start");
// check init status
if (parser_initialized) {
GELOGW("ParserInitialize is called more than once");
return SUCCESS;
}

// load custom op plugin
TBEPluginLoader::Instance().LoadPluginSo(options);

std::vector<OpRegistrationData> registrationDatas = domi::OpRegistry::Instance()->registrationDatas;
GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size());
for (OpRegistrationData &reg_data : registrationDatas) {
(void)OpRegistrationTbe::Instance()->Finalize(reg_data, true);
}

auto iter = options.find(ge::OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES);
if (iter != options.end()) {
ge::GetParserContext().enable_scope_fusion_passes = iter->second;
}

// set init status
if (!parser_initialized) {
// Initialize success, first time calling initialize
parser_initialized = true;
}

GELOGT(TRACE_STOP, "ParserInitialize finished");
return SUCCESS;
}

Status ParserFinalize() {
GELOGT(TRACE_INIT, "ParserFinalize start");
// check init status
if (!parser_initialized) {
GELOGW("ParserFinalize is called before ParserInitialize");
return SUCCESS;
}

GE_CHK_STATUS(TBEPluginLoader::Instance().Finalize());
if (parser_initialized) {
parser_initialized = false;
}
return SUCCESS;
}
} // namespace ge

+ 81
- 0
parser/common/parser_factory.cc View File

@@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "omg/parser/parser_factory.h"
#include "common/debug/log.h"
#include "framework/common/debug/ge_log.h"

namespace domi {
FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() {
static WeightsParserFactory instance;
return &instance;
}

std::shared_ptr<WeightsParser> WeightsParserFactory::CreateWeightsParser(const domi::FrameworkType type) {
std::map<domi::FrameworkType, WEIGHTS_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type);
if (iter != creator_map_.end()) {
return iter->second();
}

GELOGE(FAILED, "WeightsParserFactory::CreateWeightsParser: Not supported Type: %d", type);
return nullptr;
}

FMK_FUNC_HOST_VISIBILITY void WeightsParserFactory::RegisterCreator(const domi::FrameworkType type,
WEIGHTS_PARSER_CREATOR_FUN fun) {
std::map<domi::FrameworkType, WEIGHTS_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type);
if (iter != creator_map_.end()) {
GELOGW("WeightsParserFactory::RegisterCreator: %d creator already exist", type);
return;
}

creator_map_[type] = fun;
}

WeightsParserFactory::~WeightsParserFactory() {
creator_map_.clear();
}

FMK_FUNC_HOST_VISIBILITY ModelParserFactory *ModelParserFactory::Instance() {
static ModelParserFactory instance;
return &instance;
}

std::shared_ptr<ModelParser> ModelParserFactory::CreateModelParser(const domi::FrameworkType type) {
std::map<domi::FrameworkType, MODEL_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type);
if (iter != creator_map_.end()) {
return iter->second();
}

GELOGE(FAILED, "ModelParserFactory::CreateModelParser: Not supported Type: %d", type);
return nullptr;
}

FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::FrameworkType type,
MODEL_PARSER_CREATOR_FUN fun) {
std::map<domi::FrameworkType, MODEL_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type);
if (iter != creator_map_.end()) {
GELOGW("ModelParserFactory::RegisterCreator: %d creator already exist", type);
return;
}

creator_map_[type] = fun;
}

ModelParserFactory::~ModelParserFactory() {
creator_map_.clear();
}
} // namespace domi

+ 1270
- 0
parser/common/parser_fp16_t.cc
File diff suppressed because it is too large
View File


+ 653
- 0
parser/common/parser_fp16_t.h View File

@@ -0,0 +1,653 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_FP16_T_H_
#define PARSER_COMMON_FP16_T_H_

#include <algorithm>
#include <cmath>
#include <cstdint>

namespace ge {
namespace parser {
using DimIndex = enum {
kDim0 = 0,
kDim1,
kDim2,
kDim3,
kDim4,
kDim5,
kDim6,
kDim7,
kDim8,
kDim9,
kDim10,
kDim11,
kDim12,
kDim13,
kDim14,
kDim15,
kDim16,
};

using BitShift = enum {
kBitShift2 = 2,
kBitShift3 = 3,
kBitShift4 = 4,
kBitShift5 = 5,
kBitShift6 = 6,
kBitShift7 = 7,
kBitShift8 = 8,
kBitShift9 = 9,
kBitShift10 = 10,
kBitShift11 = 11,
kBitShift12 = 12,
kBitShift13 = 13,
kBitShift14 = 14,
kBitShift15 = 15,
kBitShift16 = 16,
kBitShift20 = 20,
kBitShift24 = 24,
kBitShift27 = 27,
kBitShift28 = 28,
kBitShift31 = 31,
kBitShift32 = 32,
kBitShift36 = 36,
kBitShift40 = 40,
kBitShift44 = 44,
kBitShift48 = 48,
kBitShift52 = 52,
kBitShift56 = 56,
kBitShift59 = 59,
kBitShift60 = 60,
kBitShift63 = 63,
kBitShift64 = 64,
kBitShift128 = 128,
kBitShift255 = 255,
kBitShift256 = 256,
kBitShift512 = 512,
kBitShift768 = 768,
kBitShift784 = 784,
kBitShift1020 = 1020,
kBitShift1024 = 1024,
kBitShift3136 = 3136,
kBitShift4096 = 4096,
kBitShift6144 = 6144,
kBitShift10240 = 10240,
kBitShift65536 = 65536
};
/// @ingroup fp16 basic parameter
/// @brief fp16 exponent bias
constexpr uint16_t kFp16ExpBias = 15;
/// @ingroup fp16 basic parameter
/// @brief the exponent bit length of fp16 is 5
constexpr uint16_t kFp16ExpLen = 5;
/// @ingroup fp16 basic parameter
/// @brief the mantissa bit length of fp16 is 10
constexpr uint16_t kFp16ManLen = 10;
/// @ingroup fp16 basic parameter
/// @brief bit index of sign in fp16
constexpr uint16_t kFp16SignIndex = 15;
/// @ingroup fp16 basic parameter
/// @brief sign mask of fp16 (1 00000 00000 00000)
constexpr uint16_t kFp16SignMask = 0x8000;
/// @ingroup fp16 basic parameter
/// @brief exponent mask of fp16 ( 11111 00000 00000)
constexpr uint16_t kFp16ExpMask = 0x7C00;
/// @ingroup fp16 basic parameter
/// @brief mantissa mask of fp16 ( 11111 11111)
constexpr uint16_t kFp16ManMask = 0x03FF;
/// @ingroup fp16 basic parameter
/// @brief hide bit of mantissa of fp16( 1 00000 00000)
constexpr uint16_t kFp16ManHideBit = 0x0400;
/// @ingroup fp16 basic parameter
/// @brief maximum value (0111 1011 1111 1111)
constexpr uint16_t kFp16Max = 0x7BFF;
/// @ingroup fp16 basic parameter
/// @brief minimum value (1111 1011 1111 1111)
constexpr uint16_t kFp16Min = 0xFBFF;
/// @ingroup fp16 basic parameter
/// @brief absolute maximum value (0111 1111 1111 1111)
constexpr uint16_t kFp16AbsMax = 0x7FFF;
/// @ingroup fp16 basic parameter
/// @brief maximum exponent value of fp16 is 15(11111)
constexpr uint16_t kFp16MaxExp = 0x001F;
/// @ingroup fp16 basic parameter
/// @brief maximum valid exponent value of fp16 is 14(11110)
constexpr uint16_t kFp16MaxValidExp = 0x001E;
/// @ingroup fp16 basic parameter
/// @brief maximum mantissa value of fp16(11111 11111)
constexpr uint16_t kFp16MaxMan = 0x03FF;
/// @ingroup fp16 basic parameter
/// @brief absolute minimum normal value of fp16
/// (E=1,M=0 D=2^(-14)=0.00006103515625)
constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14);
/// @ingroup fp16 basic operator
/// @brief get sign of fp16
#define FP16_EXTRAC_SIGN(x) (((x) >> 15) & 1)
/// @ingroup fp16 basic operator
/// @brief get exponent of fp16
#define FP16_EXTRAC_EXP(x) (((x) >> 10) & kFp16MaxExp)
/// @ingroup fp16 basic operator
/// @brief get mantissa of fp16
#define FP16_EXTRAC_MAN(x) ((((x) >> 0) & 0x3FF) | (((((x) >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400))
/// @ingroup fp16 basic operator
/// @brief constructor of fp16 from sign exponent and mantissa
#define FP16_CONSTRUCTOR(s, e, m) (((s) << kFp16SignIndex) | ((e) << kFp16ManLen) | ((m)&kFp16MaxMan))
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is zero
#define FP16_IS_ZERO(x) (((x)&kFp16AbsMax) == 0)
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is a denormalized value
#define FP16_IS_DENORM(x) ((((x)&kFp16ExpMask) == 0))
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is infinite
#define FP16_IS_INF(x) (((x)&kFp16AbsMax) == kFp16ExpMask)
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is NaN
#define FP16_IS_NAN(x) (((x & kFp16ExpMask) == kFp16ExpMask) && (x & kFp16ManMask))
/// @ingroup fp16 special value judgment
/// @brief whether a fp16 is invalid
#define FP16_IS_INVALID(x) ((x & kFp16ExpMask) == kFp16ExpMask)
/// @ingroup fp32 basic parameter
/// @brief fp32 exponent bias
constexpr uint16_t kFp32ExpBias = 127;
/// @ingroup fp32 basic parameter
/// @brief the exponent bit length of float/fp32 is 8
constexpr uint16_t kFp32ExpLen = 8;
/// @ingroup fp32 basic parameter
/// @brief the mantissa bit length of float/fp32 is 23
constexpr uint16_t kFp32ManLen = 23;
/// @ingroup fp32 basic parameter
/// @brief bit index of sign in float/fp32
constexpr uint16_t kFp32SignIndex = 31;
/// @ingroup fp32 basic parameter
/// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000)
constexpr uint32_t kFp32SignMask = 0x80000000u;
/// @ingroup fp32 basic parameter
/// @brief exponent mask of fp32 ( 1111 1111 0000 0000 0000 0000 000)
constexpr uint32_t kFp32ExpMask = 0x7F800000u;
/// @ingroup fp32 basic parameter
/// @brief mantissa mask of fp32 ( 1111 1111 1111 1111 111)
constexpr uint32_t kFp32ManMask = 0x007FFFFFu;
/// @ingroup fp32 basic parameter
/// @brief hide bit of mantissa of fp32 ( 1 0000 0000 0000 0000 000)
constexpr uint32_t kFp32ManHideBit = 0x00800000u;
/// @ingroup fp32 basic parameter
/// @brief absolute maximum value (0 1111 1111 1111 1111 1111 1111 111)
constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu;
/// @ingroup fp32 basic parameter
/// @brief maximum exponent value of fp32 is 255(1111 1111)
constexpr uint32_t kFp32MaxExp = 0xFF;
/// @ingroup fp32 basic parameter
/// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111)
constexpr uint32_t kFp32MaxMan = 0x7FFFFF;
/// @ingroup fp32 special value judgment
/// @brief whether a fp32 is NaN
#define FP32_IS_NAN(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (x & kFp32ManMask))
/// @ingroup fp32 special value judgment
/// @brief whether a fp32 is infinite
#define FP32_IS_INF(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (!(x & kFp32ManMask)))
/// @ingroup fp32 special value judgment
/// @brief whether a fp32 is a denormalized value
#define FP32_IS_DENORM(x) ((((x)&kFp32ExpMask) == 0))
/// @ingroup fp32 basic operator
/// @brief get sign of fp32
#define FP32_EXTRAC_SIGN(x) (((x) >> kFp32SignIndex) & 1)
/// @ingroup fp32 basic operator
/// @brief get exponent of fp16
#define FP32_EXTRAC_EXP(x) (((x)&kFp32ExpMask) >> kFp32ManLen)
/// @ingroup fp32 basic operator
/// @brief get mantissa of fp16
#define FP32_EXTRAC_MAN(x) (((x)&kFp32ManMask) | (((((x) >> kFp32ManLen) & kFp32MaxExp) > 0 ? 1 : 0) * kFp32ManHideBit))
/// @ingroup fp32 basic operator
/// @brief constructor of fp32 from sign exponent and mantissa
#define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m)&kFp32MaxMan))
/// @ingroup fp64 basic parameter
/// @brief fp64 exponent bias
constexpr uint16_t kFp64ExpBias = 1023;
/// @ingroup fp64 basic parameter
/// @brief the exponent bit length of double/fp64 is 11
constexpr uint16_t kFp64ExpLen = 11;
/// @ingroup fp64 basic parameter
/// @brief the mantissa bit length of double/fp64 is 52
constexpr uint16_t kFp64ManLen = 52;
/// @ingroup fp64 basic parameter
/// @brief bit index of sign in double/fp64 is 63
constexpr uint16_t kFp64SignIndex = 63;
/// @ingroup fp64 basic parameter
/// @brief sign mask of fp64 (1 000 (total 63bits 0))
constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu;
/// @ingroup fp64 basic parameter
/// @brief exponent mask of fp64 (0 1 11111 11111 0000?-?-(total 52bits 0))
constexpr uint64_t kFp64ExpMask = 0x7FF0000000000000LLu;
/// @ingroup fp64 basic parameter
/// @brief mantissa mask of fp64 ( 1111?-?-(total 52bits 1))
constexpr uint64_t kFp64ManMask = 0x000FFFFFFFFFFFFFLLu;
/// @ingroup fp64 basic parameter
/// @brief hide bit of mantissa of fp64 ( 1 0000?-?-(total 52bits 0))
constexpr uint64_t kFp64ManHideBit = 0x0010000000000000LLu;
/// @ingroup fp64 basic parameter
/// @brief absolute maximum value (0 111?-?-(total 63bits 1))
constexpr uint64_t kFp64AbsMax = 0x7FFFFFFFFFFFFFFFLLu;
/// @ingroup fp64 basic parameter
/// @brief maximum exponent value of fp64 is 2047(1 11111 11111)
constexpr uint64_t kFp64MaxExp = 0x07FF;
/// @ingroup fp64 basic parameter
/// @brief maximum mantissa value of fp64 (111?-?-(total 52bits 1))
constexpr uint64_t kFp64MaxMan = 0xFFFFFFFFFFFLLu;
/// @ingroup fp64 special value judgment
/// @brief whether a fp64 is NaN
#define FP64_IS_NAN(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (x & kFp64ManMask))
/// @ingroup fp64 special value judgment
/// @brief whether a fp64 is infinite
#define FP64_IS_INF(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (!(x & kFp64ManMask)))
/// @ingroup integer special value judgment
/// @brief maximum positive value of int8_t (0111 1111)
constexpr int8_t kInt8Max = 0x7F;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 8 bits length (1111 111)
constexpr uint8_t kBitLen8Max = 0xFF;
/// @ingroup integer special value judgment
/// @brief maximum positive value of int16_t (0111 1111 1111 1111)
constexpr int16_t kInt16Max = 0x7FFF;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 16 bits length (1111 1111 1111 1111)
constexpr uint16_t kBitLen16Max = 0xFFFF;
/// @ingroup integer special value judgment
/// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111)
constexpr int32_t kInt32Max = 0x7FFFFFFFu;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111)
constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu;
/// @ingroup integer special value judgment
/// @brief maximum positive value of int64_t
/// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111)
constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu;
/// @ingroup integer special value judgment
/// @brief maximum value of a data with 64 bits length
/// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111)
constexpr uint64_t kBitLen64Max = 0xFFFFFFFFFFFFFFFFu;

/// @ingroup fp16_t enum
/// @brief round mode of last valid digital
enum TagFp16RoundMode {
kRoundToNearest = 0, // < round to nearest even
kRoundByTruncated, // < round by truncated
kRoundModeReserved,
};

/// @ingroup fp16_t
/// @brief Half precision float
/// bit15: 1 bit SIGN +---+-----+------------+
/// bit14-10: 5 bit EXP | S |EEEEE|MM MMMM MMMM|
/// bit0-9: 10bit MAN +---+-----+------------+
using fp16_t = struct TagFp16 {
uint16_t val;

public:
/// @ingroup fp16_t constructor
/// @brief Constructor without any param(default constructor)
TagFp16(void) { val = 0x0u; }

/// @ingroup fp16_t constructor
/// @brief Constructor with an uint16_t value
TagFp16(const uint16_t &ui_val) : val(ui_val) {}

/// @ingroup fp16_t constructor
/// @brief Constructor with a fp16_t object(copy constructor)
TagFp16(const TagFp16 &fp) : val(fp.val) {}

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be added
/// @brief Override addition operator to performing fp16_t addition
/// @return Return fp16_t result of adding this and fp
TagFp16 operator+(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be subtracted
/// @brief Override addition operator to performing fp16_t subtraction
/// @return Return fp16_t result of subtraction fp from this
TagFp16 operator-(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be multiplied
/// @brief Override multiplication operator to performing fp16_t multiplication
/// @return Return fp16_t result of multiplying this and fp
TagFp16 operator*(const TagFp16 fp);

/// @ingroup fp16_t math operator divided
/// @param [in] fp fp16_t object to be divided
/// @brief Override division operator to performing fp16_t division
/// @return Return fp16_t result of division this by fp
TagFp16 operator/(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be added
/// @brief Override addition operator to performing fp16_t addition
/// @return Return fp16_t result of adding this and fp
TagFp16 operator+=(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be subtracted
/// @brief Override addition operator to performing fp16_t subtraction
/// @return Return fp16_t result of subtraction fp from this
TagFp16 operator-=(const TagFp16 fp);

/// @ingroup fp16_t math operator
/// @param [in] fp fp16_t object to be multiplied
/// @brief Override multiplication operator to performing fp16_t multiplication
/// @return Return fp16_t result of multiplying this and fp
TagFp16 operator*=(const TagFp16 fp);

/// @ingroup fp16_t math operator divided
/// @param [in] fp fp16_t object to be divided
/// @brief Override division operator to performing fp16_t division
/// @return Return fp16_t result of division this by fp
TagFp16 operator/=(const TagFp16 fp);

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t if-equal comparison
/// @return Return boolean result of if-equal comparison of this and fp.
bool operator==(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t not-equal comparison
/// @return Return boolean result of not-equal comparison of this and fp.
bool operator!=(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t greater-than comparison
/// @return Return boolean result of greater-than comparison of this and fp.
bool operator>(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t greater-equal comparison
/// @return Return boolean result of greater-equal comparison of this and fp.
bool operator>=(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t less-than comparison
/// @return Return boolean result of less-than comparison of this and fp.
bool operator<(const TagFp16 &fp) const;

/// @ingroup fp16_t math compare operator
/// @param [in] fp fp16_t object to be compared
/// @brief Override basic comparison operator to performing fp16_t less-equal comparison
/// @return Return boolean result of less-equal comparison of this and fp.
bool operator<=(const TagFp16 &fp) const;

/// @ingroup fp16_t math evaluation operator
/// @param [in] fp fp16_t object to be copy to fp16_t
/// @brief Override basic evaluation operator to copy fp16_t to a new fp16_t
/// @return Return fp16_t result from fp
TagFp16 &operator=(const TagFp16 &fp);

/// @ingroup fp16_t math evaluation operator
/// @param [in] f_val float object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert float to fp16_t
/// @return Return fp16_t result from f_val
TagFp16 &operator=(const float &f_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] d_val double object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert double to fp16_t
/// @return Return fp16_t result from d_val
TagFp16 &operator=(const double &d_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] i_val float object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert float to fp16_t
/// @return Return fp16_t result from i_val
TagFp16 &operator=(const int8_t &i_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] ui_val uint8_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert uint8_t to fp16_t
/// @return Return fp16_t result from ui_val
TagFp16 &operator=(const uint8_t &ui_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] i_val int16_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert int16_t to fp16_t
/// @return Return fp16_t result from i_val
TagFp16 &operator=(const int16_t &i_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] ui_val uint16_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert uint16_t to fp16_t
/// @return Return fp16_t result from ui_val
TagFp16 &operator=(const uint16_t &ui_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] i_val int32_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert int32_t to fp16_t
/// @return Return fp16_t result from i_val
TagFp16 &operator=(const int32_t &i_val);

/// @ingroup fp16_t math evaluation operator
/// @param [in] ui_val uint32_t object to be converted to fp16_t
/// @brief Override basic evaluation operator to convert uint32_t to fp16_t
/// @return Return fp16_t result from ui_val
TagFp16 &operator=(const uint32_t &ui_val);

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to float/fp32
/// @return Return float/fp32 value of fp16_t
operator float() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to double/fp64
/// @return Return double/fp64 value of fp16_t
operator double() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int8_t
/// @return Return int8_t value of fp16_t
operator int8_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint8_t
/// @return Return uint8_t value of fp16_t
operator uint8_t() const;

/// @ingroup fp16_t conversion
/// @brief Override convert operator to convert fp16_t to int16_t
/// @return Return int16_t value of fp16_t
operator int16_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint16_t
/// @return Return uint16_t value of fp16_t
operator uint16_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int32_t
/// @return Return int32_t value of fp16_t
operator int32_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint32_t
/// @return Return uint32_t value of fp16_t
operator uint32_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to int64_t
/// @return Return int64_t value of fp16_t
operator int64_t() const;

/// @ingroup fp16_t math conversion
/// @brief Override convert operator to convert fp16_t to uint64_t
/// @return Return uint64_t value of fp16_t
operator uint64_t() const;

/// @ingroup fp16_t judgment method
/// @param [in] fp fp16_t object to be judgement
/// @brief whether a fp16_t is inifinite
/// @return Returns 1:+INF -1:-INF 0:not INF
int IsInf();

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to float/fp32
/// @return Return float/fp32 value of fp16_t
float ToFloat() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to double/fp64
/// @return Return double/fp64 value of fp16_t
double ToDouble() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to int8_t
/// @return Return int8_t value of fp16_t
int8_t ToInt8() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint8_t
/// @return Return uint8_t value of fp16_t
uint8_t ToUInt8() const;

/// @ingroup fp16_t conversion
/// @brief Convert fp16_t to int16_t
/// @return Return int16_t value of fp16_t
int16_t ToInt16() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint16_t
/// @return Return uint16_t value of fp16_t
uint16_t ToUInt16() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to int32_t
/// @return Return int32_t value of fp16_t
int32_t ToInt32() const;

/// @ingroup fp16_t math conversion
/// @brief Convert fp16_t to uint32_t
/// @return Return uint32_t value of fp16_t
uint32_t ToUInt32() const;
};

/// @ingroup fp16_t public method
/// @param [in] val signature is negative
/// @param [in|out] s sign of fp16_t object
/// @param [in|out] e exponent of fp16_t object
/// @param [in|out] m mantissa of fp16_t object
/// @brief Extract the sign, exponent and mantissa of a fp16_t object
void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m);

/// @ingroup fp16_t public method
/// @param [in] negative sign is negative
/// @param [in|out] man mantissa to be reverse
/// @brief Calculate a mantissa's complement (add ont to it's radix-minus-one complement)
/// @return Return complement of man
template<typename T>
void ReverseMan(bool negative, T &man) {
if (negative) {
man = (~(man)) + 1;
}
}

/// @ingroup fp16_t public method
/// @param [in] e_a exponent of one fp16_t/float number
/// @param [in] m_a mantissa of one fp16_t/float number
/// @param [in] e_b exponent of another fp16_t/float number
/// @param [in] m_b mantissa of another fp16_t/float number
/// @brief choose mantissa to be shift right whoes exponent is less than another one
/// @return Return mantissawhoes exponent is less than another one
template<typename T>
T MinMan(const int16_t &e_a, T &m_a, const int16_t &e_b, T &m_b) {
return (e_a > e_b) ? m_b : m_a;
}

/// @ingroup fp16_t public method
/// @param [in] man mantissa to be operate
/// @param [in] shift right shift bits
/// @brief right shift a mantissa
/// @return Return right-shift mantissa
template<typename T>
T RightShift(T man, int16_t shift) {
int bits = sizeof(T) * 8; // one byte have 8 bits
T mask = (((T) 1u) << ((unsigned int) (bits - 1)));
for (int i = 0; i < shift; i++) {
man = ((man & mask) | (man >> 1));
}
return man;
}

/// @ingroup fp16_t public method
/// @param [in] e_a exponent of one temp fp16_t number
/// @param [in] m_a mantissa of one temp fp16_t number
/// @param [in] e_b exponent of another temp fp16_t number
/// @param [in] m_b mantissa of another temp fp16_t number
/// @brief Get mantissa sum of two temp fp16_t numbers, T support types: uint16_t/uint32_t/uint64_t
/// @return Return mantissa sum
template<typename T>
T GetManSum(int16_t e_a, const T &m_a, int16_t e_b, const T &m_b) {
T sum = 0;
if (e_a != e_b) {
T m_tmp = 0;
int16_t e_tmp = std::abs(e_a - e_b);
if (e_a > e_b) {
m_tmp = m_b;
m_tmp = RightShift(m_tmp, e_tmp);
sum = m_a + m_tmp;
} else {
m_tmp = m_a;
m_tmp = RightShift(m_tmp, e_tmp);
sum = m_tmp + m_b;
}
} else {
sum = m_a + m_b;
}
return sum;
}

/// @ingroup fp16_t public method
/// @param [in] bit0 whether the last preserved bit is 1 before round
/// @param [in] bit1 whether the abbreviation's highest bit is 1
/// @param [in] bitLeft whether the abbreviation's bits which not contain highest bit grater than 0
/// @param [in] man mantissa of a fp16_t or float number, support types: uint16_t/uint32_t/uint64_t
/// @param [in] shift abbreviation bits
/// @brief Round fp16_t or float mantissa to nearest value
/// @return Returns true if round 1,otherwise false;
template<typename T>
T ManRoundToNearest(bool bit0, bool bit1, bool bitLeft, T man, uint16_t shift = 0) {
man = (man >> shift) + ((bit1 && (bitLeft || bit0)) ? 1 : 0);
return man;
}

/// @ingroup fp16_t public method
/// @param [in] man mantissa of a float number, support types: uint16_t/uint32_t/uint64_t
/// @brief Get bit length of a uint32_t number
/// @return Return bit length of man
template<typename T>
int16_t GetManBitLength(T man) {
int16_t len = 0;
while (man) {
man >>= 1;
len++;
}
return len;
}
} // namespace parser
} // namespace ge
#endif // GE_PARSER_COMMON_FP16_T_H_

+ 24
- 0
parser/common/parser_inner_ctx.cc View File

@@ -0,0 +1,24 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "framework/omg/parser/parser_inner_ctx.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserContext &GetParserContext() {
static ParserContext context;
return context;
}
} // namespace domi

+ 494
- 0
parser/common/parser_types.cc View File

@@ -0,0 +1,494 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "framework/omg/parser/parser_types.h"


namespace ge{
namespace parser {
const char *DATA = "Data";
const char *AIPPDATA = "AippData";
const char *CONVOLUTION = "Convolution";
const char *CORRELATION = "Correlation";
const char *CORRELATIONV2 = "Correlation_V2";
const char *DECONVOLUTION = "Deconvolution";
const char *POOLING = "Pooling";
const char *ELTWISE = "Eltwise";
const char *RELU = "ReLU";
const char *RELU6 = "ReLU6";
const char *SIGMOID = "Sigmoid";
const char *ABSVAL = "AbsVal";
const char *TANH = "TanH";
const char *PRELU = "PReLU";
const char *BATCHNORM = "BatchNorm";
const char *FUSIONBATCHNORM = "FusionBatchNorm";
const char *SCALE = "Scale";
const char *FULL_CONNECTION = "FullConnection";
const char *SOFTMAX = "Softmax";
const char *PLUS = "Plus";
const char *ACTIVATION = "Activation";
const char *FLATTEN = "Flatten";
const char *ADD = "Add";
const char *SUB = "Sub";
const char *MUL = "Mul";
const char *MATMUL = "MatMul";
const char *RSQRT = "Rsqrt";
const char *BIASADD = "BiasAdd";
const char *RESHAPE = "Reshape";
const char *REFORMAT = "ReFormat";
const char *DEPCONVOLUTION = "ConvolutionDepthwise";
const char *DROPOUT = "Dropout";
const char *DROPOUTGENMASK = "DropOutGenMask";
const char *DROPOUTDOMASK = "DropOutDoMask";
const char *CONCAT = "Concat";
const char *ROIPOOLING = "ROIPooling";
const char *PROPOSAL = "Proposal";
const char *FSRDETECTIONOUTPUT = "FSRDetectionOutput";
const char *DETECTIONPOSTPROCESS = "Detectpostprocess";
const char *LRN = "LRN";
const char *TRANSDATA = "TransData";
const char *PERMUTE = "Permute";
const char *SSDNORMALIZE = "SSDNormalize";
const char *SSDPRIORBOX = "SSDPriorBox";
const char *NETOUTPUT = "NetOutput";
const char *SSDDETECTIONOUTPUT = "SSDDetectionOutput";
const char *REFINEDETDETECTIONOUTPUT = "RefinedetDetectionOutput";
const char *CHANNELAXPY = "ChannelAxpy";
const char *PSROIPOOLING = "PSROIPooling";
const char *POWER = "Power";
const char *POW = "Pow";
const char *ROIALIGN = "ROIAlign";
const char *PYTHON = "Python";
const char *FREESPACEEXTRACT = "FreespaceExtract";
const char *SPATIALTF = "SpatialTransform";
const char *SHAPE = "Shape";
const char *SHAPEN = "ShapeN";
const char *ARGMAX = "ArgMax";
const char *GATHERND = "GatherNd";
const char *GATHER = "Gather";
const char *REALDIV = "RealDiv";
const char *PACK = "Pack";
const char *SLICE = "Slice";
const char *SLICED = "SliceD";
const char *FLOORDIV = "FloorDiv";
const char *SQUEEZE = "Squeeze";
const char *UNSQUEEZE = "Unsqueeze";
const char *STRIDEDSLICE = "StridedSlice";
const char *RANGE = "Range";
const char *RPNPROPOSALS = "RpnProposals";
const char *DECODEBBOX = "DecodeBbox";
const char *PAD = "Pad";
const char *PADV2 = "PadV2";
const char *MIRRORPAD = "MirrorPad";
const char *TILE = "Tile";
const char *SIZE = "Size";
const char *CLIPBOXES = "ClipBoxes";
const char *FASTRCNNPREDICTIONS = "FastrcnnPredictions";
const char *SPLIT = "Split";
const char *SPLITV = "SplitV";
const char *EXPANDDIMS = "ExpandDims";
const char *EMPTY = "Empty";
const char *MEAN = "Mean";
const char *GREATER = "Greater";
const char *SWITCH = "Switch";
const char *SWITCHN = "SwitchN";
const char *MERGE = "Merge";
const char *SYMBOLICGRADIENT = "SymbolicGradient";
const char *REMOTECALL = "RemoteCall";
const char *_IF = "_If";
const char *STATELESSIF = "StatelessIf";
const char *IF = "If";
const char *CASE = "Case";
const char *_WHILE = "_While";
const char *WHILE = "While";
const char *STATELESSWHILE = "StatelessWhile";
const char *FOR = "For";
const char *PARTITIONEDCALL = "PartitionedCall";
const char *STATEFULPARTITIONEDCALL = "StatefulPartitionedCall";
const char *FAKEPARAM = "FakeParam";
const char *TRANSPOSE = "Transpose";
const char *TRANSPOSED = "TransposeD";
const char *CAST = "Cast";
const char *REGION = "Region";
const char *YOLO = "Yolo";
const char *YOLODETECTIONOUTPUT = "YoloDetectionOutput";
const char *FILL = "Fill";
const char *REVERSE = "Reverse";
const char *UNPACK = "Unpack";
const char *YOLO2REORG = "Yolo2Reorg";
const char *REDUCESUM = "ReduceSum";
const char *SUM = "Sum";
const char *CONSTANT = "Const";
const char *RESIZEBILINEAR = "ResizeBilinear";
const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad";
const char *MAXIMUM = "Maximum";
const char *FRAMEWORKOP = "FrameworkOp";
const char *ARG = "_Arg";
const char *FUSEDBATCHNORMGRAD = "FusedBatchNormGrad";
const char *LSTM = "LSTM";
const char *HIGHWAY = "HighWay";
const char *RNN = "RNN";
const char *ATTENTIONDECODER = "AttentionDecoder";
const char *LOGICAL_NOT = "LogicalNot";
const char *LOGICAL_AND = "LogicalAnd";
const char *LOGICAL_OR = "LogicalOr";
const char *EQUAL = "Equal";
const char *NOTEQUAL = "NotEqual";
const char *INTERP = "Interp";
const char *SHUFFLECHANNEL = "ShuffleChannel";
const char *AIPP = "Aipp";
const char *MULTISHAPE = "MultiShape";
const char *RECIPROCAL = "Reciprocal";
const char *SELU = "Selu";
const char *ELU = "Elu";
const char *ACOSH = "Acosh";
const char *ASINH = "Asinh";
const char *MINIMUM = "Minimum";
const char *CLIP = "Clip";
const char *L2NORMALIZE = "L2Normalize";
const char *CROPANDRESIZE = "CropAndResize";
const char *UNUSEDCONST = "UnusedConst";
const char *SPARSETODENSE = "SparseToDense";
const char *NONMAXSUPPRESSION = "NonMaxSuppression";
const char *TOPKV2 = "TopKV2";
const char *INVERTPERMUTATION = "InvertPermutation";
const char *MULTINOMIAL = "Multinomial";
const char *REVERSESEQUENCE = "ReverseSequence";
const char *REDUCEPROD = "ReduceProd";
const char *REDUCEMAX = "ReduceMax";
const char *REDUCEMIN = "ReduceMin";
const char *EXTRACTIMAGEPATCHES = "ExtractImagePatches";
const char *SQRT = "Sqrt";
const char *REDUCEALL = "ReduceAll";
const char *RESIZENEARESTNEIGHBOR = "ResizeNearestNeighbor";
const char *SPACETOBATCHND = "SpaceToBatchND";
const char *BATCHTOSPACEND = "BatchToSpaceND";
const char *ASSERT = "Assert";
const char *GREATEREQUAL = "GreaterEqual";
const char *FLOOR = "Floor";
const char *RANDOMUNIFORM = "RandomUniform";
const char *BATCHMATMUL = "BatchMatMul";
const char *SPACETODEPTH = "SpaceToDepth";
const char *DEPTHTOSPACE = "DepthToSpace";
const char *RINT = "Rint";
const char *ATAN = "Atan";
const char *ATAN2 = "Atan2";
const char *ATANH = "Atanh";
const char *ACOS = "Acos";
const char *ASIN = "Asin";
const char *NEG = "Neg";
const char *LOG = "Log";
const char *TAN = "Tan";
const char *ROUND = "Round";
const char *UPSAMPLE = "Upsample";
const char *FLOORMOD = "FloorMod";
const char *LESS = "Less";
const char *LESSEQUAL = "LessEqual";
const char *ONEHOT = "OneHot";
const char *REFSWITCH = "RefSwitch";
const char *REFMERGE = "RefMerge";
const char *ENTER = "Enter";
const char *REFENTER = "RefEnter";
const char *LOOPCOND = "LoopCond";
const char *NEXTITERATION = "NextIteration";
const char *REFNEXTITERATION = "RefNextIteration";
const char *EXIT = "Exit";
const char *REFEXIT = "RefExit";
const char *CONTROLTRIGGER = "ControlTrigger";
const char *ZEROSLIKE = "ZerosLike";
const char *EXP = "Exp";
const char *WHERE = "Where";
const char *FAKEQUANTWITHMINMAXVARS = "FakeQuantWithMinMaxVars";
const char *SOFTPLUS = "Softplus";
const char *SOFTSIGN = "Softsign";
const char *COSH = "Cosh";
const char *SINH = "Sinh";
const char *SQUAREDDIFFERENCE = "SquaredDifference";
const char *REQUIREDSPACETOBATCHPADDINGS = "RequiredSpaceToBatchPaddings"; // for retinanet scope fusion
const char *SSDPOSTPROCESSOR = "SSDPostProcessor";
const char *RETINANETBOXES = "RetinanetBoxes";
const char *RETINAMULTIANCHORS = "RetinaMultiAnchor";
const char *RETINANETCLIPPEDBOXES = "RetinanetClippedBoxes";
const char *RETINANETFILTEREDDETECTIONS = "RetinanetFilteredDetections";
const char *RETINANETPOSTPROCESSOR = "RetinanetPostProcessor";
const char *RETINANETANCHORS = "RetinanetAnchors";
const char *FASTERRCNNMAP = "FasterRCNNMap";
const char *FASTERRCNNMAP1 = "FasterRCNNMap1";
const char *FASTERRCNNSECONDSTAGEPOSTPROCESSOR = "FasterRCNNSecondStagePostprocessor";
const char *FASTERRCNNROIINTERPOOLING = "FasterRCNNROIInterPooling";
const char *FASTERRCNNFIRSTSTAGEPOSTPROCESSOR = "FasterRCNNFirstStagePostprocessor";
const char *FASTERRCNNGRIDANCHORGENERATOR = "FasterRCNNGridAnchorGenerator";
const char *ROIINTERPOOLING = "ROIInterPooling";
const char *FASTERRCNNCLIPTOWINDOW = "FasterRCNNClipToWindow";
const char *EMBEDLOOKUP = "EmbedLookup";
const char *HASHLOOKUP = "HashLookup";
const char *LSH_PROJ = "LshProject";
const char *SVDF = "SVDF";
const char *SSDANCHORGENERATOR = "SSDAnchorGenerator";
const char *IDENTITY = "Identity";
const char *IDENTITYN = "IdentityN";
const char *PLACEHOLDERWITHDEFAULT = "PlaceholderWithDefault";
const char *SELECT = "Select";
const char *GETSPAN = "GetSpan";
const char *STOPGRADIENT = "StopGradient";
const char *PREVENTGRADIENT = "PreventGradient";
const char *GUARANTEECONST = "GuaranteeConst";
const char *BROADCASTGRADIENTARGS = "BroadcastGradientArgs";
const char *BROADCASTARGS = "BroadcastArgs";
const char *CONFUSIONMATRIX = "ConfusionMatrix";
const char *RANK = "Rank";
const char *PLACEHOLDER = "PlaceHolder";
const char *END = "End";
const char *BASICLSTMCELL = "BasicLSTMCell";
const char *GETNEXT = "GetNext";
const char *INITDATA = "InitData";
const char *REFIDENTITY = "RefIdentity";
const char *BITCAST = "Bitcast";

/***************Ann special operator*************************/
const char *ANN_MEAN = "AnnMean";
const char *ANN_CONVOLUTION = "AnnConvolution";
const char *ANN_DEPCONVOLUTION = "AnnDepthConv";
const char *ANN_FULLCONNECTION = "AnnFullConnection";
const char *ANN_NETOUTPUT = "AnnNetOutput";
const char *ANN_DATA = "AnnData";
const char *ANN_RESHAPE = "AnnReshape";
const char *ANN_ADD = "AnnAdd";
const char *ANN_MUL = "AnnMul";
const char *ANN_SUB = "AnnSub";
const char *ANN_DIV = "AnnDiv";
const char *ANN_DEQUANTIZE = "AnnDequant";
const char *ANN_QUANTIZE = "AnnQuant";
const char *ANN_PAD = "AnnPad";
const char *ANN_RESIZE_BILINEAR = "AnnResizeBilinear";

/***************************************************/
/******************Training operator*************************/
const char *GATHERV2 = "GatherV2";
const char *CONVGRADFILTER = "Conv2DBackpropFilter";
const char *CONV2D = "Conv2D";
const char *CONV2DBACKPROPINPUT = "Conv2DBackpropInput";
const char *FUSEDBATCHNORM = "FusedBatchNorm";
const char *BIASADDGRAD = "BiasAddGrad";
const char *ACTIVATIONGRAD = "ReluGrad";
const char *MAXPOOLWITHARGMAX = "MaxPoolWithArgmax";
const char *MAXPOOLGRADWITHARGMAX = "MaxPoolGradWithArgmax";
const char *SPARSESOFTMAXCROSSENTROPYWITHLOGITS = "SparseSoftmaxCrossEntropyWithLogits";
const char *SNAPSHOT = "Snapshot";
const char *VAR = "Var";
const char *MEANGRAD = "MeanGrad";
const char *TRANSLATE = "Translate";
const char *ADDN = "AddN";
const char *L2LOSS = "L2Loss";
const char *MULTIPLY = "Multiply";
const char *HUBERLOSSGRAD = "HuberLossGrad";
const char *HUBERLOSS = "HuberLoss";
const char *NEGATIVE = "Negative";
const char *SSDCAST = "SSDCast";
const char *SPARSESOFTMAXCROSSENTROPY = "SsdSparseSoftmaxCrossEntropy";
const char *SPARSESOFTMAXCROSSENTROPYGRAD = "SsdSparseSoftmaxCrossEntropyGrad";
const char *SSDSQUEEZEFUSION = "SsdSqueezeFusion";
const char *CONCATFOUR2FIVE = "ConcatFour2Five";
const char *CONCATFIVE2FOUR = "ConcatFive2Four";
const char *SSDREALDIVTILEMUL = "SSDRealdivTileMul";
const char *SSDSUMMULREALDIVMEAN = "SSDSumMulRealdivMean";

const char *VARIABLEV2 = "VariableV2";
const char *VARHANDLEOP = "VarHandleOp";
const char *TEMPORARYVARIABLE = "TemporaryVariable";
const char *DESTROYTEMPORARYVARIABLE = "DestroyTemporaryVariable";
const char *VARIABLE = "Variable";
const char *ASSIGN = "Assign";
const char *ASSIGNVARIABLEOP = "AssignVariableOp";
const char *ASSIGNADD = "AssignAdd";
const char *ASSIGNADDVARIABLEOP = "AssignAddVariableOp";
const char *ASSIGNSUB = "AssignSub";
const char *ASSIGNSUBVARIABLEOP = "AssignSubVariableOp";
const char *APPLYMOMENTUM = "ApplyMomentum";
const char *RESOURCEAPPLYMOMENTUM = "ResourceApplyMomentum";
const char *SGD = "SGD";
const char *NOOP = "NoOp";
const char *READVARIABLEOP = "ReadVariableOp";
const char *PARALLELCONCATSTART = "_ParallelConcatStart";
const char *CONSTANTOP = "Constant";
const char *DEPTHWISECONV2DBACKPROPFILTER = "DepthwiseConv2dNativeBackpropFilter";
const char *DEPTHWISECONV2DBACKPORPINPUT = "DepthwiseConv2dNativeBackpropInput";
const char *DEPTHWISECONV2DFORWARDNATIVE = "DepthwiseConv2dNative";
const char *DROPOUTGRAD = "DropOutGrad";
const char *APPLYRMSPROPMIXEDPRECISION = "apply_rms_prop_mixed_precision";
const char *APPLYRMSPROP = "ApplyRMSProp";
const char *RELU6GRAD = "Relu6Grad";
const char *AVGPOOLGRAD = "AvgPoolGrad";
const char *CONCATV2 = "ConcatV2";
const char *CONCATOFFSET = "ConcatOffset";
const char *LAYERNORMGRAD = "LayerNormGrad";
const char *LAYERNORM = "LayerNorm";
const char *LARS = "Lars";
const char *DYNAMICSTITCH = "DynamicStitch";

/***************************************************/
const char *SQUARE = "Square";
const char *HCOMBROADCAST = "HcomBroadcast";
const char *HCOMALLGATHER = "HcomAllGather";
const char *HCOMALLREDUCE = "HcomAllReduce";
const char *HCOMREDUCESCATTER = "HcomReduceScatter";
const char *HCOMSEND = "HcomSend";
const char *HCOMRECEIVE = "HcomReceive";
const char *HCOMREMOTEREAD = "HcomRemoteRead";
const char *HCOMREMOTEWRITE = "HcomRemoteWrite";

const char *VARASSIGN = "VarAssign";
const char *VARISINITIALIZEDOP = "VarIsInitializedOp";
const char *LogTimeStamp = "LogTimeStamp";
const char *ISVARIABLEINITIALIZED = "IsVariableInitialized";
const char *STREAMSWITCH = "StreamSwitch";
const char *STREAMSWITCHN = "StreamSwitchN";
const char *STREAMACTIVE = "StreamActive";
const char *MEMCPYASYNC = "MemcpyAsync";
const char *MEMCPYADDRASYNC = "MemcpyAddrAsync";
const char *STREAMMERGE = "StreamMerge";
const char *ENDGRAPH = "EndGraph";
const char *SEND = "Send";
const char *RECV = "Recv";
const char *ENDOFSEQUENCE = "EndOfSequence";

const char *LABELSET = "LabelSet";
const char *LABELGOTO = "LabelGoto";
const char *LABELGOTOEX = "LabelGotoEx";
const char *LABELSWITCH = "LabelSwitch";
const char *LABELSWITCHBYINDEX = "LabelSwitchByIndex";

const char *ATOMICADDRCLEAN = "AtomicAddrClean";

const char *ABS_GRAD = "AbsGrad";
const char *ACCUMULATE_N_V2 = "AccumulateNV2";
const char *ACOS_GRAD = "AcosGrad";
const char *ACOSH_GRAD = "AcoshGrad";
const char *ANY = "Any";
const char *APPROXIMATE_EQUAL = "ApproximateEqual";
const char *ASIN_GRAD = "AsinGrad";
const char *ASINH_GRAD = "AsinhGrad";
const char *ATAN_GRAD = "AtanGrad";
const char *BROADCAST_TO = "BroadcastTo";
const char *ELU_GRAD = "EluGrad";
const char *ADD_V2 = "AddV2";
const char *DATAFORMATDIMMAP = "DataFormatDimMap";
const char *DATAFORMATVECPERMUTE = "DataFormatVecPermute";
const char *BESSELI0E = "BesselI0e";
const char *BESSELI1E = "BesselI1e";
const char *APPLYADADELTA = "ApplyAdadelta";
const char *APPLYADAGRAD = "ApplyAdagrad";
const char *APPLYADAGRADDA = "ApplyAdagradDA";
const char *APPLYADAM = "ApplyAdam";
const char *APPLYADAMAX = "ApplyAdaMax";
const char *APPLYADDSIGN = "ApplyAddSign";
const char *APPLYCENTEREDRMSPROP = "ApplyCenteredRMSProp";
const char *APPLYFTRL = "ApplyFtrl";
const char *APPLYFTRLV2 = "ApplyFtrlV2";
const char *APPLYGRADIENTDESCENT = "ApplyGradientDescent";
const char *APPLYPOWERSIGN = "ApplyPowerSign";
const char *APPLYPROXIMALADAGRAD = "ApplyProximalAdagrad";
const char *APPLYPROXIMALGRADIENTDESCENT = "ApplyProximalGradientDescent";
const char *DEQUANTIZE = "Dequantize";

const char *FOCAL_LOSS = "FocalLoss";
const char *FOCAL_LOSS_GRAD = "FocalLossGrad";
const char *SMOOTHL1_LOSS = "SmoothL1Loss";
const char *SMOOTHL1_LOSS_grad = "SmoothL1LossGrad";
const char *REDUCEMEAN = "ReduceMean";
const char *CONCAT_V2 = "ConcatV2";
const char *ONEHOT_V2 = "OneHotV2";
const char *SLICE_V2 = "SliceV2";
const char *TILE_V2 = "TileV2";
const char *SUM_V2 = "SumV2";
// Common type when the operator has the same name
const char *DETECTIONOUTPUT = "DetectionOutput";
// Custom operator
const char *CUSTOMOP = "CustomOp";
const char *CUSTOMOP_NCHW = "CustomOpNchw";
const char *CUSTOMOP_NHWC = "CustomOpNhwc";
const char *CUSTOMOP_NC1HWC0 = "CustomOpNc1hwc0";

// Depthwise 4d_2_6d,6d_2_4d
const char *DEPTHWISEWEIGHT4D26D = "depthwise_weight_4d_2_6d";
const char *DEPTHWISEWEIGHT6D24D = "depthwise_weight_6d_2_4d";

const char *SQRTGRAD = "SqrtGrad";
const char *SIGMOIDGRAD = "SigmoidGrad";

const char *TRANSSHAPE = "TransShape";

// Horovod operator
const char *HVDCALLBACKALLREDUCE = "HorovodAllreduce";
const char *HVDCALLBACKALLGATHER = "HorovodAllgather";
const char *HVDCALLBACKBROADCAST = "HorovodBroadcast";
const char *HVDWAIT = "HorovodWait";

///
/// @brief Magic number of model file
///
const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number

///
/// @brief Model head length
///
const uint32_t MODEL_FILE_HEAD_LEN = 256;

const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0///

///
/// @ingroup domi_omg
/// @brief alpha default value
///
const float ALPHA_DEFAULT_VALUE = 1.0;

///
/// @ingroup domi_omg
/// @brief beta default value
///
const float BETA_DEFAULT_VALUE = 0.0;

///
/// @ingroup domi_omg
/// @brief Input node type
///
const std::string INPUT_TYPE = "Input";
const std::string DUMMY_DATA = "DummyData";

// for fusion op plugin
const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type";

const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc";
const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc";

///
/// @ingroup domi_omg
/// @brief DATA node type
///
const std::string DATA_TYPE = "Data";

///
/// @ingroup domi_omg
/// @brief Frame operator type
///
const std::string FRAMEWORK_OP_TYPE = "FrameworkOp";

///
/// @ingroup domi_omg
/// @brief Convolution node type
///
const std::string NODE_NAME_NET_OUTPUT = "Node_Output";
} // namespace parser
} // namespace ge

+ 83
- 0
parser/common/pass_manager.cc View File

@@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/pass_manager.h"
#include "framework/omg/parser/parser_types.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/debug/log.h"
#include "graph/utils/node_utils.h"
#include "omg/omg_inner_types.h"

namespace ge {
namespace parser {
const vector<std::pair<std::string, GraphPass *>> &PassManager::GraphPasses() const { return names_to_graph_passes_; }

Status PassManager::AddPass(const string &pass_name, GraphPass *pass) {
GE_CHECK_NOTNULL(pass);
names_to_graph_passes_.emplace_back(pass_name, pass);
return SUCCESS;
}

Status PassManager::Run(const ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
return Run(graph, names_to_graph_passes_);
}

Status PassManager::Run(const ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &names_to_passes) {
GE_CHECK_NOTNULL(graph);
bool not_changed = true;

for (auto &pass_pair : names_to_passes) {
const auto &pass = pass_pair.second;
const auto &pass_name = pass_pair.first;
GE_CHECK_NOTNULL(pass);

PARSER_TIMESTAMP_START(PassRun);
Status status = pass->Run(graph);
if (status == SUCCESS) {
not_changed = false;
} else if (status != NOT_CHANGED) {
GELOGE(status, "Pass Run failed on graph %s", graph->GetName().c_str());
return status;
}
for (const auto &subgraph :graph->GetAllSubgraphs()) {
GE_CHECK_NOTNULL(subgraph);
GE_CHK_STATUS_RET(pass->ClearStatus(), "pass clear status failed for subgraph %s", subgraph->GetName().c_str());
string subgraph_pass_name = pass_name + "::" + graph->GetName();
PARSER_TIMESTAMP_START(PassRunSubgraph);
status = pass->Run(subgraph);
PARSER_TIMESTAMP_END(PassRunSubgraph, subgraph_pass_name.c_str());
if (status == SUCCESS) {
not_changed = false;
} else if (status != NOT_CHANGED) {
GELOGE(status, "Pass Run failed on subgraph %s", subgraph->GetName().c_str());
return status;
}
}
PARSER_TIMESTAMP_END(PassRun, pass_name.c_str());
}

return not_changed ? NOT_CHANGED : SUCCESS;
}

PassManager::~PassManager() {
for (auto &pass_pair : names_to_graph_passes_) {
auto &pass = pass_pair.second;
GE_DELETE_NEW_SINGLE(pass);
}
}
} // namespace parser
} // namespace ge

+ 76
- 0
parser/common/pass_manager.h View File

@@ -0,0 +1,76 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_PASS_MANAGER_H_
#define PARSER_COMMON_PASS_MANAGER_H_

#include <vector>

#include "inc/graph_pass.h"

using std::vector;

namespace ge {
namespace parser {
///
/// @ingroup domi_omg
/// @brief pass manager
/// @author
///
class PassManager {
public:
///
/// get graph passes
/// @author
///
const vector<std::pair<std::string, GraphPass *>> &GraphPasses() const;

///
/// Add graph pass
/// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys.
/// @author
///
Status AddPass(const string &pass_name, GraphPass *pass);

///
/// Optimize graph with added pass
/// @param [inout] graph graph to be optimized
/// @return SUCCESS optimize successfully
/// @return NOT_CHANGED not optimized
/// @return others optimize failed
/// @author
///
Status Run(const ge::ComputeGraphPtr &graph);

///
/// Optimize graph with specified pass
/// @param [inout] graph graph to be optimized
/// @param [in] passes passes to be used
/// @return SUCCESS optimize successfully
/// @return NOT_CHANGED not optimized
/// @return others optimized failed
/// @author
///
static Status Run(const ge::ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &passes);

~PassManager();

private:
vector<std::pair<std::string, GraphPass *>> names_to_graph_passes_;
};
} // namespace parser
} // namespace ge
#endif // PARSER_COMMON_PASS_MANAGER_H_

+ 287
- 0
parser/common/pre_checker.cc View File

@@ -0,0 +1,287 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/pre_checker.h"
#include <nlohmann/json.hpp>
#include "common/model_saver.h"
#include "common/op_map.h"
#include "common/util.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "omg/omg.h"
#include "parser/common/op_parser_factory.h"
#include "parser/common/model_saver.h"
#include "register/op_registry.h"

namespace ge {
// Keys in JSON file
namespace {
const char *const kKeyName = "name";
const char *const kKeyResult = "result";
const char *const kKeyTotal = "total";
const char *const kKeyPass = "pass";
const char *const kKeyFail = "fail";
const char *const kKeyOp = "op";
const char *const kKeyOpName = "name";
const char *const kKeyOpType = "type";
const char *const kKeyOpResult = "result";
const char *const kKeyCause = "cause";
const char *const kKeyCauseCode = "code";
const char *const kKeyCauseMessage = "message";

// Checking result and support warning later
const char *const kResultSuccess = "success";
const char *const kResultFailed = "failed";
} // namespace

PreChecker::PreChecker() : fmk_op_types_(nullptr) { Init(); }

void PreChecker::Init() {
model_name_.clear();
op_map_.clear();
ops_.clear();
fmk_op_types_ = nullptr;

// Currently only Caffe and tensorflow are supported
domi::FrameworkType fmk_type = GetParserContext().type;
if (fmk_type == domi::CAFFE)
fmk_op_types_ = &caffe_op_map;
else if (fmk_type == domi::TENSORFLOW)
fmk_op_types_ = &tensorflow_op_map;
else
return;
}

PreChecker::~PreChecker() {}

FMK_FUNC_HOST_VISIBILITY PreChecker &PreChecker::Instance() {
static PreChecker instance;
return instance;
}

FMK_FUNC_HOST_VISIBILITY void PreChecker::SetModelName(const string &name) { model_name_ = name; }

FMK_FUNC_HOST_VISIBILITY Status PreChecker::AddOp(OpId id, const string &name, const string &type) {
GE_RETURN_WITH_LOG_IF_TRUE(op_map_.find(id) != op_map_.end(), "Id already exists.");

Info info;
info.id = id;
info.name = name;
info.type = type;
op_map_[id] = info;
ops_.push_back(id);

return SUCCESS;
}

Status PreChecker::CheckName(OpId id) {
auto iter = op_map_.find(id);
GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist.");

Info &info = iter->second;
for (auto &v : op_map_) {
// If the name is duplicate, an error is logged
if (id != v.first && info.name == v.second.name) {
Cause cause;
cause.code = NAME_REPEATED;
cause.message = "The name is repeated.";

GELOGI("Name %s repeated.", info.name.c_str());
ErrorManager::GetInstance().ATCReportErrMessage("E19009", {"opname"}, {info.name});
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed.");
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(v.first, cause), "Add cause failed.");
break;
}
}

return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY Status PreChecker::CheckType(OpId id, bool is_tensorflow) {
auto iter = op_map_.find(id);
GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist.");

Info &info = iter->second;
string type = info.type;

// If the user explicitly specifies the mapping relationship of the operator type through
// the -- OP_name_map parameter, the type specified by the user is used.
auto op_map_iter = GetParserContext().op_conf_map.find(type);
if (op_map_iter != GetParserContext().op_conf_map.end()) {
type = op_map_iter->second;
}

// Judge whether the type is supported
GE_RETURN_WITH_LOG_IF_ERROR(
CheckTypeSupported(info.id, type, info.name, is_tensorflow), "Check type supported failed.");

return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY Status PreChecker::AddCause(OpId id, ErrorCode code, const string &msg) {
Cause cause;
cause.code = code;
cause.message = msg;
return AddCause(id, cause);
}

FMK_FUNC_HOST_VISIBILITY void PreChecker::RefreshErrorMessageByName(const string &op_name, ErrorCode code,
const string &msg) {
for (const auto &op : op_map_) {
if (op.second.name == op_name) {
AddCause(op.second.id, code, msg);
return;
}
}
GELOGW("Node [%s] not founded in prechecking list.", op_name.c_str());
}

Status PreChecker::AddCause(OpId id, const Cause &cause) {
auto iter = op_map_.find(id);
GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist.");

Info &info = iter->second;

// Avoid adding repeatedly
for (Cause &c : info.causes) {
if (c.code == cause.code && c.message == cause.message) {
return SUCCESS;
}
}

info.causes.push_back(cause);

return SUCCESS;
}

void PreChecker::Clear() { Init(); }

Status PreChecker::Clear(OpId id, const string &message) {
auto iter = op_map_.find(id);
GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist.");

Info &info = iter->second;
info.causes.clear();

// Set additional information
if (message != "") {
Cause cause;
cause.code = ErrorCode::OK;
cause.message = message;
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed.");
}

return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY bool PreChecker::HasError() {
for (auto id : ops_) {
if (HasError(id)) {
return true;
}
}

return false;
}

Status PreChecker::Save(string file) {
uint32_t fail_num = 0;
for (auto id : ops_) {
if (HasError(id)) {
fail_num++;
}
}

// Initialization model related JSON information
nlohmann::json model;
model[kKeyName] = model_name_;
model[kKeyResult] = HasError() ? kResultFailed : kResultSuccess;
model[kKeyTotal] = ops_.size();
model[kKeyPass] = ops_.size() - fail_num;
model[kKeyFail] = fail_num;

// Constructing JSON information of operators in order of network
for (auto id : ops_) {
auto iter = op_map_.find(id);
GE_CHK_BOOL_RET_STATUS(iter != op_map_.end(), FAILED, "don't find this op.");
Info &info = iter->second;

// Initialization operator general information
nlohmann::json op = {{kKeyOpName, info.name}, {kKeyOpType, info.type}};
op[kKeyOpResult] = HasError(id) ? kResultFailed : kResultSuccess;

// handle causes
for (const Cause &cause : info.causes) {
nlohmann::json cause_j = {{kKeyCauseCode, cause.code}, {kKeyCauseMessage, cause.message}};
op[kKeyCause].push_back(cause_j);
}

model[kKeyOp].push_back(op);
}

// Save JSON data to a file
GE_RETURN_WITH_LOG_IF_ERROR(ge::parser::ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed.");

return SUCCESS;
}

Status PreChecker::CheckTypeSupported(OpId id, const string &type, const string &name, bool is_tensorflow) {
// Currently only partial framework type checking is supported
if (fmk_op_types_ == nullptr) {
std::string op_type;
if (!domi::OpRegistry::Instance()->GetOmTypeByOriOpType(type, op_type)) {
Cause cause;
cause.code = TYPE_UNSUPPORTED;
cause.message = "The type is not supported.";
GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str());
if (!is_tensorflow) {
ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type});
}
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed.");
}
return SUCCESS;
}

// Log error if type not found
if (fmk_op_types_->find(type) == fmk_op_types_->end()) {
Cause cause;
cause.code = TYPE_UNSUPPORTED;
cause.message = "The type is not supported.";

GELOGI("Check op[%s]'s type[%s] failed, it is not supported.", name.c_str(), type.c_str());
if (!is_tensorflow) {
ErrorManager::GetInstance().ATCReportErrMessage("E19010", {"opname", "optype"}, {name, type});
}
GE_RETURN_WITH_LOG_IF_ERROR(AddCause(id, cause), "Add cause failed.");
}

return SUCCESS;
}

bool PreChecker::HasError(OpId id) {
auto iter = op_map_.find(id);
GE_RETURN_WITH_LOG_IF_TRUE(iter == op_map_.end(), "Id does not exist.");

Info &info = iter->second;
for (const Cause &cause : info.causes) {
if (cause.code != ErrorCode::OK) {
return true;
}
}

return false;
}
} // namespace ge

+ 194
- 0
parser/common/pre_checker.h View File

@@ -0,0 +1,194 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_PRE_CHECKER_H_
#define PARSER_COMMON_PRE_CHECKER_H_

#include <string>
#include <vector>
#include "framework/omg/parser/parser_types.h"
#include "omg/omg_inner_types.h"

namespace ge {
using std::map;
using std::string;
using std::vector;
using Status = domi::Status;
/**
* @ingroup domi_omg
* @brief pre_check
* @author
*/
class PreChecker {
public:
/**
* @ingroup domi_omg
* @brief Operator unique identification
*/
using OpId = const void *;

/**
* @ingroup domi_omg
* @brief error code, 1~99:Error, 100~199:Waring。
*/
enum ErrorCode {
// no error
OK = 0,

// type unsupported
TYPE_UNSUPPORTED = 1,

// param invalid
PARAM_INVALID = 2,

// type ambiguous
TYPE_AMBIGUOUS = 8,

// name repeated
NAME_REPEATED = 9
};

/**
* @ingroup domi_omg
* @brief Operator error description
*/
struct Cause {
// error code
ErrorCode code;

// error message
string message;
};

public:
/**
* @ingroup domi_omg
* @brief instance interface
*/
static PreChecker &Instance();

/**
* @ingroup domi_omg
* @brief set model name
*/
void SetModelName(const string &name);

/**
* @ingroup domi_omg
* @brief add op information
*/
Status AddOp(OpId id, const string &name, const string &type);

/**
* @ingroup domi_omg
* @brief Judge whether the operator name is duplicate
*/
Status CheckName(OpId id);

/**
* @ingroup domi_omg
* @brief check operation type
* 1、Check whether the operator type supports according to the global frameworktype
* 2、Check if the operator type is ambiguous
*/
Status CheckType(OpId id, bool is_tensorflow = false);

void RefreshErrorMessageByName(const string &op_name, ErrorCode code, const string& msg);

/**
* @ingroup domi_omg
* @brief Add custom error description
*/
Status AddCause(OpId id, ErrorCode code, const string &msg);

/**
* @ingroup domi_omg
* @brief Add custom error description
*/
Status AddCause(OpId id, const Cause &cause);

/**
* @ingroup domi_omg
* @brief Clear all operator information
*/
void Clear();

/**
* @ingroup domi_omg
* @brief Clear the error information of the specified operator
*/
Status Clear(OpId id, const string &message = "");

/**
* @ingroup domi_omg
* @brief Determine if an error has been detected
*/
bool HasError();

/**
* @ingroup domi_omg
* @brief Save inspection results(JSON)
*/
Status Save(string file);

private:
/**
* @ingroup domi_omg
* @brief operation information
*/
struct Info {
// Operator identifier
OpId id;

// Operator name
string name;

// Operator type
string type;

// Error description, which may contain multiple (for example, both name and type are illegal)
vector<Cause> causes;
};

PreChecker();
~PreChecker();
PreChecker(const PreChecker &);
PreChecker &operator=(const PreChecker &);

// Initialize internal data
void Init();

// Judge whether the type is supported
Status CheckTypeSupported(OpId id, const string &type, const string &name, bool is_tensorflow);

// Determine if an error has been detected
bool HasError(OpId id);

private:
// model name
string model_name_;

// Save operator check results
map<OpId, Info> op_map_;

// Save operator list in original order
vector<OpId> ops_;

// save frame related operator types
map<string, string> *fmk_op_types_;
};
} // namespace ge
#endif // PARSER_COMMON_PRE_CHECKER_H_

+ 190
- 0
parser/common/proto/ge_ir.proto View File

@@ -0,0 +1,190 @@
syntax = "proto3";

package ge.proto;

enum DataType
{
DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set.
DT_FLOAT = 1; // float type
DT_FLOAT16 = 2; // fp16 type
DT_INT8 = 3; // int8 type
DT_UINT8 = 4; // uint8 type
DT_INT16 = 5; // int16 type
DT_UINT16 = 6; // uint16 type
DT_INT32 = 7; //
DT_INT64 = 8; // int64 type
DT_UINT32 = 9; // unsigned int32
DT_UINT64 = 10; // unsigned int64
DT_BOOL = 11; // bool type
DT_DOUBLE = 12; // double type
DT_STRING = 13; // string type
DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */
DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */
DT_COMPLEX64 = 16; // complex64 type
DT_COMPLEX128 = 17; // complex128 type
DT_QINT8 = 18; // qint8 type
DT_QINT16 = 19; // qint16 type
DT_QINT32 = 20; // qint32 type
DT_QUINT8 = 21; // quint8 type
DT_QUINT16 = 22; // quint16 type
DT_RESOURCE = 23; // resource type
DT_STRING_REF = 24; // string_ref type
DT_DUAL = 25; /**< dual output type */
}

message AttrDef
{
message ListValue
{
enum ListValueType{
VT_LIST_NONE = 0;
VT_LIST_STRING = 1;
VT_LIST_INT = 2;
VT_LIST_FLOAT = 3;
VT_LIST_BOOL = 4;
VT_LIST_BYTES = 5;
VT_LIST_TENSOR_DESC = 6;
VT_LIST_TENSOR = 7;
VT_LIST_GRAPH = 8;
VT_LIST_NAMED_ATTRS = 9;
VT_LIST_DATA_TYPE = 10;
}
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3; // "list(int)"
repeated float f = 4; // "list(float)"
repeated bool b = 5; // "list(bool)"
repeated bytes bt = 7;
repeated TensorDescriptor td = 8;
repeated TensorDef t = 9;
repeated GraphDef g = 10;
repeated NamedAttrs na = 11;
repeated int64 dt = 12; // list ge::DataType

ListValueType val_type = 20;
}

message ListListInt{
message ListInt{
repeated int64 list_i = 1; // list int
}
repeated ListInt list_list_i = 1; // list list int
}

oneof value
{
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10; // Used to support attr nesting
TensorDescriptor td = 11; // GeTensorDesc type
TensorDef t = 12; // GeTensor type
GraphDef g = 13; // Graph type
ListListInt list_list_int = 14; // List List Int type
int64 dt = 15; // ge::DataType
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs
{
string name = 1;
map<string, AttrDef> attr = 2;
}

// Shape / dimension description, using row-major order
message ShapeDef
{
repeated int64 dim = 1; // Size of each dimension
}

// Multidimensional data description
message TensorDescriptor
{
string name = 1; // Optional parameter, tensor name

DataType dtype = 2; // tensor datatype
ShapeDef shape = 3; // Shape / dimension
string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND"

bool has_out_attr = 9;
int64 size = 10;
int64 weight_size = 11;
bool reuse_input = 12;
bool output_tensor = 13;
string device_type = 14;
bool input_tensor =15;
int64 real_dim_cnt = 16;
int64 reuse_input_index = 17;
int64 data_offset = 18;
int64 cmps_size = 19;
string cmps_tab = 20;
int64 cmps_tab_offset = 21;

map<string, AttrDef> attr = 5; // Set of extra parameter fields
}

// GeTensor definition
message TensorDef
{
TensorDescriptor desc = 1; // Tensor description
bytes data = 2; // Tensor data
}


// Operator description
message OpDef
{
string name = 1; // name
string type = 2; // type

repeated string input = 5; // input original op name + outgoing index. op_name:index

map<string, AttrDef> attr = 10; // Set of operator parameter fields

bool has_out_attr = 20;
int64 id = 21;
int64 stream_id =22;
repeated string input_name = 23;
repeated string src_name = 24;
repeated int64 src_index = 25;
repeated string dst_name = 26;
repeated int64 dst_index = 27;
repeated int64 input_i = 28;
repeated int64 output_i = 29;
repeated int64 workspace = 30;
repeated int64 workspace_bytes = 31;
repeated bool is_input_const = 32;
repeated TensorDescriptor input_desc = 33;
repeated TensorDescriptor output_desc = 34;
repeated string subgraph_name = 35;
}

// Graph definition
message GraphDef
{
string name = 1; // name

repeated string input = 4; // Graph input
repeated string output = 5; // Graph output

repeated OpDef op = 6; // List of operators

map<string, AttrDef> attr = 11; // Extended field
}

// model definition
message ModelDef
{
string name = 1; // name
uint32 version = 2; // IR Proto verion
string custom_version = 3; // User model version number, passed in by user

repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef

map<string, AttrDef> attr = 11; // Extended field
}


+ 136
- 0
parser/common/proto/insert_op.proto View File

@@ -0,0 +1,136 @@
syntax = "proto3";

package domi;

message InsertNewOps {
repeated AippOpParams aipp_op = 1;
repeated MultiShapeOpParams multi_shape_op = 2;
}

message AippOpParams {
enum InputFormat {
UNDEFINED = 0;
YUV420SP_U8 = 1;
XRGB8888_U8 = 2;
RGB888_U8 = 3;
YUV400_U8 = 4;
NC1HWC0DI_FP16 = 5;
NC1HWC0DI_S8 = 6;
ARGB8888_U8 = 7;
YUYV_U8 = 8;
YUV422SP_U8 = 9;
AYUV444_U8 = 10;
RAW10 = 11;
RAW12 = 12;
RAW16 = 13;
RAW24 = 14;
RGB16 = 15;
RGB20 = 16;
RGB24 = 17;
RGB8_IR = 18;
RGB16_IR = 19;
RGB24_IR = 20;
}

enum AippMode {
undefined = 0;
static = 1;
dynamic = 2;
}

// AIPP模式,区分静态AIPP和动态AIPP
AippMode aipp_mode = 1;

// related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。
uint32 related_input_rank = 2;

// input_edge_idx参数为可选,类型为整型,配置范围为>=0。
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。
// 配置值 <= Data算子输出边的个数。
repeated uint32 input_edge_idx = 3;

// [Begin] 动态AIPP参数,配置静态AIPP时无效
uint32 max_src_image_size = 4;

// 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失
bool support_rotation = 5;

// [End] 动态AIPP参数


// [Begin] 静态AIPP参数,配置动态AIPP时无效
InputFormat input_format = 51;
bool csc_switch = 52;
float cpadding_value = 53;
bool rbuv_swap_switch = 54;
bool ax_swap_switch = 55;
bool single_line_mode = 56;

int32 src_image_size_w = 57;
int32 src_image_size_h = 58;

bool crop = 59;
int32 load_start_pos_w = 60;
int32 load_start_pos_h = 61;
int32 crop_size_w = 62;
int32 crop_size_h = 63;

bool resize = 64;
int32 resize_output_w = 65;
int32 resize_output_h = 66;

bool padding = 67;
int32 left_padding_size = 68;
int32 right_padding_size = 69;
int32 top_padding_size = 70;
int32 bottom_padding_size = 71;

int32 mean_chn_0 = 10;
int32 mean_chn_1 = 11;
int32 mean_chn_2 = 12;
int32 mean_chn_3 = 19;
float min_chn_0 = 13;
float min_chn_1 = 14;
float min_chn_2 = 15;
float min_chn_3 = 20;
repeated float var_reci_chn_0 = 16;
repeated float var_reci_chn_1 = 17;
repeated float var_reci_chn_2 = 18;
repeated float var_reci_chn_3 = 21;

repeated int32 matrix_r0c0 = 30;
repeated int32 matrix_r0c1 = 31;
repeated int32 matrix_r0c2 = 32;
repeated int32 matrix_r1c0 = 33;
repeated int32 matrix_r1c1 = 34;
repeated int32 matrix_r1c2 = 35;
repeated int32 matrix_r2c0 = 36;
repeated int32 matrix_r2c1 = 37;
repeated int32 matrix_r2c2 = 38;
repeated int32 output_bias_0 = 39;
repeated int32 output_bias_1 = 40;
repeated int32 output_bias_2 = 41;
repeated int32 input_bias_0 = 42;
repeated int32 input_bias_1 = 43;
repeated int32 input_bias_2 = 44;

// [End] 静态AIPP参数

// The n number that is used for raw/rgbir data into f16 transformation.
// The transformation equation is x/(2^n). If set to 0, no transform is performed.
uint32 raw_rgbir_to_f16_n = 45;
}

message MultiShapeOpParams {
enum MultiShapeMode {
batch = 0; //动态batch
resolution = 1; //动态分辨率,扩展用
}

MultiShapeMode mode = 1; //算子模式
uint32 related_input_rank = 2; //新增算子插入到哪个输入


repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间
}

+ 396
- 0
parser/common/proto/om.proto View File

@@ -0,0 +1,396 @@
/* Copyright (C) 2018. Huawei Technologies Co., Ltd. All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the Apache License Version 2.0.You may not use this file except in compliance with the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Apache License for more details at
* http://www.apache.org/licenses/LICENSE-2.0
*/
syntax = "proto3";

package domi;

enum TargetType
{
MINI = 0;
TINY = 1;
LITE = 2;
}

// offline model
message ModelDef {
string name = 1;
uint32 version = 2;

uint64 memory_size = 10;
uint32 stream_num = 11;
uint32 event_num = 12;
uint64 weight_size = 13;
uint32 label_num = 15;
repeated OpDef op = 20;
TargetType target_type = 23;

map<string, AttrDef> attr = 30;
};

// operator define
message OpDef {
string name = 1;
string type = 2;

uint32 id = 3;
uint32 stream_id = 4;

repeated string input_name = 5;

repeated string src_name = 8;
repeated int32 src_index = 9;
repeated int64 input = 10;
repeated int64 output = 11;
repeated TensorDescriptor input_desc = 12;
repeated TensorDescriptor output_desc = 13;
repeated WeightDef weights = 14;
repeated string dst_name = 15;
repeated int32 dst_index = 16;

repeated int64 workspace = 20;
repeated uint32 workspace_bytes = 21;

repeated string weight_name = 22;
repeated bool is_input_const = 23;

map<string, AttrDef> attr = 30;

QuantizeFactorParams quantize_factor = 31;

oneof op_params {
// start at 100 here
SendOpParams sender_param = 100;
RecvOpParams receiver_param = 200;
ConvolutionOpParams convolution_param = 300;
PoolingOpParams pooling_param = 400;
EltwiseOpParams eltwise_param = 500;
BatchNormOpParams batchnorm_param = 600;
ScaleOpParams scale_param = 700;
FullConnectionOpParams full_connection_param = 800;
SoftmaxOpParams softmax_param = 900;
ActivationOpParams activation_param = 1000;
ReshapeOpParams reshape_param = 1100;
}
};

message SendOpParams {
uint32 event_id = 1;
};

message RecvOpParams {
uint32 event_id = 1;
};

enum QuantizeScaleType
{
VECTOR_SCALE = 0;
SCALAR_SCALE = 1;
}

enum QuantizeScaleMode
{
NORMAL_MODE = 0;
SQRT_MODE = 1;
}

enum QuantizeAlgorithm
{
NON_OFFSET_ALGO = 0;
HALF_OFFSET_ALGO = 1;
ALL_OFFSET_ALGO = 2;
}
message QuantizeFactor
{
QuantizeScaleMode scale_mode = 1;
bytes scale_value = 2;
int64 scale_offset = 3;
bytes offset_data_value = 4;
int64 offset_data_offset = 5;
bytes offset_weight_value = 6;
int64 offset_weight_offset = 7;
bytes offset_pad_value = 8;
int64 offset_pad_offset = 9;
};

message QuantizeCalcFactor
{
bytes offsetw = 1;
int64 offsetw_offset = 2;
bytes offsetd = 3;
int64 offsetd_offset = 4;
bytes scalereq = 5;
int64 scaledreq_offset = 6;
bytes offsetdnext = 7;
int64 offsetdnext_offset = 8;
}

message QuantizeFactorParams
{
QuantizeAlgorithm quantize_algo = 1;
QuantizeScaleType scale_type = 2;
QuantizeFactor quantize_param = 3;
QuantizeFactor dequantize_param = 4;
QuantizeFactor requantize_param = 5;
QuantizeCalcFactor quantizecalc_param = 6;
};

message ConvolutionOpParams {
int32 mode = 1;
int32 algo = 2;
int32 pad_mode = 3;
uint32 group = 4;
uint32 num_output = 5;

repeated uint32 pad = 10;
repeated uint32 stride = 11;
repeated uint32 dilation = 12;
repeated uint32 kernel = 13;

float alpha = 20;
float beta = 21;

WeightDef filter = 40;
WeightDef bias = 41;

bool relu_flag = 62;
repeated uint32 adj = 70;
repeated uint32 target_shape = 71;
repeated uint32 before_pad = 72;
};

message PoolingOpParams {
int32 mode = 1;
int32 nan_opt = 2;
int32 pad_mode = 3;
bool global_pooling = 4;

repeated uint32 window = 10;
repeated uint32 pad = 11;
repeated uint32 stride = 12;
bool ceil_mode = 13;
int32 data_mode = 14;

float alpha = 20;
float beta = 21;
repeated uint32 before_pad = 22;
};

message EltwiseOpParams {
int32 mode = 1;
repeated float coeff = 2;
float alpha = 3;
float beta = 4;
repeated WeightDef weight = 5;
bool relu_flag = 6;
};

message ActivationOpParams {
int32 mode = 1;
float coef = 2;
float alpha = 3;
float beta = 4;
};

message BatchNormOpParams {
int32 mode = 1;

float alpha = 2;
float beta = 3;
double epsilon = 4;//optinal,[default = 1e-5]
bool use_global_stats = 5; //optinal,by default true,testing mode
float moving_average_fraction = 6; //optinal,[default = .999];

WeightDef estimated_mean = 7;
WeightDef estimated_variance = 8;

WeightDef scale = 9;
WeightDef bias = 10;
};

message ScaleOpParams {
WeightDef scale = 1;
WeightDef bias = 2;
};

message ReshapeOpParams {
float alpha = 1;
float beta = 2;
ShapeDef shape = 3;
int32 axis = 4;
int32 num_axes = 5;
int32 format = 6;
};

message SoftmaxOpParams {
int32 algo = 1;
int32 mode = 2;
float alpha = 3;
float beta = 4;
};

message FullConnectionOpParams {
WeightDef filter = 1;
WeightDef bias = 2;
uint32 num_output = 3;
bool relu_flag = 12;
};

message FlattenOpParams {
float alpha = 1;
float beta = 2;
int32 start_axis = 3;
int32 end_axis = 4;
}

message AddLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message MulLimitedOpParams {
float alpha = 1;
float beta = 2;
int32 axis = 3;
bool broadcast = 4;

repeated WeightDef weight = 10;
};

message AddOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message MulOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message SubOpParams {
float alpha = 1;
float beta = 2;

repeated WeightDef weight = 10;
};

message BiasAddOpParams {
float alpha = 1;
float beta = 2;

WeightDef bias = 10;
};

message MatMulOpParams {
float alpha = 1;
float beta = 2;
bool transposeX = 3;
bool transposeW = 4;

WeightDef filter = 10;
WeightDef bias = 12;
};

message RsqrtOpParams {
float alpha = 1;
float beta = 2;
};


message WeightDef {
int32 format = 1;
int32 data_type = 2;
ShapeDef shape = 3;
bytes data = 4;
int64 data_offset = 5;
uint32 cmps_size = 6;
bytes cmps_tab = 7;
int64 cmps_tab_offset = 10;
CompressInfo cmps_info = 8;
AllOffsetQuantizeInfo alloffset_quantize_info = 11;
}

message ShapeDef {
repeated int64 dim = 1;
}

enum DeviceType {
NPU = 0; // In default, we will use NPU.
CPU = 1; // CPU
}

message AllOffsetQuantizeInfo {
float scale = 1;
int32 offset = 2;
}

message TensorDescriptor {
int32 format = 1;
int32 data_type = 2;
repeated int64 dim = 3;
uint32 size = 4;
bool reuse_input = 5;
bool output_tensor = 7;
DeviceType device_type = 8;
bool input_tensor = 9;
uint32 real_dim_cnt = 10;
uint32 reuse_input_index = 11;
AllOffsetQuantizeInfo alloffset_quantize_info = 12;
}

message CompressInfo {
int32 blockRow = 1; // block row
int32 blockCol = 2; // block col
int32 fractalK = 3; // fractal K
int32 fractalN = 4; // fractal N
int32 lastFractalK = 5; // K of last fractal
int32 lastFractalN = 6; // N of last fractal
int32 cubeSize = 7; // cube's length
int32 loadDir = 8; // data load directtiono 0:col load 1:row load
}

message AttrDef {
message ListValue {
repeated string s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated uint32 u = 6 [packed = true]; // "list(uint)"
repeated bytes bt = 7;
}

oneof value {
string s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
uint32 u = 6; // "uint32"
bytes bt = 7;
ListValue list = 1; // any "list(...)"
NamedAttrs func = 10;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NamedAttrs {
string name = 1;
map<string, AttrDef> attr = 2;
}


+ 62
- 0
parser/common/proto/tensorflow/attr_value.proto View File

@@ -0,0 +1,62 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "AttrValueProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "tensor.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing the value for an attr used to configure an Op.
// Comment indicates the corresponding attr type. Only the field matching the
// attr type may be filled.
message AttrValue {
// LINT.IfChange
message ListValue {
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated DataType type = 6 [packed = true]; // "list(type)"
repeated TensorShapeProto shape = 7; // "list(shape)"
repeated TensorProto tensor = 8; // "list(tensor)"
repeated NameAttrList func = 9; // "list(attr)"
}
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)

oneof value {
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
DataType type = 6; // "type"
TensorShapeProto shape = 7; // "shape"
TensorProto tensor = 8; // "tensor"
ListValue list = 1; // any "list(...)"

// "func" represents a function. func.name is a function's name or
// a primitive op's name. func.attr.first is the name of an attr
// defined for that function. func.attr.second is the value for
// that attr in the instantiation.
NameAttrList func = 10;

// This is a placeholder only used in nodes defined inside a
// function. It indicates the attr value will be supplied when
// the function is instantiated. For example, let us suppose a
// node "N" in function "FN". "N" has an attr "A" with value
// placeholder = "foo". When FN is instantiated with attr "foo"
// set to "bar", the instantiated node N's attr A will have been
// given the value "bar".
string placeholder = 9;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NameAttrList {
string name = 1;
map<string, AttrValue> attr = 2;
}

+ 100
- 0
parser/common/proto/tensorflow/function.proto View File

@@ -0,0 +1,100 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "node_def.proto";
import "op_def.proto";

// A library is a set of named functions.
message FunctionDefLibrary {
repeated FunctionDef function = 1;
repeated GradientDef gradient = 2;
}

// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
// * device spec, etc.
message FunctionDef {
// The definition of the function's name, arguments, return values,
// attrs etc.
OpDef signature = 1;

// Attributes specific to this function definition.
map<string, AttrValue> attr = 5;

// NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21.
reserved 2;

// In both of the following fields, there is the need to specify an
// output that is used as either the input to another node (in
// `node_def`) or as a return value of the function (in `ret`).
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
// list in some cases (instead of just single outputs). Also, we
// need to be able to deal with lists of unknown length (so the
// output index may not be known at function definition time). So
// we use the following format instead:
// * "fun_in" where "fun_in" is the name of a function input arg in
// the `signature` field above. This represents that input, whether
// it is a single tensor or a list.
// * "fun_in:0" gives the first element of a function input arg (a
// non-list input is considered a list of length 1 for these
// purposes).
// * "node:out" where "node" is the name of a node in `node_def` and
// "out" is the name one of its op's output arguments (the name
// comes from the OpDef of the node's op). This represents that
// node's output, whether it is a single tensor or a list.
// Note: We enforce that an op's output arguments are never
// renamed in the backwards-compatibility test.
// * "node:out:0" gives the first element of a node output arg (a
// non-list output is considered a list of length 1 for these
// purposes).
//
// NOT CURRENTLY SUPPORTED (but may be in the future):
// * "node:out:-1" gives last element in a node output list
// * "node:out:1:" gives a list with all but the first element in a
// node output list
// * "node:out::-1" gives a list with all but the last element in a
// node output list

// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
// may have values of type `placeholder` and the `input` field uses
// the "output" format above.

// By convention, "op" in node_def is resolved by consulting with a
// user-defined library first. If not resolved, "func" is assumed to
// be a builtin op.
repeated NodeDef node_def = 3;

// A mapping from the output arg names from `signature` to the
// outputs from `node_def` that should be returned by the function.
map<string, string> ret = 4;
}

// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
// dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
string function_name = 1; // The function name.
string gradient_func = 2; // The gradient function's name.
}

+ 56
- 0
parser/common/proto/tensorflow/graph.proto View File

@@ -0,0 +1,56 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "node_def.proto";
import "function.proto";
import "versions.proto";

// Represents the graph of operations
message GraphDef {
repeated NodeDef node = 1;

// Compatibility versions of the graph. See core/public/version.h for version
// history. The GraphDef version is distinct from the TensorFlow version, and
// each release of TensorFlow will support a range of GraphDef versions.
VersionDef versions = 4;

// Deprecated single version field; use versions above instead. Since all
// GraphDef changes before "versions" was introduced were forward
// compatible, this field is entirely ignored.
int32 version = 3 [deprecated = true];

// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
//
// "library" provides user-defined functions.
//
// Naming:
// * library.function.name are in a flat namespace.
// NOTE: We may need to change it to be hierarchical to support
// different orgs. E.g.,
// { "/google/nn", { ... }},
// { "/google/vision", { ... }}
// { "/org_foo/module_bar", { ... }}
// map<string, FunctionDefLib> named_lib;
// * If node[i].op is the name of one function in "library",
// node[i] is deemed as a function call. Otherwise, node[i].op
// must be a primitive operation supported by the runtime.
//
//
// Function call semantics:
//
// * The callee may start execution as soon as some of its inputs
// are ready. The caller may want to use Tuple() mechanism to
// ensure all inputs are ready in the same time.
//
// * The consumer of return values may start executing as soon as
// the return values the consumer depends on are ready. The
// consumer may want to use Tuple() mechanism to ensure the
// consumer does not start until all return values of the callee
// function are ready.
FunctionDefLibrary library = 2;
};

+ 63
- 0
parser/common/proto/tensorflow/node_def.proto View File

@@ -0,0 +1,63 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "NodeProto";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";

message NodeDef {
// The name given to this operator. Used for naming inputs,
// logging, visualization, etc. Unique within a single GraphDef.
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
string name = 1;

// The operation name. There may be custom parameters in attrs.
// Op names starting with an underscore are reserved for internal use.
string op = 2;

// Each input is "node:src_output" with "node" being a string name and
// "src_output" indicating which output tensor to use from "node". If
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
// may optionally be followed by control inputs that have the format
// "^node".
repeated string input = 3;

// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
string device = 4;

// Operation-specific graph-construction-time configuration.
// Note that this should include all attrs defined in the
// corresponding OpDef, including those with a value matching
// the default -- this allows the default to change and makes
// NodeDefs easier to interpret on their own. However, if
// an attr with a default is not specified in this list, the
// default will be used.
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
// one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef
// attr's type field.
// Add some examples here showing best practices.
map<string, AttrValue> attr = 5;
};

+ 164
- 0
parser/common/proto/tensorflow/op_def.proto View File

@@ -0,0 +1,164 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "OpDefProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "types.proto";

// Defines an operation. A NodeDef in a GraphDef specifies an Op by
// using the "op" field which should match the name of a OpDef.
// LINT.IfChange
message OpDef {
// Op names starting with an underscore are reserved for internal use.
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
string name = 1;

// For describing inputs and outputs.
message ArgDef {
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
string name = 1;

// Human readable description.
string description = 2;

// Describes the type of one or more tensors that are accepted/produced
// by this input/output arg. The only legal combinations are:
// * For a single tensor: either the "type" field is set or the
// "type_attr" field is set to the name of an attr with type "type".
// * For a sequence of tensors with the same type: the "number_attr"
// field will be set to the name of an attr with type "int", and
// either the "type" or "type_attr" field will be set as for
// single tensors.
// * For a sequence of tensors, the "type_list_attr" field will be set
// to the name of an attr with type "list(type)".
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
// If specified, attr must have type "list(type)", and none of
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;

// For inputs: if true, the inputs are required to be refs.
// By default, inputs can be either refs or non-refs.
// For outputs: if true, outputs are refs, otherwise they are not.
bool is_ref = 16;
};

// Description of the input(s).
repeated ArgDef input_arg = 2;

// Description of the output(s).
repeated ArgDef output_arg = 3;

// Description of the graph-construction-time configuration of this
// Op. That is to say, this describes the attr fields that will
// be specified in the NodeDef.
message AttrDef {
// A descriptive name for the argument. May be used, e.g. by the
// Python client, as a keyword argument name, and so should match
// the regexp "[a-z][a-z0-9_]+".
string name = 1;

// One of the type names from attr_value.proto ("string", "list(string)",
// "int", etc.).
string type = 2;

// A reasonable default for this attribute if the user does not supply
// a value. If not specified, the user must supply a value.
AttrValue default_value = 3;

// Human-readable description.
string description = 4;


// --- Constraints ---
// These constraints are only in effect if specified. Default is no
// constraints.

// For type == "int", this is a minimum value. For "list(___)"
// types, this is the minimum length.
bool has_minimum = 5;
int64 minimum = 6;

// The set of allowed values. Has type that is the "list" version
// of the "type" field above (uses the "list" field of AttrValue).
// If type == "type" or "list(type)" above, then the "type" field
// of "allowed_values.list" has the set of allowed DataTypes.
// If type == "string" or "list(string)", then the "s" field of
// "allowed_values.list" has the set of allowed strings.
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;

// Optional deprecation based on GraphDef versions.
OpDeprecation deprecation = 8;

// One-line human-readable description of what the Op does.
string summary = 5;

// Additional, longer human-readable description of what the Op does.
string description = 6;

// -------------------------------------------------------------------------
// Which optimizations this operation can participate in.

// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
bool is_commutative = 18;

// If is_aggregate is true, then this operation accepts N >= 2
// inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same
// shape as the input. The optimizer may replace an aggregate op
// taking input from multiple devices with a tree of aggregate ops
// that aggregate locally within each device (and possibly within
// groups of nearby devices) before communicating.
bool is_aggregate = 16; // for things like add

// Other optimizations go here, like
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.

// -------------------------------------------------------------------------
// Optimization constraints.

// Ops are marked as stateful if their behavior depends on some state beyond
// their input tensors (e.g. variable reading op) or if they have
// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have
// no side-effects.
//
// By default Ops may be moved between devices. Stateful ops should
// either not be moved, or should only be moved if that state can also
// be moved (e.g. via some sort of save / restore).
// Stateful ops are guaranteed to never be optimized away by Common
// Subexpression Elimination (CSE).
bool is_stateful = 17; // for things like variables, queue

// -------------------------------------------------------------------------
// Non-standard options.

// By default, all inputs to an Op must be initialized Tensors. Ops
// that may initialize tensors for the first time should set this
// field to true, to allow the Op to take an uninitialized Tensor as
// input.
bool allows_uninitialized_input = 19; // for Assign, etc.
};
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc)

// Information about version-dependent deprecation of an op
message OpDeprecation {
// First GraphDef version at which the op is disallowed.
int32 version = 1;

// Explanation of why it was deprecated and what to use instead.
string explanation = 2;
};

// A collection of OpDefs
message OpList {
repeated OpDef op = 1;
};

+ 29
- 0
parser/common/proto/tensorflow/resource_handle.proto View File

@@ -0,0 +1,29 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "ResourceHandle";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
message ResourceHandleProto {
// Unique name for the device containing the resource.
string device = 1;

// Container in which this resource is placed.
string container = 2;

// Unique name of this resource.
string name = 3;

// Hash code for the type of the resource. Is only valid in the same device
// and in the same execution.
uint64 hash_code = 4;

// For debug-only, the name of the type pointed to by this handle, if
// available.
string maybe_type_name = 5;
};

+ 94
- 0
parser/common/proto/tensorflow/tensor.proto View File

@@ -0,0 +1,94 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "resource_handle.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;

// Shape of the tensor.
TensorShapeProto tensor_shape = 2;

// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.

// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;

// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;

// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.

// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];

// DT_FLOAT.
repeated float float_val = 5 [packed = true];

// DT_DOUBLE.
repeated double double_val = 6 [packed = true];

// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];

// DT_STRING
repeated bytes string_val = 8;

// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];

// DT_INT64
repeated int64 int64_val = 10 [packed = true];

// DT_BOOL
repeated bool bool_val = 11 [packed = true];

// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];

// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;

// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;

// DT_UINT32
repeated uint32 uint32_val = 16 [packed = true];

// DT_UINT64
repeated uint64 uint64_val = 17 [packed = true];
};

// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}

+ 45
- 0
parser/common/proto/tensorflow/tensor_shape.proto View File

@@ -0,0 +1,45 @@
// Protocol buffer representing the shape of tensors.

syntax = "proto3";
option cc_enable_arenas = true;
option java_outer_classname = "TensorShapeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

package domi.tensorflow;

// Dimensions of a tensor.
message TensorShapeProto {
// One dimension of the tensor.
message Dim {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
// that work with TensorShapeProto may fail at runtime when deserializing
// a TensorShapeProto containing a dim value of -1.
int64 size = 1;

// Optional name of the tensor dimension.
string name = 2;
};

// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
// for a 30 x 40 2D tensor. If an entry has size -1, this
// corresponds to a dimension of unknown size. The names are
// optional.
//
// The order of entries in "dim" matters: It indicates the layout of the
// values in the tensor in-memory representation.
//
// The first entry in "dim" is the outermost dimension used to layout the
// values, the last entry is the innermost dimension. This matches the
// in-memory layout of RowMajor Eigen tensors.
//
// If "dim.size()" > 0, "unknown_rank" must be false.
repeated Dim dim = 2;

// If true, the number of dimensions in the shape is unknown.
//
// If true, "dim.size()" must be 0.
bool unknown_rank = 3;
};

+ 74
- 0
parser/common/proto/tensorflow/types.proto View File

@@ -0,0 +1,74 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TypesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// LINT.IfChange
enum DataType {
// Not a legal value for DataType. Used to indicate a DataType field
// has not been set.
DT_INVALID = 0;

// Data types that all computation devices are expected to be
// capable to support.
DT_FLOAT = 1;
DT_DOUBLE = 2;
DT_INT32 = 3;
DT_UINT8 = 4;
DT_INT16 = 5;
DT_INT8 = 6;
DT_STRING = 7;
DT_COMPLEX64 = 8; // Single-precision complex
DT_INT64 = 9;
DT_BOOL = 10;
DT_QINT8 = 11; // Quantized int8
DT_QUINT8 = 12; // Quantized uint8
DT_QINT32 = 13; // Quantized int32
DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
DT_QINT16 = 15; // Quantized int16
DT_QUINT16 = 16; // Quantized uint16
DT_UINT16 = 17;
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;
DT_RESOURCE = 20;
DT_VARIANT = 21; // Arbitrary C++ data types
DT_UINT32 = 22;
DT_UINT64 = 23;

// Do not use! These are only for parameters. Every enum above
// should have a corresponding value below (verified by types_test).
DT_FLOAT_REF = 101;
DT_DOUBLE_REF = 102;
DT_INT32_REF = 103;
DT_UINT8_REF = 104;
DT_INT16_REF = 105;
DT_INT8_REF = 106;
DT_STRING_REF = 107;
DT_COMPLEX64_REF = 108;
DT_INT64_REF = 109;
DT_BOOL_REF = 110;
DT_QINT8_REF = 111;
DT_QUINT8_REF = 112;
DT_QINT32_REF = 113;
DT_BFLOAT16_REF = 114;
DT_QINT16_REF = 115;
DT_QUINT16_REF = 116;
DT_UINT16_REF = 117;
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
DT_RESOURCE_REF = 120;
DT_VARIANT_REF = 121;
DT_UINT32_REF = 122;
DT_UINT64_REF = 123;
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/c/c_api.h,
// https://www.tensorflow.org/code/tensorflow/go/tensor.go,
// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.h,
// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc,
// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py,
// https://www.tensorflow.org/code/tensorflow/python/framework/function.py)

+ 31
- 0
parser/common/proto/tensorflow/versions.proto View File

@@ -0,0 +1,31 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "VersionsProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Version information for a piece of serialized data
//
// There are different types of versions for each type of data
// (GraphDef, etc.), but they all have the same common shape
// described here.
//
// Each consumer has "consumer" and "min_producer" versions (specified
// elsewhere). A consumer is allowed to consume this data if
//
// producer >= min_producer
// consumer >= min_consumer
// consumer not in bad_consumers
//
message VersionDef {
// The version of the code that produced this data.
int32 producer = 1;

// Any consumer below this version is not allowed to consume this data.
int32 min_consumer = 2;

// Specific consumer versions which are disallowed (e.g. due to bugs).
repeated int32 bad_consumers = 3;
};

+ 528
- 0
parser/common/proto_file_parser.cc View File

@@ -0,0 +1,528 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/proto_file_parser.h"

#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <random>
#include <sys/types.h>
#include <unistd.h>
#include "common/string_util.h"
#include "common/types.h"
#include "common/util.h"
#include "common/debug/log.h"
#include "parser/common/acl_graph_parser_util.h"
#include "ge/ge_api_types.h"
#include "framework/common/debug/ge_log.h"

using std::ifstream;
using std::vector;
using std::string;

namespace {
const char kMinNum = '0';
const char kMaxNum = '9';
const int kMinLineWordSize = 3;
const int kMinMessageLineWords = 2;
const int kMaxIdentifier = 536870912; // 2^29 - 1
const int kTmpFileNameLen = 16;
const int kMinRandomNum = 0;
const int kMaxRandomNum = 9;
const int kDecimalMulti = 10;
const int kOpenRetValue = 0;
const int kMessageNameIndex = 2;
const char *const kTmpPath = "/tmp";
const char *const kMessage = "message";
const char *const kLayerParameter = "LayerParameter";
const char *const kNetParameter = "NetParameter";
const char *const kStartBrace = "{";
const char *const kCloseBrace = "}";
const char *const kOptional = "optional";
const char *const kRepeated = "repeated";
const char *const kRequired = "required";

bool GetIdentifier(const std::string &line, int &identifier) {
int size = line.size();
auto pos = line.find("=");
if (pos == std::string::npos) {
return false;
}
for (int i = pos + 1; i < size; i++) {
if (line[i] == ';') {
break;
}
if (line[i] >= kMinNum && line[i] <= kMaxNum) {
identifier = identifier * kDecimalMulti + line[i] - kMinNum;
}
if (identifier > kMaxIdentifier || identifier < 0) {
return false;
}
}
if (identifier == 0) {
return false;
}
return true;
}

void GetName(const std::string &op_info, string &op_name) {
op_name.assign(op_info);
auto pos = op_name.find("=");
if (pos != string::npos) {
op_name = op_name.substr(0, pos);
}
}

void GetOpParamInfo(const std::string &line, std::vector<std::string> &op_param_info) {
std::istringstream string_stream(line);
std::string temp;
while (std::getline(string_stream, temp, ' ')) {
if (temp.empty()) {
continue;
}
op_param_info.emplace_back(std::move(temp));
}
}

string GetMessageName(const std::string &line) {
std::vector<std::string> op_param_info;
GetOpParamInfo(line, op_param_info);
string message_name;
if (op_param_info.size() < kMinMessageLineWords) {
message_name = "";
return message_name;
}
message_name = op_param_info[1];
auto pos = message_name.find(kStartBrace);
if (pos != string::npos) {
message_name = message_name.substr(0, pos);
}
return message_name;
}

string CreatTmpName(int len) {
std::uniform_int_distribution<int> u(kMinRandomNum, kMaxRandomNum);
std::default_random_engine e;
e.seed(time(0));
string tmp_name = "";
for (int i = 0; i < len; i++) {
tmp_name += std::to_string(u(e));
}
return tmp_name;
}

bool SaveIdentifierOpMapInfo(const string &line, std::map<int, std::pair<string, string>> &identifier_op_map,
std::map<std::string, std::pair<int, string>> &op_identifier_map) {
std::vector<std::string> op_param_info;
GetOpParamInfo(line, op_param_info);
int info_size = op_param_info.size();
if (info_size < kMinLineWordSize) {
GELOGE(ge::FAILED, "Words size of line[%s] is less than kMinLineWordSize[%d].", line.c_str(), kMinLineWordSize);
return false;
}

if (op_param_info[0] != kOptional && op_param_info[0] != kRepeated && op_param_info[0] != kRequired) {
GELOGE(ge::FAILED, "Split line[%s] failed.", line.c_str());
return false;
}

// get identifier
int identifier = 0;
bool ret = GetIdentifier(line, identifier);
if (!ret) {
GELOGE(ge::FAILED, "Get identifier of line[%s] failed.", line.c_str());
return false;
}

// get op_name
string name;
GetName(op_param_info[kMessageNameIndex], name);

identifier_op_map[identifier] = std::make_pair(op_param_info[1], name);
op_identifier_map[name] = std::make_pair(identifier, op_param_info[1]);
return true;
}

bool CheckRealPath(const char *file_path) {
string dest_path = ge::parser::RealPath(file_path);
if (dest_path.empty()) {
GELOGW("Path [%s] is not real existed.", file_path);
return false;
}
return true;
}
} // namespace

namespace ge {
ProtoFileParser::~ProtoFileParser() {
if (!fusion_proto_path.empty() && CheckRealPath(fusion_proto_path.c_str())) {
(void)remove(fusion_proto_path.c_str());
}
}

std::string ProtoFileParser::GetFusionProtoFile() {
return fusion_proto_path;
}

Status ProtoFileParser::CreatProtoFile() {
if (fusion_proto_path.empty()) {
fusion_proto_path.assign(kTmpPath);
fusion_proto_path += "/" + CreatTmpName(kTmpFileNameLen);
}

int fd = open(fusion_proto_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP);
if (fd < kOpenRetValue) {
GELOGE(FAILED, "creat tmp proto file[%s] failed.", fusion_proto_path.c_str());
return FAILED;
}
close(fd);
return SUCCESS;
}

Status ProtoFileParser::ParseProtoFile(const string &proto_file,
std::map<int, std::pair<string, string>> &identifier_op_map,
std::map<std::string, std::pair<int, string>> &op_identifier_map) {
ifstream read_file;
read_file.open(proto_file, std::ios::in);
if (read_file.fail()) {
GELOGE(FAILED, "ifsream open proto file[%s] failed.", proto_file.c_str());
return FAILED;
}

std::string line;
bool save_flag = false;
while (std::getline(read_file, line)) {
if (line.find(kMessage) != std::string::npos && line.find(kLayerParameter) != std::string::npos) {
save_flag = true;
continue;
}

if (save_flag && line.find(kCloseBrace) != std::string::npos) {
save_flag = false;
break;
}

if (save_flag) {
if (line.find(kRepeated) == std::string::npos && line.find(kOptional) == std::string::npos &&
line.find(kRequired) == std::string::npos) {
continue;
}
bool ret = SaveIdentifierOpMapInfo(line, identifier_op_map, op_identifier_map);
if (!ret) {
read_file.close();
return FAILED;
}
}
}
read_file.close();
return SUCCESS;
}

Status ProtoFileParser::AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp) {
ifstream read_custom;
read_custom.open(custom_proto_file, std::ios::in);
if (read_custom.fail()) {
GELOGE(FAILED, "ifsream open custom proto file[%s] failed.", custom_proto_file);
return FAILED;
}

std::string line_custom;
bool custom_in_layer = false;
while (std::getline(read_custom, line_custom)) {
if (line_custom.find(kMessage) != std::string::npos && line_custom.find(kLayerParameter) != std::string::npos) {
custom_in_layer = true;
continue;
}

if (!custom_in_layer) {
continue;
}

if (line_custom.find(kCloseBrace) != std::string::npos) {
custom_in_layer = false;
break;
}
// exclude remark lines
if (line_custom.find(kRepeated) == std::string::npos && line_custom.find(kOptional) == std::string::npos &&
line_custom.find(kRequired) == std::string::npos) {
continue;
}
// exclude repeated lines
if (custom_repeat_line_map_.count(line_custom) == 0) {
write_tmp << line_custom << '\n';
}
}
read_custom.close();
return SUCCESS;
}

Status ProtoFileParser::AddCustomAndConflictMessage(const char *custom_proto_file, std::ofstream &write_tmp) {
ifstream read_custom;
read_custom.open(custom_proto_file, std::ios::in);
if (read_custom.fail()) {
GELOGE(FAILED, "ifsream open custom proto file[%s] failed.", custom_proto_file);
return FAILED;
}

std::string line_custom;
bool custom_in_message = false;
while (std::getline(read_custom, line_custom)) {
if (line_custom.find(kMessage) != std::string::npos) {
std::string message_name = GetMessageName(line_custom);
if (message_name != kLayerParameter && message_name != kNetParameter) {
custom_in_message = true;
write_tmp << line_custom << '\n';
} else {
custom_in_message = false;
}
continue;
}

// exclude repeated messages
if (custom_in_message) {
write_tmp << line_custom << '\n';
}
}
read_custom.close();
return SUCCESS;
}

Status ProtoFileParser::WriteCaffeProtoFile(const char *custom_proto_file,
std::ifstream &read_caffe,
std::ofstream &write_tmp) {
std::string line_caffe;
bool caffe_in_layer = false;
bool caffe_in_unrepeated_message = true;
string tmp_message_name;
while (std::getline(read_caffe, line_caffe)) {
if (line_caffe.find(kMessage) != std::string::npos) {
tmp_message_name.assign(GetMessageName(line_caffe));
if (custom_repeat_message_map_.count(tmp_message_name) > 0) {
caffe_in_unrepeated_message = false;
} else {
caffe_in_unrepeated_message = true;
if (tmp_message_name == kLayerParameter) {
caffe_in_layer = true;
}
}
}
if (!caffe_in_unrepeated_message) {
continue;
}
if (caffe_in_layer && line_caffe.find(kCloseBrace) != std::string::npos) {
if (AddCustomAndConflictLayer(custom_proto_file, write_tmp) != SUCCESS) {
GELOGE(FAILED, "Add conflict and new layer line from custom proto to dest proto failed.");
return FAILED;
}
caffe_in_layer = false;
}

// exclude conflict lines
if (caffe_in_layer && caffe_conflict_line_map_.count(line_caffe) > 0) {
GELOGD("pass line: %s", line_caffe.c_str());
continue;
}
write_tmp << line_caffe << '\n';
}
return SUCCESS;
}

Status ProtoFileParser::WriteProtoFile(const char *caffe_proto_file,
const char *custom_proto_file) {
std::ifstream read_caffe;
std::ofstream write_tmp;
read_caffe.open(caffe_proto_file, std::ios::in);
if (read_caffe.fail()) {
GELOGE(FAILED, "ifsream open proto file[%s] failed.", caffe_proto_file);
return FAILED;
}
write_tmp.open(fusion_proto_path, std::ios::out);
if (write_tmp.fail()) {
GELOGE(FAILED, "ofstream open proto file[%s] failed.", fusion_proto_path.c_str());
read_caffe.close();
return FAILED;
}

if (WriteCaffeProtoFile(custom_proto_file, read_caffe, write_tmp) != SUCCESS) {
read_caffe.close();
write_tmp.close();
return FAILED;
}

if (AddCustomAndConflictMessage(custom_proto_file, write_tmp) != SUCCESS) {
GELOGE(FAILED, "Add conflict and new message from custom proto to dest proto failed.");
read_caffe.close();
write_tmp.close();
return FAILED;
}

read_caffe.close();
write_tmp.close();
return SUCCESS;
}

Status ProtoFileParser::FindConflictLine(const char *proto_file, int identifier,
std::string &dest_line) {
ifstream read_file;
read_file.open(proto_file, std::ios::in);
if (read_file.fail()) {
GELOGE(FAILED, "open file[%s] failed.", proto_file);
return FAILED;
}

std::string line;
bool save_flag = false;
while (std::getline(read_file, line)) {
if (line.find(kMessage) != std::string::npos && line.find(kLayerParameter) != std::string::npos) {
save_flag = true;
continue;
}

if (save_flag && line.find(kCloseBrace) != std::string::npos) {
save_flag = false;
break;
}

int tmp_identifier = 0;
if (save_flag && GetIdentifier(line, tmp_identifier) && tmp_identifier == identifier) {
dest_line.assign(line);
read_file.close();
return SUCCESS;
}
}
read_file.close();
GELOGE(FAILED, "find line according to identifier[%d] failed.", identifier);
return FAILED;
}

void ProtoFileParser::CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file,
std::map<std::string, std::pair<int, string>> &caffe_op_identifier_map,
std::map<std::string, std::pair<int, string>> &custom_op_identifier_map) {
for (auto iter = custom_op_identifier_map.begin(); iter != custom_op_identifier_map.end(); ++iter) {
if (caffe_op_identifier_map.count(iter->first) > 0) {
string message_name = iter->first;
auto caffe_pair = caffe_op_identifier_map[iter->first];
auto custom_pair = custom_op_identifier_map[iter->first];
if (caffe_pair.first != custom_pair.first || caffe_pair.second != custom_pair.second) {
// consider conflict op and name and type;
GELOGD("Find conflict op: caffe_identifier[%d], custom_identifier[%d], op_name[%s].",
caffe_pair.first, custom_pair.first, message_name.c_str());
std::string caffe_conflict_line;
(void)FindConflictLine(caffe_proto_file, caffe_pair.first, caffe_conflict_line);
GELOGD("conflict: %s", caffe_conflict_line.c_str());
caffe_conflict_line_map_[caffe_conflict_line]++;
} else {
// consider repeat op and name and type; could be removed
std::string custom_repeat_line;
(void)FindConflictLine(custom_proto_file, caffe_pair.first, custom_repeat_line);
custom_repeat_line_map_[custom_repeat_line]++;
GELOGD("repeat: %s", custom_repeat_line.c_str());
}
}
}
}

void ProtoFileParser::CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file,
std::map<int, std::pair<string, string>> caffe_identifier_op_map,
std::map<int, std::pair<string, string>> custom_identifier_op_map) {
for (auto iter = custom_identifier_op_map.begin(); iter != custom_identifier_op_map.end(); ++iter) {
if (caffe_identifier_op_map.count(iter->first) > 0) {
int identifier = iter->first;
auto caffe_pair = caffe_identifier_op_map[iter->first];
auto custom_pair = custom_identifier_op_map[iter->first];
if (caffe_pair.first != custom_pair.first || caffe_pair.second != custom_pair.second) {
// consider conflict op and name and type;
GELOGD("Find conflict op: caffe_op[%s], custom_op[%s], identifier[%d].",
caffe_pair.first.c_str(), custom_pair.first.c_str(),
identifier);
std::string caffe_conflict_line;
(void)FindConflictLine(caffe_proto_file, identifier, caffe_conflict_line);
GELOGD("conflict: %s", caffe_conflict_line.c_str());
caffe_conflict_line_map_[caffe_conflict_line]++;
} else {
// consider repeat op and name and type;
std::string custom_repeat_line;
(void)FindConflictLine(custom_proto_file, identifier, custom_repeat_line);
custom_repeat_line_map_[custom_repeat_line]++;
GELOGD("repeat: %s", custom_repeat_line.c_str());
}
}
}
}

Status ProtoFileParser::RecordProtoMessage(const string &proto_file) {
ifstream read_file;
read_file.open(proto_file, std::ios::in);
if (read_file.fail()) {
GELOGE(FAILED, "ifsream open proto file[%s] failed.", proto_file.c_str());
return FAILED;
}

std::string line;
while (std::getline(read_file, line)) {
if (line.find(kMessage) != std::string::npos) {
std::string message_name = GetMessageName(line);
if (message_name != kLayerParameter && message_name != kNetParameter) {
custom_repeat_message_map_[message_name]++;
}
}
}
read_file.close();
return SUCCESS;
}

Status ProtoFileParser::CombineProtoFile(const char *caffe_proto_file, const char *custom_proto_file,
std::string &dest_proto_file) {
GE_CHECK_NOTNULL(caffe_proto_file);
GE_CHECK_NOTNULL(custom_proto_file);

if (!CheckRealPath(caffe_proto_file) || !CheckRealPath(custom_proto_file)) {
GELOGE(FAILED, "caffe proto[%s] and custom proto[%s] are not all existed.",
caffe_proto_file, custom_proto_file);
return FAILED;
}

GELOGI("Start fusion custom and caffe proto to file.");
std::map<int, std::pair<string, string>> caffe_identifier_op_map;
std::map<int, std::pair<string, string>> custom_identifier_op_map;
std::map<std::string, std::pair<int, string>> caffe_op_identifier_map;
std::map<std::string, std::pair<int, string>> custom_op_identifier_map;

(void)ParseProtoFile(caffe_proto_file, caffe_identifier_op_map, caffe_op_identifier_map);
(void)ParseProtoFile(custom_proto_file, custom_identifier_op_map, custom_op_identifier_map);
(void)RecordProtoMessage(custom_proto_file);

// check identifier or op_type is same
CheckConflictIdentifier(caffe_proto_file, custom_proto_file,
caffe_identifier_op_map, custom_identifier_op_map);
CheckConflictOp(caffe_proto_file, custom_proto_file,
caffe_op_identifier_map, custom_op_identifier_map);

if (CreatProtoFile() != SUCCESS) {
return FAILED;
}

if (WriteProtoFile(caffe_proto_file, custom_proto_file) != SUCCESS) {
GELOGE(FAILED, "Combine caffe proto and custom proto to dest proto file failed.");
return FAILED;
}
dest_proto_file.assign(fusion_proto_path);
GELOGI("Fusion custom and caffe proto to file[%s] success.", dest_proto_file.c_str());
return SUCCESS;
}
} // namespace ge

+ 63
- 0
parser/common/proto_file_parser.h View File

@@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd

* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at

* http://www.apache.org/licenses/LICENSE-2.0

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PROTO_FILE_PARSE_UTIL_
#define PROTO_FILE_PARSE_UTIL_

#include <map>
#include <string>
#include "common/types.h"
#include "ge/ge_api_types.h"

namespace ge {
class ProtoFileParser {
public:
ProtoFileParser(){};
ProtoFileParser(const char *dest_path){
fusion_proto_path = dest_path;
}
~ProtoFileParser();
Status CombineProtoFile(const char *caffe_proto_file, const char *custom_proto_file,
std::string &dest_proto_file);
std::string GetFusionProtoFile();
private:
Status CreatProtoFile();
Status ParseProtoFile(const std::string &proto_file,
std::map<int, std::pair<std::string, std::string> > &identifier_op_map,
std::map<std::string, std::pair<int, std::string> > &op_identifier_map);
Status WriteCaffeProtoFile(const char *custom_proto_file,
std::ifstream &read_caffe,
std::ofstream &write_tmp);
Status WriteProtoFile(const char *caffe_proto_file, const char *custom_proto_file);
Status FindConflictLine(const char *proto_file, int identifier,
std::string &dest_line);
Status AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp);
Status AddCustomAndConflictMessage(const char *custom_proto_file, std::ofstream &write_tmp);
void CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file,
std::map<std::string, std::pair<int, std::string>> &caffe_op_identifier_map,
std::map<std::string, std::pair<int, std::string>> &custom_op_identifier_map);
void CheckConflictIdentifier(const char *caffe_proto_file, const char *custom_proto_file,
std::map<int, std::pair<std::string, std::string>> caffe_identifier_op_map,
std::map<int, std::pair<std::string, std::string>> custom_identifier_op_map);
Status RecordProtoMessage(const std::string &proto_file);
std::map<std::string, int> caffe_conflict_line_map_;
std::map<std::string, int> custom_repeat_line_map_;
std::map<std::string, int> custom_repeat_message_map_;
std::string fusion_proto_path;
};
} // namespace ge

#endif // PROTO_FILE_PARSE_UTIL_

+ 132
- 0
parser/common/register_tbe.cc View File

@@ -0,0 +1,132 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "parser/common/register_tbe.h"
#include <map>
#include <memory>
#include <string>
#include "common/debug/log.h"
#include "common/ge/ge_util.h"
#include "common/op/ge_op_utils.h"
#include "common/op_map.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/utils/type_utils.h"
#include "parser/common/op_parser_factory.h"
#include "parser/tensorflow/tensorflow_custom_parser_adapter.h"
#include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h"

namespace ge {
using PARSER_CREATOR_FN = std::function<std::shared_ptr<OpParser>(void)>;

FMK_FUNC_HOST_VISIBILITY OpRegistrationTbe *OpRegistrationTbe::Instance() {
static OpRegistrationTbe instance;
return &instance;
}

bool OpRegistrationTbe::Finalize(const OpRegistrationData &reg_data, bool is_train) {
static std::map<domi::FrameworkType, std::map<std::string, std::string> *> op_map = {{CAFFE, &caffe_op_map}};
if (is_train) {
op_map[domi::TENSORFLOW] = &tensorflow_train_op_map;
} else {
op_map[domi::TENSORFLOW] = &tensorflow_op_map;
}

if (op_map.find(reg_data.GetFrameworkType()) != op_map.end()) {
std::map<std::string, std::string> *fmk_op_map = op_map[reg_data.GetFrameworkType()];
auto ori_optype_set = reg_data.GetOriginOpTypeSet();
for (auto &tmp : ori_optype_set) {
if ((*fmk_op_map).find(tmp) != (*fmk_op_map).end()) {
GELOGW("Op type does not need to be changed, om_optype:%s, orignal type:%s.", (*fmk_op_map)[tmp].c_str(),
tmp.c_str());
continue;
} else {
(*fmk_op_map)[tmp] = reg_data.GetOmOptype();
GELOGD("First register in parser initialize, original type: %s, om_optype: %s, imply type: %s.", tmp.c_str(),
reg_data.GetOmOptype().c_str(), TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str());
}
}
}

bool ret = RegisterParser(reg_data);
return ret;
}

bool OpRegistrationTbe::RegisterParser(const OpRegistrationData &reg_data) {
if (reg_data.GetFrameworkType() == domi::TENSORFLOW) {
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
if (factory == nullptr) {
GELOGE(INTERNAL_ERROR, "Get op parser factory for tf failed.");
return false;
}
if (reg_data.GetParseParamFn() != nullptr || reg_data.GetParseParamByOperatorFn() != nullptr) {
bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype());
if (is_registed) {
GELOGW("Parse param func has already register for op:%s.", reg_data.GetOmOptype().c_str());
return false;
}
std::shared_ptr<TensorFlowCustomParserAdapter> tf_parser_adapter =
ge::MakeShared<TensorFlowCustomParserAdapter>();
if (tf_parser_adapter == nullptr) {
GELOGE(PARAM_INVALID, "Create tf parser adapter failed.");
return false;
}
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar(
domi::TENSORFLOW, reg_data.GetOmOptype(), [=]() -> std::shared_ptr<OpParser> { return tf_parser_adapter; });
}
if (reg_data.GetFusionParseParamFn() != nullptr || reg_data.GetFusionParseParamByOpFn() != nullptr) {
bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype(), true);
if (is_registed) {
GELOGW("Parse param func has already register for fusion op:%s.", reg_data.GetOmOptype().c_str());
return false;
}
GELOGI("Register fusion custom op parser: %s", reg_data.GetOmOptype().c_str());
std::shared_ptr<TensorFlowFusionCustomParserAdapter> tf_fusion_parser_adapter =
ge::MakeShared<TensorFlowFusionCustomParserAdapter>();
if (tf_fusion_parser_adapter == nullptr) {
GELOGE(PARAM_INVALID, "Create tf fusion parser adapter failed.");
return false;
}
OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar(
domi::TENSORFLOW, reg_data.GetOmOptype(),
[=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true);
}
} else {
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType());
if (factory == nullptr) {
GELOGE(INTERNAL_ERROR, "Get op parser factory for %s failed.",
TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str());
return false;
}
bool is_registed = factory->OpParserIsRegistered(reg_data.GetOmOptype());
if (is_registed) {
GELOGW("Parse param func has already register for op:%s.", reg_data.GetOmOptype().c_str());
return false;
}

PARSER_CREATOR_FN func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(reg_data.GetFrameworkType());
if (func == nullptr) {
GELOGE(INTERNAL_ERROR, "Get custom parser adapter failed for fmk type %s.",
TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str());
return false;
}
OpParserFactory::Instance(reg_data.GetFrameworkType())->RegisterCreator(reg_data.GetOmOptype(), func);
GELOGD("Register custom parser adapter for op %s of fmk type %s success.", reg_data.GetOmOptype().c_str(),
TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str());
}
return true;
}
} // namespace ge

+ 34
- 0
parser/common/register_tbe.h View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_REGISTER_TBE_H_
#define PARSER_COMMON_REGISTER_TBE_H_

#include "register/op_registry.h"

namespace ge {
class OpRegistrationTbe {
public:
static OpRegistrationTbe *Instance();

bool Finalize(const OpRegistrationData &reg_data, bool is_train = false);

private:
bool RegisterParser(const OpRegistrationData &reg_data);
};
} // namespace ge

#endif // PARSER_COMMON_REGISTER_TBE_H_

+ 212
- 0
parser/common/tbe_plugin_loader.cc View File

@@ -0,0 +1,212 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tbe_plugin_loader.h"

#include <dirent.h>
#include <sys/stat.h>
#include <unistd.h>
#include <algorithm>
#include <cstring>
#include <fstream>
#include <iostream>
#include <map>
#include <memory>
#include <string>

#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/string_util.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "graph/utils/type_utils.h"
#include "parser/common/acl_graph_parser_util.h"

namespace ge {
std::map<string, string> TBEPluginLoader::options_ = {};

namespace {
const std::string FRAMEWORK_TYPE = "ge.frameworkType";
}

// Get Singleton Instance
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginLoader &TBEPluginLoader::Instance() {
static TBEPluginLoader instance_ptr_;
return instance_ptr_;
}

Status TBEPluginLoader::ClearHandles_() {
Status ret = SUCCESS;
for (const auto &handle : handles_vec_) {
if (dlclose(handle) != 0) {
ret = FAILED;
GELOGW("Failed to close handle: %s", dlerror());
}
}
handles_vec_.clear();
return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status TBEPluginLoader::Finalize() {
Status ret = ClearHandles_();
return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginLoader::LoadPluginSo(
const std::map<string, string> &options) {
vector<string> file_list;
string caffe_parser_path;
std::string plugin_path;

options_ = options;
GetCustomOpPath(plugin_path);

// Whether there are files in the plugin so path
GetPluginSoFileList(plugin_path, file_list, caffe_parser_path);

// No file
if (file_list.empty()) {
// Print log
GELOGW("Can not find any plugin file in plugin_path: %s", plugin_path.c_str());
}

GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted.");

// Load other so files except lib_caffe_parser.so in the plugin so path
for (auto elem : file_list) {
StringUtils::Trim(elem);

void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE);
if (handle == nullptr) {
GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror());
} else if (find(handles_vec_.begin(), handles_vec_.end(), handle) == handles_vec_.end()) {
// Close dl when the program exist, not close here
GELOGI("Plugin load %s success.", elem.c_str());
handles_vec_.push_back(handle);
} else {
GELOGI("Plugin so has already been loaded, no need to load again.");
}
}
}

void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) {
GELOGI("Enter get custom op path schedule");
std::string fmk_type;
domi::FrameworkType type = domi::TENSORFLOW;
auto it = options_.find(FRAMEWORK_TYPE);
if (it != options_.end()) {
type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10));
}
fmk_type = ge::TypeUtils::FmkTypeToSerialString(type);
GELOGI("Framework type is %s.", fmk_type.c_str());

const char *path_env = std::getenv("ASCEND_OPP_PATH");
if (path_env != nullptr) {
std::string path = path_env;
customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type);
GELOGI("Get custom so path from env : %s", path_env);
return;
}
std::string path_base = GetPath();
GELOGI("path_base is %s", path_base.c_str());
path_base = path_base.substr(0, path_base.rfind('/'));
path_base = path_base.substr(0, path_base.rfind('/') + 1);
customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type);
}

string TBEPluginLoader::GetPath() {
Dl_info dl_info;
if (dladdr(reinterpret_cast<void *>(&TBEPluginLoader::GetPath), &dl_info) == 0) {
GELOGW("Failed to read so path!");
return string();
} else {
string so_path = dl_info.dli_fname;
char path[PATH_MAX] = {0};
if (so_path.length() >= PATH_MAX) {
GELOGW("File path is too long!");
return string();
}
if (realpath(so_path.c_str(), path) == nullptr) {
GELOGW("Failed to get realpath of %s", so_path.c_str());
return string();
}

so_path = path;
so_path = so_path.substr(0, so_path.rfind('/') + 1);
return so_path;
}
}

void TBEPluginLoader::GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path) {
// Support to split multiple so directories by ":"
vector<string> v_path = StringUtils::Split(path, ':');
for (size_t i = 0; i < v_path.size(); ++i) {
FindParserSo(v_path[i], file_list, caffe_parser_path);
GELOGI("CustomOpLib full name = %s", v_path[i].c_str());
}
}

void TBEPluginLoader::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) {
// Path, change to absolute path
string real_path = ge::parser::RealPath(path.c_str());
// Plugin path does not exist
if (real_path.empty()) {
GELOGW("RealPath is empty.");
return;
}
struct stat stat_buf;
if ((stat(real_path.c_str(), &stat_buf) != 0) || (!S_ISDIR(stat_buf.st_mode))) {
GELOGW("%s is not a dir.", real_path.c_str());
return;
}
struct dirent *dent(0);
DIR *dir = opendir(real_path.c_str());
// Plugin path does not exist
if (dir == nullptr) {
GELOGW("Open directory %s failed.", real_path.c_str());
return;
}

while ((dent = readdir(dir)) != nullptr) {
if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) continue;
string name = dent->d_name;
string full_name = real_path + "/" + name;
const string so_suff = ".so";
const string caffe_parser_so_suff = "lib_caffe_parser.so";
const string aicpu_so_suff = "_aicpu.so";
const string aicpu_host_so_suff = "_online.so";
if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) {
ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff,
aicpu_host_so_suff);
} else {
FindParserSo(full_name, file_list, caffe_parser_path);
}
}
closedir(dir);
}

void TBEPluginLoader::ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name,
const string &caffe_parser_so_suff, const string &aicpu_so_suff,
const string &aicpu_host_so_suff) {
if (full_name.size() >= caffe_parser_so_suff.size() &&
full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(),
caffe_parser_so_suff) == 0) {
caffe_parser_path = full_name;
} else {
// Save parser so path into file_list vector
file_list.push_back(full_name);
}
}
} // namespace ge

+ 62
- 0
parser/common/tbe_plugin_loader.h View File

@@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_TBE_PLUGIN_LOADER_H_
#define PARSER_COMMON_TBE_PLUGIN_LOADER_H_

#include <dlfcn.h>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <type_traits>
#include <typeinfo>
#include <vector>

#include "external/ge/ge_api_error_codes.h"
#include "external/register/register.h"

namespace ge {
using SoHandlesVec = std::vector<void *>;
class TBEPluginLoader {
public:
Status Finalize();

// Get TBEPluginManager singleton instance
static TBEPluginLoader& Instance();

void LoadPluginSo(const std::map<string, string> &options);

static string GetPath();

private:
TBEPluginLoader() = default;
~TBEPluginLoader() = default;
Status ClearHandles_();
static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name,
const string &caffe_parser_so_suff, const string &aicpu_so_suff,
const string &aicpu_host_so_suff);
static void GetCustomOpPath(std::string &customop_path);
static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path);

SoHandlesVec handles_vec_;
static std::map<string, string> options_;
};
} // namespace ge

#endif //PARSER_COMMON_TBE_PLUGIN_LOADER_H_

+ 78
- 0
parser/common/thread_pool.cc View File

@@ -0,0 +1,78 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/thread_pool.h"

#include <atomic>
#include <functional>
#include <queue>
#include <stdexcept>
#include <utility>
#include <vector>

#include "register/register_types.h"

namespace ge {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) {
idle_thrd_num_ = size < 1 ? 1 : size;

for (uint32_t i = 0; i < idle_thrd_num_; ++i) {
pool_.emplace_back(ThreadFunc, this);
}
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() {
is_stoped_.store(true);
{
std::unique_lock<std::mutex> lock{m_lock_};
cond_var_.notify_all();
}

for (std::thread &thd : pool_) {
if (thd.joinable()) {
try {
thd.join();
} catch (const std::system_error &) {
GELOGW("system_error");
} catch (...) {
GELOGW("exception");
}
}
}
}

void ThreadPool::ThreadFunc(ThreadPool *thread_pool) {
if (thread_pool == nullptr) {
return;
}
while (!thread_pool->is_stoped_) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock{thread_pool->m_lock_};
thread_pool->cond_var_.wait(
lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); });
if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) {
return;
}
task = std::move(thread_pool->tasks_.front());
thread_pool->tasks_.pop();
}
--thread_pool->idle_thrd_num_;
task();
++thread_pool->idle_thrd_num_;
}
}
} // namespace ge

+ 83
- 0
parser/common/thread_pool.h View File

@@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_COMMON_THREAD_POOL_H_
#define PARSER_COMMON_THREAD_POOL_H_

#include <atomic>
#include <condition_variable>
#include <functional>
#include <future>
#include <memory>
#include <queue>
#include <stdexcept>
#include <thread>
#include <utility>
#include <vector>

#include "framework/common/debug/ge_log.h"
#include "framework/common/ge_inner_error_codes.h"
#include "external/ge/ge_api_error_codes.h"
#include "graph/types.h"
#include "common/ge/ge_util.h"

namespace ge {
using ThreadTask = std::function<void()>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool {
public:
explicit ThreadPool(uint32_t size = 4);
~ThreadPool();

template <class Func, class... Args>
auto commit(Func &&func, Args &&... args) -> std::future<decltype(func(args...))> {
GELOGD("commit run task enter.");
using retType = decltype(func(args...));
std::future<retType> fail_future;
if (is_stoped_.load()) {
GELOGE(ge::FAILED, "thread pool has been stopped.");
return fail_future;
}

auto bindFunc = std::bind(std::forward<Func>(func), std::forward<Args>(args)...);
auto task = ge::MakeShared<std::packaged_task<retType()>>(bindFunc);
if (task == nullptr) {
GELOGE(ge::FAILED, "Make shared failed.");
return fail_future;
}
std::future<retType> future = task->get_future();
{
std::lock_guard<std::mutex> lock{m_lock_};
tasks_.emplace([task]() { (*task)(); });
}
cond_var_.notify_one();
GELOGD("commit run task end");
return future;
}

static void ThreadFunc(ThreadPool *thread_pool);

private:
std::vector<std::thread> pool_;
std::queue<ThreadTask> tasks_;
std::mutex m_lock_;
std::condition_variable cond_var_;
std::atomic<bool> is_stoped_;
std::atomic<uint32_t> idle_thrd_num_;
};
} // namespace ge

#endif // PARSER_COMMON_THREAD_POOL_H_

+ 307
- 0
parser/common/tuple.h View File

@@ -0,0 +1,307 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_COMMON_TUPLE_H_
#define GE_COMMON_TUPLE_H_

#include <algorithm>
#include <iostream>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "framework/common/debug/log.h"

namespace ge {
template <typename ValueType>
class Tuple {
public:
Tuple() = default;
inline ~Tuple() {
delete[] data_heap_;
data_heap_ = nullptr;
}
///
/// @brief copy constructor from another tuple
/// @param s the source tuple
///
inline Tuple(const Tuple<ValueType> &s) { this->assign(s.begin(), s.end()); }
///
/// @brief constructor from initializer list
/// @param init the initializer_list
///
inline Tuple(const std::initializer_list<ValueType> &init) { this->assign(init.begin(), init.end()); }
///
/// @brief constructor from vector
/// @param init the vector
///
inline Tuple(const std::vector<ValueType> &init) { // NOLINT(runtime/explicit)
this->assign(init.begin(), init.end());
}
///
/// @brief move constructor from Tuple
/// @param src the source shape
///
inline Tuple(Tuple<ValueType> &&src) { // NOLINT(runtime/explicit)
this->swap(src);
}
///
/// @brief construct the Tuple from content of iterator
/// @param begin the beginning of iterator
/// @param end end the end of the iterator
/// @tparam RandomAccessIterator iterator type
///
template <typename RandomAccessIterator>
inline Tuple(RandomAccessIterator begin, RandomAccessIterator end) {
this->assign(begin, end);
}
///
/// @brief Assign content to tuple from iterator.
/// @param begin the beginning of iterator
/// @param end end the end of the iterator
/// @tparam RandomAccessIterator iterator type
///
template <typename RandomAccessIterator>
inline void assign(const RandomAccessIterator &begin, const RandomAccessIterator &end) {
this->SetDim(end - begin);
(void)std::copy(begin, end, this->begin());
}
///
/// @brief Swap current object with other
/// @param other another object to be swapped.
///
inline void swap(Tuple<ValueType> &other) { // NOLINT(*)
std::swap(ndim_, other.ndim_);
std::swap(num_heap_allocated_, other.num_heap_allocated_);
std::swap(data_stack_, other.data_stack_);
std::swap(data_heap_, other.data_heap_);
}
///
/// @brief assignment from another tuple.
/// @param src source tuple
/// @return reference of self
///
inline Tuple<ValueType> &operator=(const Tuple<ValueType> &src) {
if (&src != this) {
this->assign(src.begin(), src.end());
}
return *this;
}
///
/// @brief assignment from rvalue of another tuple.
/// @param src source tuple
/// @return reference of self
///
inline Tuple<ValueType> &operator=(Tuple<ValueType> &&src) {
if (&src != this) {
Tuple<ValueType>(std::move(src)).swap(*this);
}
return *this;
}
///
/// @brief assignment from initializer list
/// @param init the source initializer list
/// @return reference of self
///
inline Tuple<ValueType> &operator=(std::initializer_list<ValueType> init) {
this->assign(init.begin(), init.end());
return *this;
}
///
/// @return whether two tuple equals
/// @param s the tuple to compare against
///
inline bool operator==(const Tuple<ValueType> &s) const {
if (ndim_ != s.ndim_) return false;
return std::equal(begin(), end(), s.begin());
}
///
/// @return whether two tuple not equal
/// @param s the tuple to compare against
///
inline bool operator!=(const Tuple<ValueType> &s) const { return !(*this == s); }
///
/// @return the begin data pointer to content of the tuple
///
inline const ValueType *begin() const { return ndim_ <= STACK_CACHE_NUM ? data_stack_ : data_heap_; }
///
/// @return the begin data pointer to content of the tuple
///
inline ValueType *begin() { return ndim_ <= STACK_CACHE_NUM ? data_stack_ : data_heap_; }
///
/// @return the data pointer to end of the tuple
///
inline const ValueType *end() const {
return ndim_ <= STACK_CACHE_NUM ? (data_stack_ + ndim_) : (data_heap_ + ndim_);
}
///
/// @return the data pointer to end the tuple
///
inline ValueType *end() { return ndim_ <= STACK_CACHE_NUM ? (data_stack_ + ndim_) : (data_heap_ + ndim_); }
///
/// @return number of dimension of the tuple
///
inline uint32_t ndim() const { return ndim_; }
///
/// @brief get corresponding index
/// @param i dimension index
/// @return the corresponding dimension size
///
inline ValueType &operator[](size_t i) { return begin()[i]; }
///
/// @brief get corresponding index
/// @param i dimension index
/// @return the corresponding dimension size
///
inline const ValueType &operator[](size_t i) const { return begin()[i]; }
///
/// @brief allow output string of tuple to ostream
/// @param os the output stream
/// @param t the tuple
/// @return the ostream
///
friend std::ostream &operator<<(std::ostream &os, const Tuple<ValueType> &t) {
os << '[';
const ValueType *begin = t.begin();
const ValueType *end = t.end();
for (const ValueType *it = begin; it != end; ++it) {
if (it != begin) os << ',';
os << *it;
}
os << ']';
return os;
}
///
/// @brief read tuple from the istream
/// @param is the input stream
/// @param t The tuple
/// @return the istream
///
friend std::istream &operator>>(std::istream &is, Tuple<ValueType> &t) {
// get (
if (!HandleLeftBracket(is, t)) {
return is;
}

// Handle empty tuple
while (isspace(is.peek())) {
(void)is.get();
}
if (IsRightBracket(is.peek())) {
(void)is.get();
return is;
}
// Handle non-empty tuple
ValueType idx;
std::vector<ValueType> tmp;
while (is >> idx) {
tmp.push_back(idx);
char ch;
do {
ch = static_cast<char>(is.get());
} while (isspace(ch));
if (std::is_integral<ValueType>::value && ch == 'L') {
ch = static_cast<char>(is.get());
}
if (ch == ',') {
while (true) {
ch = static_cast<char>(is.peek());
if (isspace(ch)) {
(void)is.get();
continue;
}
if (IsRightBracket(ch)) {
(void)is.get();
break;
}
break;
}
if (IsRightBracket(ch)) break;
} else if (IsRightBracket(ch)) {
break;
} else {
is.setstate(std::ios::failbit);
return is;
}
}
t.assign(tmp.begin(), tmp.end());
return is;
}

// stack cache size
static const uint32_t STACK_CACHE_NUM = 4;
// in stack space used to store shape when it is small
ValueType data_stack_[STACK_CACHE_NUM];
// space to store shape when dimension is big
ValueType *data_heap_{nullptr};
uint32_t ndim_{0};

protected:
// number of cells allocated in data_heap_
uint32_t num_heap_allocated_{0};

// internal function to change the dimension
inline void SetDim(uint32_t ndim) {
if (ndim > STACK_CACHE_NUM && ndim > num_heap_allocated_) {
if (data_heap_ != nullptr) {
delete[] data_heap_;
data_heap_ = nullptr;
}
data_heap_ = new (std::nothrow) ValueType[ndim]();
if (data_heap_ == nullptr) {
GELOGW("data_heap_ is nullptr.");
}
num_heap_allocated_ = ndim;
}
ndim_ = ndim;
}
static inline bool IsLeftBracket(char ch) { return ch == '(' || ch == '['; }

static inline bool IsRightBracket(char ch) { return ch == ')' || ch == ']'; }

friend bool HandleLeftBracket(std::istream &is, Tuple<ValueType> &t) {
while (true) {
char ch = is.peek();
if (isdigit(ch) || (ch == '-')) {
ValueType idx;
if (is >> idx) {
t.assign(&idx, &idx + 1);
}
return false;
}
(void)is.get();
if (IsLeftBracket(ch)) {
break;
}

if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return false;
}
}

return true;
}
};

using UintTuple = Tuple<uint32_t>;
using IntTuple = Tuple<int64_t>;
using FloatTuple = Tuple<float>;
using BoolTuple = Tuple<bool>;
using StringTuple = Tuple<std::string>;
} // namespace ge

#endif // GE_COMMON_TUPLE_H_

+ 53
- 0
parser/common/types_map.h View File

@@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_TYPES_MAP_H
#define GE_TYPES_MAP_H

#include "external/graph/types.h"
#include "proto/tensorflow/graph.pb.h"

namespace ge {
// Correspondence between data_type in GE and tensorflow
static map<int32_t, int32_t> GE_TENSORFLOW_DATA_TYPE_MAP = {
{ge::DataType::DT_UNDEFINED, domi::tensorflow::DT_INVALID},
{ge::DataType::DT_FLOAT, domi::tensorflow::DT_FLOAT},
{ge::DataType::DT_FLOAT16, domi::tensorflow::DT_HALF},
{ge::DataType::DT_INT8, domi::tensorflow::DT_INT8},
{ge::DataType::DT_INT16, domi::tensorflow::DT_INT16},
{ge::DataType::DT_UINT16, domi::tensorflow::DT_UINT16},
{ge::DataType::DT_UINT8, domi::tensorflow::DT_UINT8},
{ge::DataType::DT_INT32, domi::tensorflow::DT_INT32},
{ge::DataType::DT_INT64, domi::tensorflow::DT_INT64},
{ge::DataType::DT_UINT32, domi::tensorflow::DT_UINT32},
{ge::DataType::DT_UINT64, domi::tensorflow::DT_UINT64},
{ge::DataType::DT_STRING, domi::tensorflow::DT_STRING},
{ge::DataType::DT_RESOURCE, domi::tensorflow::DT_RESOURCE},
{ge::DataType::DT_BOOL, domi::tensorflow::DT_BOOL},
{ge::DataType::DT_DOUBLE, domi::tensorflow::DT_DOUBLE},
{ge::DataType::DT_COMPLEX64, domi::tensorflow::DT_COMPLEX64},
{ge::DataType::DT_COMPLEX128, domi::tensorflow::DT_COMPLEX128},
{ge::DataType::DT_QINT8, domi::tensorflow::DT_QINT8},
{ge::DataType::DT_QINT16, domi::tensorflow::DT_QINT16},
{ge::DataType::DT_QINT32, domi::tensorflow::DT_QINT32},
{ge::DataType::DT_QUINT8, domi::tensorflow::DT_QUINT8},
{ge::DataType::DT_QUINT16, domi::tensorflow::DT_QUINT16},
{ge::DataType::DT_DUAL, domi::tensorflow::DT_INVALID},
{ge::DataType::DT_DUAL_SUB_INT8, domi::tensorflow::DT_INVALID},
{ge::DataType::DT_DUAL_SUB_UINT8, domi::tensorflow::DT_INVALID},
};
} // namespace ge
#endif // GE_TYPES_MAP_H

+ 32
- 0
parser/func_to_graph/CMakeLists.txt View File

@@ -0,0 +1,32 @@
set(PROTO_LIST
"${TOP_DIR}/inc/register/proto/tensorflow/graph.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/node_def.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/tensor_shape.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/attr_value.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/function.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/op_def.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/resource_handle.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/tensor.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/types.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/versions.proto"
"${TOP_DIR}/inc/register/proto/tensorflow/graph_library.proto"
)

protobuf_generate_py(ge PROTO_SRCS ${PROTO_LIST})

include_directories(${CMAKE_CURRENT_LIST_DIR})

############ func2graph/util ############
add_custom_target(util ALL
DEPENDS ${PROTO_SRCS}
COMMAND mkdir -p ${CMAKE_CURRENT_BINARY_DIR}/util
&& cp -r ${PROTO_SRCS} ${CMAKE_CURRENT_BINARY_DIR}/util
)

set(INSTALL_BASE_DIR "")
set(INSTALL_LIBRARY_DIR lib)

install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/util OPTIONAL
DESTINATION ${INSTALL_LIBRARY_DIR}/func2graph
)


+ 279
- 0
parser/func_to_graph/func2graph.py View File

@@ -0,0 +1,279 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-

import os
import sys
import getopt

from google.protobuf import text_format
import tensorflow as tf
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework.errors_impl import NotFoundError
from tensorflow.python.platform import gfile

from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions

sys.path.append(os.path.join(os.path.split(os.path.realpath(__file__))[0], "util"))

import graph_library_pb2


def _get_num_args(arg_def, node_def):
if arg_def.number_attr:
return node_def.attr[arg_def.number_attr].i
elif arg_def.type_list_attr:
return len(node_def.attr[arg_def.type_list_attr].list.type)
elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
return 1
else:
raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))


def is_function(fname):
"""Checks for a function definition with `fname` in the current context."""
if context.executing_eagerly():
return context.context().has_function(fname)
else:
return ops.get_default_graph()._is_function(fname)

def create_arg_for_input_nodes(fdef, graph_def, input_shapes):
for i, arg_def in enumerate(fdef.signature.input_arg):
node_def = graph_def.node.add()
node_def.name = arg_def.name
node_def.op = "_Arg"
node_def.attr["T"].type = arg_def.type
node_def.attr["index"].i = i
if input_shapes and input_shapes[i] is not None:
input_shape = input_shapes[i]
if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto):
input_shape = input_shape.as_proto()
node_def.attr["shape"].shape.CopyFrom(input_shape)
arg_attrs = fdef.arg_attr[i].attr
for k in arg_attrs:
# Only copy internal attributes. Normal attributes for nodes cannot be
# applied to these Arg nodes.
if k.startswith("_"):
node_def.attr[k].CopyFrom(arg_attrs[k])
return

def create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name):
for i, arg_def in enumerate(fdef.signature.output_arg):
node_def = graph_def.node.add()
node_def.name = '{}_Retval'.format(arg_def.name)
node_def.op = "_Retval"
node_def.attr["T"].type = arg_def.type
node_def.attr["index"].i = i
node_def.attr["op_def"].s = ops.get_default_graph()._get_op_def(node_def.op).SerializeToString()

ret_name = fdef.ret[arg_def.name]
node_def.input.append(nested_to_flat_tensor_name[ret_name])
return

def updat_input_index(node_def, op_def, nested_to_flat_tensor_name):
flattened_index = 0
for arg_def in op_def.output_arg:
num_args = _get_num_args(arg_def, node_def)
for i in range(num_args):
# Map tensor names from "node_name:output_arg_name:index" to
# "node_name:flattened_index".
nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
if flattened_index == 0:
flat_name = node_def.name
else:
flat_name = "{}:{}".format(node_def.name, flattened_index)
nested_to_flat_tensor_name[nested_name] = flat_name
flattened_index += 1
control_name = "^" + node_def.name
nested_to_flat_tensor_name[control_name] = control_name
return

def build_tensor_name(fdef, default_graph):
nested_to_flat_tensor_name = {}
for arg_def in fdef.signature.input_arg:
nested_to_flat_tensor_name[arg_def.name] = arg_def.name
control_name = '^{}'.format(arg_def.name)
nested_to_flat_tensor_name[control_name] = control_name

global op_def
for node_def in fdef.node_def:
f = default_graph._functions.get(node_def.op, None)
if f is not None and hasattr(f, "signature"):
op_def = f.signature
if node_def.op not in copied_functions:
# Since this function is referenced as an op type, we have no choice but
# to copy it into the GraphDef if we want downstream tools to process
# it.
graph_def.library.function.add().CopyFrom(f.definition)
copied_functions.add(node_def.op)
else:
op_def = ops.get_default_graph()._get_op_def(node_def.op)

for attr in op_def.attr:
if attr.type == "func":
fname = node_def.attr[attr.name].func.name
if not is_function(fname):
raise ValueError("%s function not found." % fname)
elif attr.type == "list(func)":
for fn in node_def.attr[attr.name].list.func:
fname = fn.name
if not is_function(fname):
raise ValueError("%s function not found." % fname)

# Iterate over output_args in op_def to build the map.
# Index of the output tensor in the flattened list of *all* output
# tensors of the op.
updat_input_index(node_def, op_def, nested_to_flat_tensor_name)
return nested_to_flat_tensor_name

def convert_function_def_to_graph_def(fdef, input_shapes=None, copy_functions=True):
graph_def = graph_pb2.GraphDef()
graph_def.versions.CopyFrom(
versions_pb2.VersionDef(
producer=versions.GRAPH_DEF_VERSION,
min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))

default_graph = ops.get_default_graph()

copied_functions = set()

# Copy *all* functions from outer graph to `graph_def` so that both direct
# and indirect references are safely handled.
if copy_functions:
default_graph._copy_functions_to_graph_def(graph_def, 0)
for function_name in default_graph._functions.keys():
copied_functions.add(function_name)

if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
raise ValueError("Length of input_shapes must match the number of " +
"input_args. len(input_shapes): {} len(input_arg): {}".
format(len(input_shapes), len(fdef.signature.input_arg)))

# 1. Create _Arg for input nodes.
create_arg_for_input_nodes(fdef, graph_def, input_shapes)

# 2. Copy all body NodeDefs to the GraphDef.
graph_def.node.extend(fdef.node_def)

# 3. Perform the renaming.

# Build the tensor name mapping then flatten the tensor names.
# See comment on `FunctionDef.node_def` on how the tensor naming in
# FunctionDefs is different from GraphDefs.
nested_to_flat_tensor_name = build_tensor_name(fdef, default_graph)

# Update inputs of all nodes in graph.
for node_def in graph_def.node:
for i in range(len(node_def.input)):
node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]

# Create _Retval for output nodes.
create_retval_for_output_nodes(fdef, graph_def, nested_to_flat_tensor_name)

return graph_def, nested_to_flat_tensor_name


def convert_graphs(filename):
try:
with tf.io.gfile.GFile(filename, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
if len(graph_def.library.function) == 0:
print("INFO: The input model does not contain a functionDef and does not require conversion.")
return
try:
convert_subgraphs(graph_def, filename)
except Exception as e:
print("ERROR: Convert subgraphs failed.", e)
return
print("INFO: Convert to subgraphs successfully.")
except NotFoundError:
print('ERROR: model file {} does not exist'.format(filename))
return


def convert_subgraphs(graph_def, filename):
graph_def_library = graph_library_pb2.GraphDefLibrary()
for i, fdef in enumerate(graph_def.library.function):
sub_graph, nested_to_flat_tensor_name = convert_function_def_to_graph_def(fdef, copy_functions=False)
print("INFO: Convert FunctionDef, index:{}, name:{}".format(str(i), fdef.signature.name))
sub_graph_name = '{}.pb'.format(fdef.signature.name)
result_path = '{}/results'.format(os.path.dirname(os.path.abspath(filename)))
tf.io.write_graph(sub_graph, result_path, sub_graph_name, as_text=False)
data = sub_graph.SerializeToString()
ge_graph_def = graph_library_pb2.GeGraphDef()
ge_graph_def.name = fdef.signature.name
ge_graph_def.graph.ParseFromString(data)
graph_def_library.graph_def.append(ge_graph_def)
print(graph_def_library.graph_def[i])

# Write to prototxt
try:
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename)))
print("graph_def_file: ", graph_def_file)
with open(graph_def_file, "w") as f:
print(graph_def_library, file=f)
except IOError:
print("Could not open file. Creating a new one.")


def usage():
print(
'''
Based on tensorflow 1.15 or later, Python 3

Convert the tensorflow functionDefs in the input model file to single GraphDefs,
and save the result to the "results" directory and graph_def_library.pbtxt in
the input file directory.
The name of the sub graph is same as the name of the corresponding functionDef.

Usage: func2grpah.py <command>

Available commands:
model (-m) Input model file.
version (-v) Prints the version of this software.
help (-h) Prints help for commands.
'''
)


if __name__ == '__main__':
model = ''
try:
opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model='])
for opt_name, opt_value in opts:
if opt_name in ('-m', '--model'):
model = opt_value
print("INFO: Input model file is", model)
convert_graphs(model)
elif opt_name in ('-h', '--help'):
usage()
break
elif opt_name in ('-v', '--version'):
print("version 1.0.0")
break
except getopt.GetoptError:
print("ERROR: Input parameters is invalid, use '--help' to view the help.")
if (len(sys.argv) == 1):
print("INFO: Please specify the input parameters, and use '--help' to view the help.")

+ 9
- 0
parser/func_to_graph/module.mk View File

@@ -0,0 +1,9 @@
LOCAL_PATH := $(call my-dir)

include $(CLEAR_VARS)

LOCAL_MODULE := func2graph/util

LOCAL_MODULE_CLASS := FOLDER

include $(LOCAL_PATH)/proto_python_rule.mk

+ 62
- 0
parser/func_to_graph/proto/attr_value.proto View File

@@ -0,0 +1,62 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "AttrValueProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "tensor.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing the value for an attr used to configure an Op.
// Comment indicates the corresponding attr type. Only the field matching the
// attr type may be filled.
message AttrValue {
// LINT.IfChange
message ListValue {
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated DataType type = 6 [packed = true]; // "list(type)"
repeated TensorShapeProto shape = 7; // "list(shape)"
repeated TensorProto tensor = 8; // "list(tensor)"
repeated NameAttrList func = 9; // "list(attr)"
}
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)

oneof value {
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
DataType type = 6; // "type"
TensorShapeProto shape = 7; // "shape"
TensorProto tensor = 8; // "tensor"
ListValue list = 1; // any "list(...)"

// "func" represents a function. func.name is a function's name or
// a primitive op's name. func.attr.first is the name of an attr
// defined for that function. func.attr.second is the value for
// that attr in the instantiation.
NameAttrList func = 10;

// This is a placeholder only used in nodes defined inside a
// function. It indicates the attr value will be supplied when
// the function is instantiated. For example, let us suppose a
// node "N" in function "FN". "N" has an attr "A" with value
// placeholder = "foo". When FN is instantiated with attr "foo"
// set to "bar", the instantiated node N's attr A will have been
// given the value "bar".
string placeholder = 9;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
message NameAttrList {
string name = 1;
map<string, AttrValue> attr = 2;
}

+ 100
- 0
parser/func_to_graph/proto/function.proto View File

@@ -0,0 +1,100 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "FunctionProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "node_def.proto";
import "op_def.proto";

// A library is a set of named functions.
message FunctionDefLibrary {
repeated FunctionDef function = 1;
repeated GradientDef gradient = 2;
}

// A function can be instantiated when the runtime can bind every attr
// with a value. When a GraphDef has a call to a function, it must
// have binding for every attr defined in the signature.
// * device spec, etc.
message FunctionDef {
// The definition of the function's name, arguments, return values,
// attrs etc.
OpDef signature = 1;

// Attributes specific to this function definition.
map<string, AttrValue> attr = 5;

// NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21.
reserved 2;

// In both of the following fields, there is the need to specify an
// output that is used as either the input to another node (in
// `node_def`) or as a return value of the function (in `ret`).
// Unlike the NodeDefs in GraphDef, we need to be able to specify a
// list in some cases (instead of just single outputs). Also, we
// need to be able to deal with lists of unknown length (so the
// output index may not be known at function definition time). So
// we use the following format instead:
// * "fun_in" where "fun_in" is the name of a function input arg in
// the `signature` field above. This represents that input, whether
// it is a single tensor or a list.
// * "fun_in:0" gives the first element of a function input arg (a
// non-list input is considered a list of length 1 for these
// purposes).
// * "node:out" where "node" is the name of a node in `node_def` and
// "out" is the name one of its op's output arguments (the name
// comes from the OpDef of the node's op). This represents that
// node's output, whether it is a single tensor or a list.
// Note: We enforce that an op's output arguments are never
// renamed in the backwards-compatibility test.
// * "node:out:0" gives the first element of a node output arg (a
// non-list output is considered a list of length 1 for these
// purposes).
//
// NOT CURRENTLY SUPPORTED (but may be in the future):
// * "node:out:-1" gives last element in a node output list
// * "node:out:1:" gives a list with all but the first element in a
// node output list
// * "node:out::-1" gives a list with all but the last element in a
// node output list

// The body of the function. Unlike the NodeDefs in a GraphDef, attrs
// may have values of type `placeholder` and the `input` field uses
// the "output" format above.

// By convention, "op" in node_def is resolved by consulting with a
// user-defined library first. If not resolved, "func" is assumed to
// be a builtin op.
repeated NodeDef node_def = 3;

// A mapping from the output arg names from `signature` to the
// outputs from `node_def` that should be returned by the function.
map<string, string> ret = 4;
}

// GradientDef defines the gradient function of a function defined in
// a function library.
//
// A gradient function g (specified by gradient_func) for a function f
// (specified by function_name) must follow the following:
//
// The function 'f' must be a numerical function which takes N inputs
// and produces M outputs. Its gradient function 'g', which is a
// function taking N + M inputs and produces N outputs.
//
// I.e. if we have
// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
// then, g is
// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
// dL/dy1, dL/dy2, ..., dL/dy_M),
// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
// loss function). dL/dx_i is the partial derivative of L with respect
// to x_i.
message GradientDef {
string function_name = 1; // The function name.
string gradient_func = 2; // The gradient function's name.
}

+ 56
- 0
parser/func_to_graph/proto/graph.proto View File

@@ -0,0 +1,56 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "GraphProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "node_def.proto";
import "function.proto";
import "versions.proto";

// Represents the graph of operations
message GraphDef {
repeated NodeDef node = 1;

// Compatibility versions of the graph. See core/public/version.h for version
// history. The GraphDef version is distinct from the TensorFlow version, and
// each release of TensorFlow will support a range of GraphDef versions.
VersionDef versions = 4;

// Deprecated single version field; use versions above instead. Since all
// GraphDef changes before "versions" was introduced were forward
// compatible, this field is entirely ignored.
int32 version = 3 [deprecated = true];

// EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
//
// "library" provides user-defined functions.
//
// Naming:
// * library.function.name are in a flat namespace.
// NOTE: We may need to change it to be hierarchical to support
// different orgs. E.g.,
// { "/google/nn", { ... }},
// { "/google/vision", { ... }}
// { "/org_foo/module_bar", { ... }}
// map<string, FunctionDefLib> named_lib;
// * If node[i].op is the name of one function in "library",
// node[i] is deemed as a function call. Otherwise, node[i].op
// must be a primitive operation supported by the runtime.
//
//
// Function call semantics:
//
// * The callee may start execution as soon as some of its inputs
// are ready. The caller may want to use Tuple() mechanism to
// ensure all inputs are ready in the same time.
//
// * The consumer of return values may start executing as soon as
// the return values the consumer depends on are ready. The
// consumer may want to use Tuple() mechanism to ensure the
// consumer does not start until all return values of the callee
// function are ready.
FunctionDefLibrary library = 2;
};

+ 14
- 0
parser/func_to_graph/proto/graph_library.proto View File

@@ -0,0 +1,14 @@
syntax = "proto3";

package domi.tensorflow;

import "graph.proto";

message GeGraphDef {
string name = 1;
GraphDef graph = 2;
}

message GraphDefLibrary {
repeated GeGraphDef graph_def = 1;
};

+ 63
- 0
parser/func_to_graph/proto/node_def.proto View File

@@ -0,0 +1,63 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "NodeProto";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";

message NodeDef {
// The name given to this operator. Used for naming inputs,
// logging, visualization, etc. Unique within a single GraphDef.
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
string name = 1;

// The operation name. There may be custom parameters in attrs.
// Op names starting with an underscore are reserved for internal use.
string op = 2;

// Each input is "node:src_output" with "node" being a string name and
// "src_output" indicating which output tensor to use from "node". If
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
// may optionally be followed by control inputs that have the format
// "^node".
repeated string input = 3;

// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
string device = 4;

// Operation-specific graph-construction-time configuration.
// Note that this should include all attrs defined in the
// corresponding OpDef, including those with a value matching
// the default -- this allows the default to change and makes
// NodeDefs easier to interpret on their own. However, if
// an attr with a default is not specified in this list, the
// default will be used.
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
// one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef
// attr's type field.
// Add some examples here showing best practices.
map<string, AttrValue> attr = 5;
};

+ 164
- 0
parser/func_to_graph/proto/op_def.proto View File

@@ -0,0 +1,164 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "OpDefProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "attr_value.proto";
import "types.proto";

// Defines an operation. A NodeDef in a GraphDef specifies an Op by
// using the "op" field which should match the name of a OpDef.
// LINT.IfChange
message OpDef {
// Op names starting with an underscore are reserved for internal use.
// Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*".
string name = 1;

// For describing inputs and outputs.
message ArgDef {
// Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*".
string name = 1;

// Human readable description.
string description = 2;

// Describes the type of one or more tensors that are accepted/produced
// by this input/output arg. The only legal combinations are:
// * For a single tensor: either the "type" field is set or the
// "type_attr" field is set to the name of an attr with type "type".
// * For a sequence of tensors with the same type: the "number_attr"
// field will be set to the name of an attr with type "int", and
// either the "type" or "type_attr" field will be set as for
// single tensors.
// * For a sequence of tensors, the "type_list_attr" field will be set
// to the name of an attr with type "list(type)".
DataType type = 3;
string type_attr = 4; // if specified, attr must have type "type"
string number_attr = 5; // if specified, attr must have type "int"
// If specified, attr must have type "list(type)", and none of
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;

// For inputs: if true, the inputs are required to be refs.
// By default, inputs can be either refs or non-refs.
// For outputs: if true, outputs are refs, otherwise they are not.
bool is_ref = 16;
};

// Description of the input(s).
repeated ArgDef input_arg = 2;

// Description of the output(s).
repeated ArgDef output_arg = 3;

// Description of the graph-construction-time configuration of this
// Op. That is to say, this describes the attr fields that will
// be specified in the NodeDef.
message AttrDef {
// A descriptive name for the argument. May be used, e.g. by the
// Python client, as a keyword argument name, and so should match
// the regexp "[a-z][a-z0-9_]+".
string name = 1;

// One of the type names from attr_value.proto ("string", "list(string)",
// "int", etc.).
string type = 2;

// A reasonable default for this attribute if the user does not supply
// a value. If not specified, the user must supply a value.
AttrValue default_value = 3;

// Human-readable description.
string description = 4;


// --- Constraints ---
// These constraints are only in effect if specified. Default is no
// constraints.

// For type == "int", this is a minimum value. For "list(___)"
// types, this is the minimum length.
bool has_minimum = 5;
int64 minimum = 6;

// The set of allowed values. Has type that is the "list" version
// of the "type" field above (uses the "list" field of AttrValue).
// If type == "type" or "list(type)" above, then the "type" field
// of "allowed_values.list" has the set of allowed DataTypes.
// If type == "string" or "list(string)", then the "s" field of
// "allowed_values.list" has the set of allowed strings.
AttrValue allowed_values = 7;
}
repeated AttrDef attr = 4;

// Optional deprecation based on GraphDef versions.
OpDeprecation deprecation = 8;

// One-line human-readable description of what the Op does.
string summary = 5;

// Additional, longer human-readable description of what the Op does.
string description = 6;

// -------------------------------------------------------------------------
// Which optimizations this operation can participate in.

// True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs)
bool is_commutative = 18;

// If is_aggregate is true, then this operation accepts N >= 2
// inputs and produces 1 output all of the same type. Should be
// associative and commutative, and produce output with the same
// shape as the input. The optimizer may replace an aggregate op
// taking input from multiple devices with a tree of aggregate ops
// that aggregate locally within each device (and possibly within
// groups of nearby devices) before communicating.
bool is_aggregate = 16; // for things like add

// Other optimizations go here, like
// can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc.

// -------------------------------------------------------------------------
// Optimization constraints.

// Ops are marked as stateful if their behavior depends on some state beyond
// their input tensors (e.g. variable reading op) or if they have
// a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops
// must always produce the same output for the same input and have
// no side-effects.
//
// By default Ops may be moved between devices. Stateful ops should
// either not be moved, or should only be moved if that state can also
// be moved (e.g. via some sort of save / restore).
// Stateful ops are guaranteed to never be optimized away by Common
// Subexpression Elimination (CSE).
bool is_stateful = 17; // for things like variables, queue

// -------------------------------------------------------------------------
// Non-standard options.

// By default, all inputs to an Op must be initialized Tensors. Ops
// that may initialize tensors for the first time should set this
// field to true, to allow the Op to take an uninitialized Tensor as
// input.
bool allows_uninitialized_input = 19; // for Assign, etc.
};
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc)

// Information about version-dependent deprecation of an op
message OpDeprecation {
// First GraphDef version at which the op is disallowed.
int32 version = 1;

// Explanation of why it was deprecated and what to use instead.
string explanation = 2;
};

// A collection of OpDefs
message OpList {
repeated OpDef op = 1;
};

+ 29
- 0
parser/func_to_graph/proto/resource_handle.proto View File

@@ -0,0 +1,29 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "ResourceHandle";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

// Protocol buffer representing a handle to a tensorflow resource. Handles are
// not valid across executions, but can be serialized back and forth from within
// a single run.
message ResourceHandleProto {
// Unique name for the device containing the resource.
string device = 1;

// Container in which this resource is placed.
string container = 2;

// Unique name of this resource.
string name = 3;

// Hash code for the type of the resource. Is only valid in the same device
// and in the same execution.
uint64 hash_code = 4;

// For debug-only, the name of the type pointed to by this handle, if
// available.
string maybe_type_name = 5;
};

+ 94
- 0
parser/func_to_graph/proto/tensor.proto View File

@@ -0,0 +1,94 @@
syntax = "proto3";

package domi.tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";

import "resource_handle.proto";
import "tensor_shape.proto";
import "types.proto";

// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;

// Shape of the tensor.
TensorShapeProto tensor_shape = 2;

// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.

// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;

// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;

// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.

// DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];

// DT_FLOAT.
repeated float float_val = 5 [packed = true];

// DT_DOUBLE.
repeated double double_val = 6 [packed = true];

// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];

// DT_STRING
repeated bytes string_val = 8;

// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];

// DT_INT64
repeated int64 int64_val = 10 [packed = true];

// DT_BOOL
repeated bool bool_val = 11 [packed = true];

// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];

// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;

// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;

// DT_UINT32
repeated uint32 uint32_val = 16 [packed = true];

// DT_UINT64
repeated uint64 uint64_val = 17 [packed = true];
};

// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save