Browse Source

Pre Merge pull request !580 from 万学磊/development

pull/580/MERGE
万学磊 Gitee 4 years ago
parent
commit
fdbb9aeb54
2 changed files with 28 additions and 0 deletions
  1. +22
    -0
      ge/offline/single_op_parser.cc
  2. +6
    -0
      ge/offline/single_op_parser.h

+ 22
- 0
ge/offline/single_op_parser.cc View File

@@ -26,6 +26,8 @@
#include "common/util/error_manager/error_manager.h" #include "common/util/error_manager/error_manager.h"
#include "common/ge_inner_error_codes.h" #include "common/ge_inner_error_codes.h"
#include "framework/common/util.h" #include "framework/common/util.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "graph/utils/op_desc_utils.h" #include "graph/utils/op_desc_utils.h"
#include "graph/operator_factory_impl.h" #include "graph/operator_factory_impl.h"
@@ -52,6 +54,7 @@ constexpr char const *kKeyOriginFormat = "origin_format";
constexpr char const *kFileSuffix = ".om"; constexpr char const *kFileSuffix = ".om";
constexpr char const *kKeyDynamicInput = "dynamic_input"; constexpr char const *kKeyDynamicInput = "dynamic_input";
constexpr char const *kKeyDynamicOutput = "dynamic_output"; constexpr char const *kKeyDynamicOutput = "dynamic_output";
constexpr char const *kSingleOpTensorDescValid = "_ge_single_op_tensor_desc_valid";
constexpr int kDumpJsonIndent = 2; constexpr int kDumpJsonIndent = 2;
constexpr int kShapeRangePairSize = 2; constexpr int kShapeRangePairSize = 2;
constexpr int kShapeRangeLow = 0; constexpr int kShapeRangeLow = 0;
@@ -176,6 +179,7 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) {
} }


void from_json(const Json &j, SingleOpTensorDesc &desc) { void from_json(const Json &j, SingleOpTensorDesc &desc) {
bool is_tensor_valid = true;
desc.dims = j.at(kKeyShape).get<vector<int64_t>>(); desc.dims = j.at(kKeyShape).get<vector<int64_t>>();
auto it = j.find(kKeyShapeRange); auto it = j.find(kKeyShapeRange);
if (it != j.end()) { if (it != j.end()) {
@@ -187,11 +191,14 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
} }
string format_str = j.at(kKeyFormat).get<string>(); string format_str = j.at(kKeyFormat).get<string>();
string type_str = j.at(kKeyType).get<string>(); string type_str = j.at(kKeyType).get<string>();
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(format_str);
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsDataTypeValid(type_str);
desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
it = j.find(kKeyOriginFormat); it = j.find(kKeyOriginFormat);
if (it != j.end()) { if (it != j.end()) {
string origin_format_str = j.at(kKeyOriginFormat).get<string>(); string origin_format_str = j.at(kKeyOriginFormat).get<string>();
is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(origin_format_str);
desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED); desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED);
} }
auto tensor_name = j.find(kKeyName); auto tensor_name = j.find(kKeyName);
@@ -202,6 +209,9 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
if (dynamic_input_name != j.end()) { if (dynamic_input_name != j.end()) {
desc.dynamic_input_name = dynamic_input_name->get<string>(); desc.dynamic_input_name = dynamic_input_name->get<string>();
} }
if (!is_tensor_valid) {
desc.SetValidFlag(is_tensor_valid);
}
} }


void from_json(const Json &j, SingleOpAttr &attr) { void from_json(const Json &j, SingleOpAttr &attr) {
@@ -305,6 +315,12 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {


int index = 0; int index = 0;
for (auto &tensor_desc : op_desc.input_desc) { for (auto &tensor_desc : op_desc.input_desc) {
if (!tensor_desc.GetValidFlag()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"intput", "datatype or format", std::to_string(index)});
GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index);
return false;
}
if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) || if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) ||
(tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){ (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
@@ -317,6 +333,12 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {


index = 0; index = 0;
for (auto &tensor_desc : op_desc.output_desc) { for (auto &tensor_desc : op_desc.output_desc) {
if (!tensor_desc.GetValidFlag()) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "datatype", std::to_string(index)});
GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index);
return false;
}
if (tensor_desc.type == DT_UNDEFINED) { if (tensor_desc.type == DT_UNDEFINED) {
ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"}, ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "type", "index"},
{"output", "datatype", std::to_string(index)}); {"output", "datatype", std::to_string(index)});


+ 6
- 0
ge/offline/single_op_parser.h View File

@@ -28,6 +28,10 @@


namespace ge { namespace ge {
struct SingleOpTensorDesc { struct SingleOpTensorDesc {
public:
bool GetValidFlag() const { return is_valid_; }
void SetValidFlag(bool is_valid) { is_valid_ = is_valid; }
public:
std::string name; std::string name;
std::vector<int64_t> dims; std::vector<int64_t> dims;
std::vector<int64_t> ori_dims; std::vector<int64_t> ori_dims;
@@ -36,6 +40,8 @@ struct SingleOpTensorDesc {
ge::Format ori_format = ge::FORMAT_RESERVED; ge::Format ori_format = ge::FORMAT_RESERVED;
ge::DataType type = ge::DT_UNDEFINED; ge::DataType type = ge::DT_UNDEFINED;
std::string dynamic_input_name; std::string dynamic_input_name;
private:
bool is_valid_ = true;
}; };


struct SingleOpAttr { struct SingleOpAttr {


Loading…
Cancel
Save