@@ -1 +1 @@ | |||||
Subproject commit f62cba4fdf845ffe04e5c1e37ea990d22c438910 | |||||
Subproject commit 51f76677af9299a919440416af70471f191380b8 |
@@ -21,6 +21,7 @@ | |||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
#include "parser/common/parser_utils.h" | #include "parser/common/parser_utils.h" | ||||
#include "graph/def_types.h" | |||||
using domi::ONNX; | using domi::ONNX; | ||||
using domi::ParseParamByOpFunc; | using domi::ParseParamByOpFunc; | ||||
@@ -29,7 +30,7 @@ using domi::ParseParamFunc; | |||||
namespace ge { | namespace ge { | ||||
Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { | Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
const ge::onnx::NodeProto *node_src = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | |||||
GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); | GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); | ||||
@@ -18,6 +18,7 @@ | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/def_types.h" | |||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
#include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
@@ -29,7 +30,7 @@ using namespace ge::parser; | |||||
namespace ge { | namespace ge { | ||||
Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
const ge::onnx::NodeProto *node_src = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | |||||
GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
GELOGD("Onnx op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op_type().c_str()); | GELOGD("Onnx op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op_type().c_str()); | ||||
if (ParseInputFromModel(op_src, op_def) != SUCCESS) { | if (ParseInputFromModel(op_src, op_def) != SUCCESS) { | ||||
@@ -73,7 +74,7 @@ int64_t OnnxDataParser::ParseInputTensor(const ge::onnx::AttributeProto &attribu | |||||
Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &op_def) { | Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &op_def) { | ||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
const ge::onnx::NodeProto *node = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||||
const ge::onnx::NodeProto *node = PtrToPtr<const Message, const ge::onnx::NodeProto>(op_src); | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
// Get attr t:'input_tensor' form NodeProto | // Get attr t:'input_tensor' form NodeProto | ||||
@@ -19,6 +19,7 @@ | |||||
#include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/def_types.h" | |||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
#include "register/register.h" | #include "register/register.h" | ||||
@@ -43,7 +44,7 @@ Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge | |||||
GELOGE(PARAM_INVALID, "Op src is null"); | GELOGE(PARAM_INVALID, "Op src is null"); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); | ||||
if (op_dest == nullptr) { | if (op_dest == nullptr) { | ||||
REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid"); | ||||
@@ -31,7 +31,7 @@ Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
const std::string name = op_desc->GetName(); | const std::string name = op_desc->GetName(); | ||||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
domi::tensorflow::AttrValue str_attr; | domi::tensorflow::AttrValue str_attr; | ||||
if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { | if (!TensorFlowUtil::FindAttrValue(node, ENTER_ATTR_FRAME_NAME, str_attr)) { | ||||
REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid", | REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid", | ||||
@@ -21,6 +21,7 @@ | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
#include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
#include "graph/def_types.h" | |||||
using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
using ge::parser::MERGE; | using ge::parser::MERGE; | ||||
@@ -30,7 +31,7 @@ Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr & | |||||
GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | |||||
const NodeDef *node = PtrToPtr<const Message, const NodeDef>(op_src); | |||||
domi::tensorflow::AttrValue attr_num; | domi::tensorflow::AttrValue attr_num; | ||||
if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { | if (!(TensorFlowUtil::FindAttrValue(node, ATTR_NAME_N, attr_num))) { | ||||
GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); | GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_desc->GetName().c_str(), ATTR_NAME_N.c_str()); | ||||
@@ -42,10 +42,9 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||||
ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return PARAM_INVALID); | return PARAM_INVALID); | ||||
// calculate size | // calculate size | ||||
int64_t tmp_dim = 0; | |||||
int64_t real_size = 1; | int64_t real_size = 1; | ||||
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | ||||
tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | ||||
real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
} | } | ||||
@@ -47,9 +47,8 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||||
return domi::PARAM_INVALID); | return domi::PARAM_INVALID); | ||||
// calculate size | // calculate size | ||||
int64_t real_size = 1; | int64_t real_size = 1; | ||||
int64_t tmp_dim = 0; | |||||
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | ||||
tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | ||||
PARSER_INT64_MULCHECK(real_size, tmp_dim); | PARSER_INT64_MULCHECK(real_size, tmp_dim); | ||||
real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
@@ -271,9 +271,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||||
GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(ge::TypeUtils::GetDataTypeLength(data_type, size_type), PARAM_INVALID, | ||||
"dataType no define size , parse ge_desc failed."); | "dataType no define size , parse ge_desc failed."); | ||||
// get size | // get size | ||||
int64_t tmp_dim = 0; | |||||
for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | ||||
tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
int64_t tmp_dim = ge_desc.GetShape().GetDim(j); | |||||
// The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. | // The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. | ||||
// Here, special treatment is given to the two operators. | // Here, special treatment is given to the two operators. | ||||