@@ -253,6 +253,7 @@ install(TARGETS _caffe_parser parser_common fmk_onnx_parser fmk_parser parser_he | |||||
install(FILES ${PARSER_DIR}/inc/external/parser/onnx_parser.h | install(FILES ${PARSER_DIR}/inc/external/parser/onnx_parser.h | ||||
${PARSER_DIR}/inc/external/parser/caffe_parser.h | ${PARSER_DIR}/inc/external/parser/caffe_parser.h | ||||
${PARSER_DIR}/inc/external/parser/tensorflow_parser.h | ${PARSER_DIR}/inc/external/parser/tensorflow_parser.h | ||||
${PARSER_DIR}/inc/external/parser/parser_common.h | |||||
DESTINATION ${INSTALL_INCLUDE_DIR}/parser/external/parser COMPONENT opensdk EXCLUDE_FROM_ALL | DESTINATION ${INSTALL_INCLUDE_DIR}/parser/external/parser COMPONENT opensdk EXCLUDE_FROM_ALL | ||||
) | ) | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -17,27 +17,14 @@ | |||||
#ifndef INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | #ifndef INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | ||||
#define INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | #define INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ | ||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include <map> | |||||
#include "graph/ascend_string.h" | #include "graph/ascend_string.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "parser_common.h" | |||||
namespace ge { | namespace ge { | ||||
PARSER_FUNC_VISIBILITY graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | PARSER_FUNC_VISIBILITY graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -17,23 +17,11 @@ | |||||
#ifndef INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | #ifndef INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | ||||
#define INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | #define INC_EXTERNAL_PARSER_ONNX_PARSER_H_ | ||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <map> | |||||
#include "graph/ascend_string.h" | #include "graph/ascend_string.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "parser_common.h" | |||||
namespace ge { | namespace ge { | ||||
PARSER_FUNC_VISIBILITY graphStatus aclgrphParseONNX(const char *model_file, | PARSER_FUNC_VISIBILITY graphStatus aclgrphParseONNX(const char *model_file, | ||||
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_EXTERNAL_ACL_PARSER_COMMON_H_ | |||||
#define INC_EXTERNAL_ACL_PARSER_COMMON_H_ | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#endif // INC_EXTERNAL_ACL_PARSER_COMMON_H_ |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -17,20 +17,6 @@ | |||||
#ifndef INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | #ifndef INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | ||||
#define INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | #define INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ | ||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <atomic> | #include <atomic> | ||||
#include <memory> | #include <memory> | ||||
#include <string> | #include <string> | ||||
@@ -39,6 +25,7 @@ | |||||
#include "graph/ascend_string.h" | #include "graph/ascend_string.h" | ||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "parser_common.h" | |||||
namespace ge { | namespace ge { | ||||
PARSER_FUNC_VISIBILITY graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); | PARSER_FUNC_VISIBILITY graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -45,7 +45,7 @@ const int kBlobIndexOne = 1; | |||||
Status CaffeCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | Status CaffeCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const LayerParameter *layer = reinterpret_cast<const LayerParameter *>(op_src); | |||||
const LayerParameter *layer = PtrToPtr<const Message, const LayerParameter>(op_src); | |||||
GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | ||||
GE_CHECK_NOTNULL(op_dest); | GE_CHECK_NOTNULL(op_dest); | ||||
@@ -78,12 +78,37 @@ Status CaffeCustomParserAdapter::ParseParams(const Operator &op_src, const ge::O | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeCustomParserAdapter::AddEdgeFromConstNode(const NodePtr &const_node, const int32_t index, | |||||
const bool update_in_turn, ge::NodePtr &node) const { | |||||
GE_CHECK_NOTNULL(const_node); | |||||
GE_CHECK_NOTNULL(node); | |||||
auto op = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op); | |||||
auto valid_input_name = op->GetValidInputNameByIndex(index); | |||||
if (update_in_turn || valid_input_name.empty()) { | |||||
if (node->AddLinkFrom(static_cast<const uint32_t &>(index), const_node) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "AddEdge failed of from Node %s output to Node %s input %d", | |||||
const_node->GetName().c_str(), node->GetName().c_str(), index); | |||||
GELOGE(GRAPH_FAILED, "[Invoke][AddLinkFrom] AddEdge failed of from Node %s output to Node %s input %d", | |||||
const_node->GetName().c_str(), node->GetName().c_str(), index); | |||||
} | |||||
} else { | |||||
if (node->AddLinkFrom(valid_input_name, const_node) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "AddEdge failed of from Node %s output to Node %s input %s", | |||||
const_node->GetName().c_str(), node->GetName().c_str(), valid_input_name.c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Invoke][AddLinkFrom] AddEdge failed of from Node %s output to Node %s input %s", | |||||
const_node->GetName().c_str(), node->GetName().c_str(), valid_input_name.c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr &node) { | Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr &node) { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
auto op = node->GetOpDesc(); | auto op = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
const LayerParameter *layer = reinterpret_cast<const LayerParameter *>(op_src); | |||||
const LayerParameter *layer = PtrToPtr<const Message, const LayerParameter>(op_src); | |||||
GE_CHK_BOOL_RET_STATUS(layer != nullptr, FAILED, "[Convert][Type]Dynamic cast op_src to LayerParameter failed"); | GE_CHK_BOOL_RET_STATUS(layer != nullptr, FAILED, "[Convert][Type]Dynamic cast op_src to LayerParameter failed"); | ||||
GELOGI("layer: %s blobs_size: %d bottom_size: %d", layer->name().c_str(), layer->blobs_size(), layer->bottom_size()); | GELOGI("layer: %s blobs_size: %d bottom_size: %d", layer->name().c_str(), layer->blobs_size(), layer->bottom_size()); | ||||
@@ -100,11 +125,11 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr | |||||
GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), | GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), | ||||
"[Convert][Blobs] (%d) for layer %s failed", i, layer->name().c_str()); | "[Convert][Blobs] (%d) for layer %s failed", i, layer->name().c_str()); | ||||
GE_IF_BOOL_EXEC(layer->type() == kConvolution && i == kBlobIndexOne, | GE_IF_BOOL_EXEC(layer->type() == kConvolution && i == kBlobIndexOne, | ||||
const ConvolutionParameter &conv_params_src = layer->convolution_param(); | |||||
bias_en = conv_params_src.bias_term();); | |||||
bias_en = layer->convolution_param().bias_term(); | |||||
); | |||||
GE_IF_BOOL_EXEC(layer->type() == kInnerProduct && i == kBlobIndexOne, | GE_IF_BOOL_EXEC(layer->type() == kInnerProduct && i == kBlobIndexOne, | ||||
const InnerProductParameter &fc_params_src = layer->inner_product_param(); | |||||
bias_en = fc_params_src.bias_term();); | |||||
bias_en = layer->inner_product_param().bias_term(); | |||||
); | |||||
auto bias_shape = weight->MutableTensorDesc().GetShape(); | auto bias_shape = weight->MutableTensorDesc().GetShape(); | ||||
// The num 0, 1, 2, 3 represet the dim index. | // The num 0, 1, 2, 3 represet the dim index. | ||||
bool matched = bias_en && bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) && | bool matched = bias_en && bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) && | ||||
@@ -127,24 +152,8 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr | |||||
// add edge from const to current node | // add edge from const to current node | ||||
auto const_node = owner_graph->AddNodeFront(const_opdesc); | auto const_node = owner_graph->AddNodeFront(const_opdesc); | ||||
GE_CHECK_NOTNULL(const_node); | |||||
auto index = start_pos + i; | auto index = start_pos + i; | ||||
auto valid_input_name = op->GetValidInputNameByIndex(static_cast<uint32_t>(index)); | |||||
if (update_in_turn || valid_input_name.empty()) { | |||||
if (node->AddLinkFrom(static_cast<const uint32_t &>(index), const_node) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "AddEdge failed of from Node %s output to Node %s input %d", | |||||
const_node->GetName().c_str(), node->GetName().c_str(), index); | |||||
GELOGE(GRAPH_FAILED, "[Invoke][AddLinkFrom] AddEdge failed of from Node %s output to Node %s input %d", | |||||
const_node->GetName().c_str(), node->GetName().c_str(), index); | |||||
} | |||||
} else { | |||||
if (node->AddLinkFrom(valid_input_name, const_node) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "AddEdge failed of from Node %s output to Node %s input %s", | |||||
const_node->GetName().c_str(), node->GetName().c_str(), valid_input_name.c_str()); | |||||
GELOGE(GRAPH_FAILED, "[Invoke][AddLinkFrom] AddEdge failed of from Node %s output to Node %s input %s", | |||||
const_node->GetName().c_str(), node->GetName().c_str(), valid_input_name.c_str()); | |||||
} | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(AddEdgeFromConstNode(const_node, static_cast<int32_t>(index), update_in_turn, node)); | |||||
std::vector<ge::NodePtr> original_nodes; | std::vector<ge::NodePtr> original_nodes; | ||||
ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); | ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -54,6 +54,20 @@ class PARSER_FUNC_VISIBILITY CaffeCustomParserAdapter : public CaffeOpParser { | |||||
* @author | * @author | ||||
*/ | */ | ||||
Status ParseWeights(const Message *op_src, ge::NodePtr &node) override; | Status ParseWeights(const Message *op_src, ge::NodePtr &node) override; | ||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief parse weight of the operation | |||||
* @param [in] const_node const node to add link edge | |||||
* @param [in] index index of current node to add link | |||||
* @param [in] update_in_turn flag of update in turn | |||||
* @param [out] node params after parsing | |||||
* @return SUCCESS parse successfullyparse failed | |||||
* @return FAILED | |||||
* @author | |||||
*/ | |||||
Status AddEdgeFromConstNode(const NodePtr &const_node, const int32_t index, | |||||
const bool update_in_turn, ge::NodePtr &node) const; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -28,6 +28,9 @@ using namespace ge::parser; | |||||
using domi::CAFFE; | using domi::CAFFE; | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
const char *kData = "Data"; | |||||
} | |||||
Status CaffeDataParser::GetOutputDesc(const string &name, const std::vector<int64_t> &input_dims, | Status CaffeDataParser::GetOutputDesc(const string &name, const std::vector<int64_t> &input_dims, | ||||
const ge::OpDescPtr &op) const { | const ge::OpDescPtr &op) const { | ||||
GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
@@ -51,10 +54,10 @@ Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | |||||
GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | ||||
if (layer->type() == ge::parser::INPUT_TYPE) { | if (layer->type() == ge::parser::INPUT_TYPE) { | ||||
GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "[Parse][Params] failed, Caffe layer name = %s, " | |||||
GE_CHK_STATUS_RET(ParseParamsForInput(*layer, op), "[Parse][Params] failed, Caffe layer name = %s, " | |||||
"layer type= %s", layer->name().c_str(), layer->type().c_str()); | "layer type= %s", layer->name().c_str(), layer->type().c_str()); | ||||
} else if (layer->type() == ge::parser::DUMMY_DATA) { | } else if (layer->type() == ge::parser::DUMMY_DATA) { | ||||
GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "[Parse][Params] failed, Caffe layer name = %s, " | |||||
GE_CHK_STATUS_RET(ParseParamsForDummyData(*layer, op), "[Parse][Params] failed, Caffe layer name = %s, " | |||||
"layer type= %s", layer->name().c_str(), layer->type().c_str()); | "layer type= %s", layer->name().c_str(), layer->type().c_str()); | ||||
} else { | } else { | ||||
REPORT_INNER_ERROR("E19999", "layer:%s(%s) type is not %s or %s, check invalid", | REPORT_INNER_ERROR("E19999", "layer:%s(%s) type is not %s or %s, check invalid", | ||||
@@ -66,14 +69,14 @@ Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op) { | |||||
if (layer->has_input_param()) { | |||||
const domi::caffe::InputParameter &input_param = layer->input_param(); | |||||
Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op) const { | |||||
if (layer.has_input_param()) { | |||||
const domi::caffe::InputParameter &input_param = layer.input_param(); | |||||
if (input_param.shape_size() == 0) { | if (input_param.shape_size() == 0) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
"E11027", {"layername", "layertype"}, {layer->name(), layer->type()}); | |||||
"E11027", {"layername", "layertype"}, {layer.name(), layer.type()}); | |||||
GELOGE(PARAM_INVALID, "[Check][Param]input_param shape size is zero, check invalid, " | GELOGE(PARAM_INVALID, "[Check][Param]input_param shape size is zero, check invalid, " | ||||
"caffe layer name [%s], layer type [%s].", layer->name().c_str(), layer->type().c_str()); | |||||
"caffe layer name [%s], layer type [%s].", layer.name().c_str(), layer.type().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
for (int i = 0; i < input_param.shape_size(); i++) { | for (int i = 0; i < input_param.shape_size(); i++) { | ||||
@@ -84,7 +87,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||||
for (auto &blob_shape_dim_temp : blob_shape.dim()) { | for (auto &blob_shape_dim_temp : blob_shape.dim()) { | ||||
model_dims.push_back(blob_shape_dim_temp); | model_dims.push_back(blob_shape_dim_temp); | ||||
} | } | ||||
string name = layer->name(); | |||||
string name = layer.name(); | |||||
GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | ||||
GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims, op), | GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims, op), | ||||
"[Get][OutputDesc] failed in layer %s", name.c_str()); | "[Get][OutputDesc] failed in layer %s", name.c_str()); | ||||
@@ -93,13 +96,13 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||||
// Get from external input | // Get from external input | ||||
const ge::ParserContext &ctx = GetParserContext(); | const ge::ParserContext &ctx = GetParserContext(); | ||||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
string name = layer->name(); | |||||
string name = layer.name(); | |||||
std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name); | std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name); | ||||
if (search == input_dims.end()) { | if (search == input_dims.end()) { | ||||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); | |||||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer.name()})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | ||||
"should set --input_shape in atc parameter, caffe layer name [%s], layer type [%s].", | "should set --input_shape in atc parameter, caffe layer name [%s], layer type [%s].", | ||||
layer->name().c_str(), layer->type().c_str()); | |||||
layer.name().c_str(), layer.type().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
std::vector<int64_t> dims = search->second; | std::vector<int64_t> dims = search->second; | ||||
@@ -109,14 +112,14 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op) { | |||||
if (layer->has_dummy_data_param()) { | |||||
const domi::caffe::DummyDataParameter &dummy_data_param = layer->dummy_data_param(); | |||||
Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op) const { | |||||
if (layer.has_dummy_data_param()) { | |||||
const domi::caffe::DummyDataParameter &dummy_data_param = layer.dummy_data_param(); | |||||
if (dummy_data_param.shape_size() == 0) { | if (dummy_data_param.shape_size() == 0) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
"E11027", {"layername", "layertype"}, {layer->name(), layer->type()}); | |||||
"E11027", {"layername", "layertype"}, {layer.name(), layer.type()}); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] input_param shape size is zero, caffe layer name [%s], layer type [%s].", | GELOGE(PARAM_INVALID, "[Check][Param] input_param shape size is zero, caffe layer name [%s], layer type [%s].", | ||||
layer->name().c_str(), layer->type().c_str()); | |||||
layer.name().c_str(), layer.type().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
for (int i = 0; i < dummy_data_param.shape_size(); i++) { | for (int i = 0; i < dummy_data_param.shape_size(); i++) { | ||||
@@ -129,7 +132,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||||
model_dims.push_back(blob_shape_dim_temp); | model_dims.push_back(blob_shape_dim_temp); | ||||
} | } | ||||
string name = layer->name(); | |||||
string name = layer.name(); | |||||
GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | ||||
GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims, op), | GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims, op), | ||||
"[Get][OutputDesc] failed in layer %s", name.c_str()); | "[Get][OutputDesc] failed in layer %s", name.c_str()); | ||||
@@ -138,13 +141,13 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||||
// Get from external input | // Get from external input | ||||
const ge::ParserContext &ctx = GetParserContext(); | const ge::ParserContext &ctx = GetParserContext(); | ||||
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
string name = layer->name(); | |||||
string name = layer.name(); | |||||
std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name); | std::map<std::string, std::vector<int64_t>>::const_iterator search = input_dims.find(name); | ||||
if (search == input_dims.end()) { | if (search == input_dims.end()) { | ||||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer->name()})); | |||||
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({layer.name()})); | |||||
GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | GELOGE(PARAM_INVALID, "[Check][Param] Caffe prototxt has no input_param or user " | ||||
"should set --input_shape in atc parameter, caffe layer name [%s], layer type [%s].", | "should set --input_shape in atc parameter, caffe layer name [%s], layer type [%s].", | ||||
layer->name().c_str(), layer->type().c_str()); | |||||
layer.name().c_str(), layer.type().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
std::vector<int64_t> dims = search->second; | std::vector<int64_t> dims = search->second; | ||||
@@ -153,6 +156,5 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
REGISTER_OP_PARSER_CREATOR(CAFFE, DATA, CaffeDataParser); | |||||
REGISTER_OP_PARSER_CREATOR(CAFFE, kData, CaffeDataParser); | |||||
} // namespace ge | } // namespace ge |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -48,8 +48,8 @@ class PARSER_FUNC_VISIBILITY CaffeDataParser : public CaffeOpParser, public Data | |||||
Status GetOutputDesc(const std::string &name, const std::vector<int64_t> &input_dims, const ge::OpDescPtr &op) const; | Status GetOutputDesc(const std::string &name, const std::vector<int64_t> &input_dims, const ge::OpDescPtr &op) const; | ||||
// caffe data layer type could be type of `Input` or `DummyData` | // caffe data layer type could be type of `Input` or `DummyData` | ||||
Status ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op); | |||||
Status ParseParamsForDummyData(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op); | |||||
Status ParseParamsForInput(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op) const; | |||||
Status ParseParamsForDummyData(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op) const; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -26,6 +26,10 @@ using domi::caffe::BlobProto; | |||||
using domi::CAFFE; | using domi::CAFFE; | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
const char *kNetOutput = "NetOutput"; | |||||
const char *kDropout = "Dropout"; | |||||
} | |||||
Status CaffeOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | Status CaffeOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | ||||
(void)op_src; | (void)op_src; | ||||
(void)op_dest; | (void)op_dest; | ||||
@@ -88,6 +92,18 @@ Status CaffeOpParser::ConvertWeight(const BlobProto &proto, const string &lay_na | |||||
return ParseWeightType(proto, shape, count, lay_name, weight); | return ParseWeightType(proto, shape, count, lay_name, weight); | ||||
} | } | ||||
Status CaffeOpParser::CheckSizeInvalid(const string &lay_name, const int32_t blob_size, const int32_t size) { | |||||
if (blob_size == size) { | |||||
return SUCCESS; | |||||
} | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11033", {"opname", "blobsize", "reason"}, | |||||
{lay_name, std::to_string(blob_size), | |||||
"it does not match shape size[" + std::to_string(size) + "]"}); | |||||
GELOGE(FAILED, "[Check][Param]Convert weight fail, Blob size does not match shape size, " | |||||
"shape size:%d, blob size:%d, layer name:%s", size, blob_size, lay_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape &shape, int size, | Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape &shape, int size, | ||||
const string &lay_name, ge::GeTensorPtr &weight) { | const string &lay_name, ge::GeTensorPtr &weight) { | ||||
// Extract weight data and store it in weightdef by float type | // Extract weight data and store it in weightdef by float type | ||||
@@ -95,14 +111,7 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||||
ge::DataType dtype = ge::DT_FLOAT; | ge::DataType dtype = ge::DT_FLOAT; | ||||
if (proto.double_data_size() > 0) { | if (proto.double_data_size() > 0) { | ||||
// Convert by double type | // Convert by double type | ||||
if (size != proto.double_data_size()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11033", {"opname", "blobsize", "reason"}, | |||||
{lay_name, std::to_string(proto.double_data_size()), | |||||
"it does not match shape size[" + std::to_string(size) + "]"}); | |||||
GELOGE(FAILED, "[Check][Param]Convert weight fail, Blob size does not match shape size, " | |||||
"shape size:%d, blob size:%d, layer name:%s", size, proto.double_data_size(), lay_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(CheckSizeInvalid(lay_name, proto.double_data_size(), size)); | |||||
std::unique_ptr<float[]> buf(new (std::nothrow) float[size]()); | std::unique_ptr<float[]> buf(new (std::nothrow) float[size]()); | ||||
GE_CHECK_NOTNULL(buf); | GE_CHECK_NOTNULL(buf); | ||||
for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
@@ -111,71 +120,39 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<float, uint8_t>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS, | GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<float, uint8_t>(buf.get()), size * sizeof(float)) != ge::GRAPH_SUCCESS, | ||||
GELOGW("SetData failed for GeTensor.");); // no need to return | GELOGW("SetData failed for GeTensor.");); // no need to return | ||||
} else if (proto.int8_data().length() > 0) { | } else if (proto.int8_data().length() > 0) { | ||||
if (size != static_cast<int>(proto.int8_data().length())) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11033", {"opname", "blobsize", "reason"}, | |||||
{lay_name, std::to_string(proto.int8_data().length()), | |||||
"it does not match shape size[" + std::to_string(size) + "]"}); | |||||
GELOGE(FAILED, "[Check][Param]Convert weight failed, Blob size does not match shape size, " | |||||
"shape size:%d, blob size:%ld, layer name:%s", size, proto.int8_data().length(), lay_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(CheckSizeInvalid(lay_name, static_cast<int32_t>(proto.int8_data().length()), size)); | |||||
const char *data_ptr = proto.int8_data().data(); | const char *data_ptr = proto.int8_data().data(); | ||||
GE_CHECK_NOTNULL(data_ptr); | GE_CHECK_NOTNULL(data_ptr); | ||||
GE_IF_BOOL_EXEC( | |||||
weight->SetData(PtrToPtr<const char, const uint8_t>(data_ptr), size * sizeof(int8_t)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<const char, const uint8_t>(data_ptr), size * sizeof(int8_t)) != | |||||
ge::GRAPH_SUCCESS, GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
dtype = ge::DT_INT8; | dtype = ge::DT_INT8; | ||||
} else if (proto.int32_data_size() > 0) { | } else if (proto.int32_data_size() > 0) { | ||||
if (size != proto.int32_data_size()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11033", {"opname", "blobsize", "reason"}, | |||||
{lay_name, std::to_string(proto.int32_data_size()), | |||||
"it does not match shape size[" + std::to_string(size) + "]"}); | |||||
GELOGE(FAILED, "[Check][Param]Convert weight failed, Blob size does not match shape size, " | |||||
"shape size:%d, blob size:%d, layer name:%s", size, proto.int32_data_size(), lay_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(CheckSizeInvalid(lay_name, static_cast<int32_t>(proto.int32_data_size()), size)); | |||||
std::unique_ptr<int32_t[]> int32_weight_buf(new (std::nothrow) int32_t[size]()); | std::unique_ptr<int32_t[]> int32_weight_buf(new (std::nothrow) int32_t[size]()); | ||||
GE_CHECK_NOTNULL(int32_weight_buf); | GE_CHECK_NOTNULL(int32_weight_buf); | ||||
for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
int32_weight_buf[i] = proto.int32_data(i); | int32_weight_buf[i] = proto.int32_data(i); | ||||
} | } | ||||
GE_IF_BOOL_EXEC( | |||||
weight->SetData(PtrToPtr<int32_t, uint8_t>(int32_weight_buf.get()), size * sizeof(int32_t)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<int32_t, uint8_t>(int32_weight_buf.get()), size * sizeof(int32_t)) != | |||||
ge::GRAPH_SUCCESS, GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
dtype = ge::DT_INT32; | dtype = ge::DT_INT32; | ||||
} else if (proto.uint64_data_size() > 0) { | } else if (proto.uint64_data_size() > 0) { | ||||
if (size != proto.uint64_data_size()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11033", {"opname", "blobsize", "reason"}, | |||||
{lay_name, std::to_string(proto.uint64_data_size()), | |||||
"it does not match shape size[" + std::to_string(size) + "]"}); | |||||
GELOGE(FAILED, "[Check][Param]Convert weight failed, Blob size does not match shape size, " | |||||
"shape size:%d, blob size:%d, layer name:%s", size, proto.uint64_data_size(), lay_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(CheckSizeInvalid(lay_name, static_cast<int32_t>(proto.uint64_data_size()), size)); | |||||
std::unique_ptr<uint64_t[]> uint64_weight_buf(new (std::nothrow) uint64_t[size]()); | std::unique_ptr<uint64_t[]> uint64_weight_buf(new (std::nothrow) uint64_t[size]()); | ||||
GE_CHECK_NOTNULL(uint64_weight_buf); | GE_CHECK_NOTNULL(uint64_weight_buf); | ||||
for (int i = 0; i < size; ++i) { | for (int i = 0; i < size; ++i) { | ||||
uint64_weight_buf[i] = proto.uint64_data(i); | uint64_weight_buf[i] = proto.uint64_data(i); | ||||
} | } | ||||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<uint64_t, uint8_t>(uint64_weight_buf.get()), size * sizeof(uint64_t)) != | GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<uint64_t, uint8_t>(uint64_weight_buf.get()), size * sizeof(uint64_t)) != | ||||
ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
ge::GRAPH_SUCCESS, GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
dtype = ge::DT_UINT64; | dtype = ge::DT_UINT64; | ||||
} else { | } else { | ||||
// Convert by float type | // Convert by float type | ||||
if (size != proto.data_size()) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E11033", {"opname", "blobsize", "reason"}, | |||||
{lay_name, std::to_string(proto.data_size()), | |||||
"it does not match shape size[" + std::to_string(size) + "]"}); | |||||
GELOGE(FAILED, "[Check][Param]Convert weight fail, Blob size does not match shape size, " | |||||
"shape size:%d, blob.data_size:%d, layer name:%s", size, proto.data_size(), lay_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG(CheckSizeInvalid(lay_name, static_cast<int32_t>(proto.data_size()), size)); | |||||
const float *data_ptr = proto.data().data(); | const float *data_ptr = proto.data().data(); | ||||
GE_CHECK_NOTNULL(data_ptr); | GE_CHECK_NOTNULL(data_ptr); | ||||
GE_IF_BOOL_EXEC( | |||||
weight->SetData(PtrToPtr<const float, const uint8_t>(data_ptr), size * sizeof(float)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
GE_IF_BOOL_EXEC(weight->SetData(PtrToPtr<const float, const uint8_t>(data_ptr), size * sizeof(float)) != | |||||
ge::GRAPH_SUCCESS, GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
} | } | ||||
ge::GeTensorDesc weight_desc = ge::GeTensorDesc(); | ge::GeTensorDesc weight_desc = ge::GeTensorDesc(); | ||||
weight_desc.Update(shape, ge::FORMAT_NCHW, dtype); | weight_desc.Update(shape, ge::FORMAT_NCHW, dtype); | ||||
@@ -184,11 +161,11 @@ Status CaffeOpParser::ParseWeightType(const BlobProto &proto, const ge::GeShape | |||||
} | } | ||||
// Dropout's corresponding op_parser is registered as caffeopparser, optimized in optimization stage. | // Dropout's corresponding op_parser is registered as caffeopparser, optimized in optimization stage. | ||||
REGISTER_OP_PARSER_CREATOR(CAFFE, DROPOUT, CaffeOpParser); | |||||
REGISTER_OP_PARSER_CREATOR(CAFFE, kDropout, CaffeOpParser); | |||||
// A new operator added by framework in OM model is used to | // A new operator added by framework in OM model is used to | ||||
// collect and arrange all outputs in the order of the original model's output | // collect and arrange all outputs in the order of the original model's output | ||||
// Net output operator does not need special processing in the parse stage, | // Net output operator does not need special processing in the parse stage, | ||||
// and directly registers in the op_parser file | // and directly registers in the op_parser file | ||||
REGISTER_OP_PARSER_CREATOR(CAFFE, NETOUTPUT, CaffeOpParser); | |||||
REGISTER_OP_PARSER_CREATOR(CAFFE, kNetOutput, CaffeOpParser); | |||||
} // namespace ge | } // namespace ge |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -39,7 +39,6 @@ | |||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/operator.h" | #include "graph/operator.h" | ||||
#include "graph/types.h" | |||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
@@ -112,6 +111,18 @@ class PARSER_FUNC_VISIBILITY CaffeOpParser : public OpParser { | |||||
*/ | */ | ||||
static Status ParseWeightType(const domi::caffe::BlobProto &proto, const ge::GeShape &shape, | static Status ParseWeightType(const domi::caffe::BlobProto &proto, const ge::GeShape &shape, | ||||
int size, const string &lay_name, ge::GeTensorPtr &weight); | int size, const string &lay_name, ge::GeTensorPtr &weight); | ||||
private: | |||||
/** | |||||
* @ingroup ge_omg | |||||
* @brief Convert blob proto to weight definition | |||||
* @param [in] lay_name op name | |||||
* @param [in] blob_size blob size | |||||
* @param [in] size input size | |||||
* @return SUCCESS check size invalid | |||||
* @return FAILED parse failed | |||||
*/ | |||||
static Status CheckSizeInvalid(const string &lay_name, const int32_t blob_size, const int32_t size); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -389,7 +389,8 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||||
Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path, | Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path, | ||||
const string &custom_proto_name, vector<ge::Operator> &operators) { | |||||
const string &custom_proto_name, | |||||
vector<ge::Operator> &operators) const { | |||||
google::protobuf::compiler::DiskSourceTree source_tree; | google::protobuf::compiler::DiskSourceTree source_tree; | ||||
source_tree.MapPath(kProjectRoot, custom_proto_path); | source_tree.MapPath(kProjectRoot, custom_proto_path); | ||||
google::protobuf::compiler::Importer importer(&source_tree, nullptr); | google::protobuf::compiler::Importer importer(&source_tree, nullptr); | ||||
@@ -1926,7 +1927,7 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me | |||||
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, | Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, | ||||
const google::protobuf::Message *message, | const google::protobuf::Message *message, | ||||
const google::protobuf::FieldDescriptor *field, | const google::protobuf::FieldDescriptor *field, | ||||
google::protobuf::Message *layer) { | |||||
google::protobuf::Message *layer) const { | |||||
GELOGD("Start to parse field: %s.", field->name().c_str()); | GELOGD("Start to parse field: %s.", field->name().c_str()); | ||||
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); | domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); | ||||
string filed_name = field->name(); | string filed_name = field->name(); | ||||
@@ -181,7 +181,7 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||||
* @return FAILED parse failed | * @return FAILED parse failed | ||||
*/ | */ | ||||
Status ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path, | Status ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path, | ||||
const string &custom_proto_name, std::vector<ge::Operator> &operators); | |||||
const string &custom_proto_name, std::vector<ge::Operator> &operators) const; | |||||
/* | /* | ||||
* @ingroup domi_omg | * @ingroup domi_omg | ||||
@@ -401,7 +401,7 @@ class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser { | |||||
Status ParseLayerField(const google::protobuf::Reflection *reflection, | Status ParseLayerField(const google::protobuf::Reflection *reflection, | ||||
const google::protobuf::Message *message, | const google::protobuf::Message *message, | ||||
const google::protobuf::FieldDescriptor *field, | const google::protobuf::FieldDescriptor *field, | ||||
google::protobuf::Message *layer); | |||||
google::protobuf::Message *layer) const; | |||||
Status ConvertBlobsProto(const google::protobuf::Message *message, | Status ConvertBlobsProto(const google::protobuf::Message *message, | ||||
google::protobuf::Message *blobs) const; | google::protobuf::Message *blobs) const; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -36,6 +36,7 @@ const int kAnchorIndexZero = 0; | |||||
const int kAnchorIndexOne = 1; | const int kAnchorIndexOne = 1; | ||||
const int32_t RESHAPE_AXIS_DEFAULT_VALUE = 0; | const int32_t RESHAPE_AXIS_DEFAULT_VALUE = 0; | ||||
const int32_t RESHAPE_NUM_AXES_DEFAULT_VALUE = -1; | const int32_t RESHAPE_NUM_AXES_DEFAULT_VALUE = -1; | ||||
const char *kReshape = "Reshape"; | |||||
} // namespace | } // namespace | ||||
Status CaffeReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | Status CaffeReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | ||||
@@ -127,9 +128,8 @@ Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) { | |||||
for (size_t i = 0; i < dims_size; ++i) { | for (size_t i = 0; i < dims_size; ++i) { | ||||
data[i] = attr_shape[i]; | data[i] = attr_shape[i]; | ||||
} | } | ||||
GE_IF_BOOL_EXEC( | |||||
constTensor->SetData(PtrToPtr<int64_t, uint8_t>(data.get()), dims_size * sizeof(int64_t)) != ge::GRAPH_SUCCESS, | |||||
GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
GE_IF_BOOL_EXEC(constTensor->SetData(PtrToPtr<int64_t, uint8_t>(data.get()), dims_size * sizeof(int64_t)) != | |||||
ge::GRAPH_SUCCESS, GELOGW("SetData failed for GeTensor.");); // no need to return | |||||
// construct const node and add edge | // construct const node and add edge | ||||
auto const_opdesc = ge::OpDescUtils::CreateConstOp(constTensor); | auto const_opdesc = ge::OpDescUtils::CreateConstOp(constTensor); | ||||
@@ -151,5 +151,5 @@ Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
REGISTER_OP_PARSER_CREATOR(CAFFE, RESHAPE, CaffeReshapeParser); | |||||
REGISTER_OP_PARSER_CREATOR(CAFFE, kReshape, CaffeReshapeParser); | |||||
} // namespace ge | } // namespace ge |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -33,6 +33,9 @@ using GeTensorDesc = ge::GeTensorDesc; | |||||
using namespace ge::parser; | using namespace ge::parser; | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
const char *kConstant = "Const"; | |||||
} | |||||
Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) { | Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) { | ||||
int64_t data_type = tensor_proto.data_type(); | int64_t data_type = tensor_proto.data_type(); | ||||
if (ge::OnnxUtil::ConvertOnnxDataType(data_type) == ge::DataType::DT_UNDEFINED) { | if (ge::OnnxUtil::ConvertOnnxDataType(data_type) == ge::DataType::DT_UNDEFINED) { | ||||
@@ -47,27 +50,22 @@ Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_ | |||||
} | } | ||||
std::map<uint32_t, int32_t> datatype_val_size_map = { | std::map<uint32_t, int32_t> datatype_val_size_map = { | ||||
// for int32, uint8, int8, uint16, int16, bool, and float16 values | |||||
{OnnxDataType::INT32, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::UINT8, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::INT8, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::UINT16, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::INT16, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::BOOL, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::FLOAT16, tensor_proto.int32_data_size()}, | |||||
// for int64 values | |||||
{OnnxDataType::INT64, tensor_proto.int64_data_size()}, | |||||
// for string values | |||||
{OnnxDataType::STRING, tensor_proto.string_data_size()}, | |||||
// for float and complex64 values | |||||
{OnnxDataType::FLOAT, tensor_proto.float_data_size()}, | |||||
{OnnxDataType::COMPLEX64, tensor_proto.float_data_size()}, | |||||
// for double and complex128 values | |||||
{OnnxDataType::DOUBLE, tensor_proto.double_data_size()}, | |||||
{OnnxDataType::COMPLEX128, tensor_proto.double_data_size()}, | |||||
// for uint64 and uint32 values | |||||
{OnnxDataType::UINT64, tensor_proto.uint64_data_size()}, | |||||
{OnnxDataType::UINT32, tensor_proto.uint64_data_size()}, | |||||
// for int32, uint8, int8, uint16, int16, bool, and float16 values | |||||
{OnnxDataType::INT32, tensor_proto.int32_data_size()}, {OnnxDataType::UINT8, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::INT8, tensor_proto.int32_data_size()}, {OnnxDataType::UINT16, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::INT16, tensor_proto.int32_data_size()}, {OnnxDataType::BOOL, tensor_proto.int32_data_size()}, | |||||
{OnnxDataType::FLOAT16, tensor_proto.int32_data_size()}, | |||||
// for int64 values | |||||
{OnnxDataType::INT64, tensor_proto.int64_data_size()}, | |||||
// for string values | |||||
{OnnxDataType::STRING, tensor_proto.string_data_size()}, | |||||
// for float and complex64 values | |||||
{OnnxDataType::FLOAT, tensor_proto.float_data_size()}, {OnnxDataType::COMPLEX64, tensor_proto.float_data_size()}, | |||||
// for double and complex128 values | |||||
{OnnxDataType::DOUBLE, tensor_proto.double_data_size()}, | |||||
{OnnxDataType::COMPLEX128, tensor_proto.double_data_size()}, | |||||
// for uint64 and uint32 values | |||||
{OnnxDataType::UINT64, tensor_proto.uint64_data_size()}, {OnnxDataType::UINT32, tensor_proto.uint64_data_size()}, | |||||
}; | }; | ||||
int32_t datatype_val_size = 0; | int32_t datatype_val_size = 0; | ||||
@@ -199,7 +197,7 @@ Status OnnxConstantParser::ParseConvertDataType(const ge::onnx::TensorProto &ten | |||||
Status OnnxConstantParser::ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def) { | Status OnnxConstantParser::ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const NodeProto *node = reinterpret_cast<const NodeProto *>(op_src); | |||||
const NodeProto *node = PtrToPtr<const ge::onnx::NodeProto, const NodeProto>(op_src); | |||||
// Get const Tensor from node | // Get const Tensor from node | ||||
Tensor tensor; | Tensor tensor; | ||||
@@ -226,7 +224,7 @@ Status OnnxConstantParser::ParseConstFromInput(const ge::onnx::NodeProto *op_src | |||||
Status OnnxConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | Status OnnxConstantParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const ge::onnx::NodeProto *node = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
const ge::onnx::NodeProto *node = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | GELOGD("Onnx op node name = %s, op type= %s, parse params", node->name().c_str(), node->op_type().c_str()); | ||||
@@ -237,5 +235,5 @@ Status OnnxConstantParser::ParseParams(const Message *op_src, ge::Operator &op_d | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
REGISTER_OP_PARSER_CREATOR(ONNX, CONSTANT, OnnxConstantParser); | |||||
REGISTER_OP_PARSER_CREATOR(ONNX, kConstant, OnnxConstantParser); | |||||
} // namespace ge | } // namespace ge |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -28,6 +28,9 @@ using domi::ONNX; | |||||
using namespace ge::parser; | using namespace ge::parser; | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
const char *kData = "Data"; | |||||
} | |||||
Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const ge::onnx::NodeProto *node_src = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | const ge::onnx::NodeProto *node_src = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | ||||
@@ -140,5 +143,5 @@ Status OnnxDataParser::ParseInputFromUser(const ge::Operator &op_def) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
REGISTER_OP_PARSER_CREATOR(ONNX, DATA, OnnxDataParser); | |||||
REGISTER_OP_PARSER_CREATOR(ONNX, kData, OnnxDataParser); | |||||
} // namespace ge | } // namespace ge |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2022 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -147,4 +147,4 @@ Status OnnxFileConstantParser::SetPathAttr(const ge::onnx::StringStringEntryProt | |||||
} | } | ||||
REGISTER_OP_PARSER_CREATOR(ONNX, kFileConstant, OnnxFileConstantParser); | REGISTER_OP_PARSER_CREATOR(ONNX, kFileConstant, OnnxFileConstantParser); | ||||
} // namespace ge | |||||
} // namespace ge |
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2022 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -48,7 +48,7 @@ | |||||
#include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
namespace { | namespace { | ||||
const std::string kLocation = "location"; | |||||
const char *kLocation = "location"; | |||||
} | } | ||||
namespace ge { | namespace ge { | ||||
@@ -671,13 +671,12 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
} | } | ||||
Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | ||||
if (input_node_names_.empty()) { | |||||
// subgraph might not have input, we use constant nodes as the start nodes of graph | |||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||||
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||||
if (node->op_type() == kOpTypeConstant) { | |||||
input_node_names_.emplace_back(node->name()); | |||||
} | |||||
// subgraph might not have input, or isolated const nodes exist in the graph, | |||||
// we use constant nodes as the start nodes of graph | |||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||||
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||||
if (node->op_type() == kOpTypeConstant) { | |||||
input_node_names_.emplace_back(node->name()); | |||||
} | } | ||||
} | } | ||||
for (auto in_name : input_node_names_) { | for (auto in_name : input_node_names_) { | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -45,7 +45,7 @@ namespace ge { | |||||
class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | ||||
public: | public: | ||||
OnnxModelParser() {} | OnnxModelParser() {} | ||||
virtual ~OnnxModelParser() {} | |||||
~OnnxModelParser() override {} | |||||
Status Parse(const char *file, ge::Graph &graph) override; | Status Parse(const char *file, ge::Graph &graph) override; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright (c) Huawei Technologies Co., Ltd. 2020~2022. All rights reserved. | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -292,18 +292,18 @@ TEST_F(STestCaffeParser, caffe_parser_ParseParamsForDummyData_test) | |||||
domi::caffe::NetParameter net; | domi::caffe::NetParameter net; | ||||
ge::OpDescPtr op = std::make_shared<ge::OpDesc>("conv", "Convolution"); | ge::OpDescPtr op = std::make_shared<ge::OpDesc>("conv", "Convolution"); | ||||
domi::caffe::LayerParameter *lay = net.add_layer(); | domi::caffe::LayerParameter *lay = net.add_layer(); | ||||
Status ret = caffe_parser.ParseParamsForDummyData(lay, op); | |||||
Status ret = caffe_parser.ParseParamsForDummyData(*lay, op); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
ret = caffe_parser.ParseParamsForInput(lay, op); | |||||
ret = caffe_parser.ParseParamsForInput(*lay, op); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
domi::caffe::DummyDataParameter *dummyData = lay->mutable_dummy_data_param(); | domi::caffe::DummyDataParameter *dummyData = lay->mutable_dummy_data_param(); | ||||
ret = caffe_parser.ParseParamsForDummyData(lay, op); | |||||
ret = caffe_parser.ParseParamsForDummyData(*lay, op); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
domi::caffe::BlobShape* dummpShape = dummyData->add_shape(); | domi::caffe::BlobShape* dummpShape = dummyData->add_shape(); | ||||
ret = caffe_parser.ParseParamsForDummyData(lay, op); | |||||
ret = caffe_parser.ParseParamsForDummyData(*lay, op); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
@@ -352,18 +352,18 @@ TEST_F(UtestCaffeParser, caffe_parser_ParseParamsForDummyData_test) | |||||
domi::caffe::NetParameter net; | domi::caffe::NetParameter net; | ||||
ge::OpDescPtr op = std::make_shared<ge::OpDesc>("conv", "Convolution"); | ge::OpDescPtr op = std::make_shared<ge::OpDesc>("conv", "Convolution"); | ||||
domi::caffe::LayerParameter *lay = net.add_layer(); | domi::caffe::LayerParameter *lay = net.add_layer(); | ||||
Status ret = caffe_parser.ParseParamsForDummyData(lay, op); | |||||
Status ret = caffe_parser.ParseParamsForDummyData(*lay, op); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
ret = caffe_parser.ParseParamsForInput(lay, op); | |||||
ret = caffe_parser.ParseParamsForInput(*lay, op); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
domi::caffe::DummyDataParameter *dummyData = lay->mutable_dummy_data_param(); | domi::caffe::DummyDataParameter *dummyData = lay->mutable_dummy_data_param(); | ||||
ret = caffe_parser.ParseParamsForDummyData(lay, op); | |||||
ret = caffe_parser.ParseParamsForDummyData(*lay, op); | |||||
EXPECT_EQ(ret, FAILED); | EXPECT_EQ(ret, FAILED); | ||||
domi::caffe::BlobShape* dummpShape = dummyData->add_shape(); | domi::caffe::BlobShape* dummpShape = dummyData->add_shape(); | ||||
ret = caffe_parser.ParseParamsForDummyData(lay, op); | |||||
ret = caffe_parser.ParseParamsForDummyData(*lay, op); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
@@ -703,4 +703,23 @@ TEST_F(UtestOnnxParser, onnx_test_TransNodeToOperator_SetTensorData) | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
TEST_F(UtestOnnxParser, onnx_test_const_input_op) | |||||
{ | |||||
ge::onnx::ModelProto model_proto; | |||||
ge::onnx::GraphProto* graph = model_proto.mutable_graph(); | |||||
ge::onnx::NodeProto *node_proto = graph->add_node(); | |||||
node_proto->set_op_type("Constant"); | |||||
node_proto->set_domain("const.onnx"); | |||||
node_proto->set_name("const_11"); | |||||
ge::OpDescPtr op_desc_src = std::make_shared<ge::OpDesc>("Constant", "const.onnx"); | |||||
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc_src); | |||||
std::string op_type = "Constant"; | |||||
OnnxModelParser onnx_parser; | |||||
std::vector<ge::Operator> input_ops; | |||||
onnx_parser.name_operator_["const_11"] = op; | |||||
Status ret = onnx_parser.GetGraphInputs(*graph, input_ops); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(input_ops.size() > 0, true); | |||||
} | |||||
} // namespace ge | } // namespace ge |