From 6907537a214141a12d5b7020fb14952f606378e7 Mon Sep 17 00:00:00 2001 From: wxl Date: Sat, 28 Nov 2020 14:11:32 +0800 Subject: [PATCH] bugfix --- ge/ir_build/atc_ir_common.cc | 5 +++++ ge/offline/single_op_parser.cc | 29 +++++++++++++++++++++++++++++ ge/offline/single_op_parser.h | 10 ++++++++++ 3 files changed, 44 insertions(+) diff --git a/ge/ir_build/atc_ir_common.cc b/ge/ir_build/atc_ir_common.cc index eaff928b..3ac6928b 100755 --- a/ge/ir_build/atc_ir_common.cc +++ b/ge/ir_build/atc_ir_common.cc @@ -19,6 +19,7 @@ #include "framework/common/string_util.h" #include "framework/common/types.h" #include "framework/common/util.h" +#include "graph/utils/type_utils.h" using std::pair; using std::string; @@ -106,6 +107,10 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> bool CheckDynamicImagesizeInputShapeValid(unordered_map> shape_map, const std::string input_format, std::string &dynamic_image_size) { + if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(const_cast(input_format)) { + GELOGE(ge::PARAM_INVALID, "user input format [%s] is nod found!", input_format.c_str()); + return false; + } int32_t size = 0; for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) { vector shape = iter->second; diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index d4b9c1c9..68621942 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -28,6 +28,7 @@ #include "framework/common/util.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" #include "graph/operator_factory_impl.h" using Json = nlohmann::json; @@ -176,6 +177,7 @@ T GetValue(const map &dict, string &key, T default_val) { } void from_json(const Json &j, SingleOpTensorDesc &desc) { + JsonTensorVeriry tensor_verify_result; desc.dims = j.at(kKeyShape).get>(); auto it = j.find(kKeyShapeRange); if (it != j.end()) { @@ -187,6 +189,11 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { } string format_str = j.at(kKeyFormat).get(); string type_str = j.at(kKeyType).get(); + tensor_verify_result.is_format_valid = TypeUtils::IsFormatValid(format_str); + tensor_verify_result.format = format_str; + tensor_verify_result.is_dtype_valid = TypeUtils::IsFormatValid(type_str); + tensor_verify_result.dtye = type_str; + desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); it = j.find(kKeyOriginFormat); @@ -202,6 +209,7 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { if (dynamic_input_name != j.end()) { desc.dynamic_input_name = dynamic_input_name->get(); } + json_op_valid_result_.emplace_back(tensor_verify_result); } void from_json(const Json &j, SingleOpAttr &attr) { @@ -350,6 +358,20 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { return true; } +Status ValidateSingleOpJson() { + for (const auto &r : json_op_valid_result_) { + if (!r.is_format_valid) { + GELOGE(PARAM_INVALID, "input format:%s is not defined!", r.format); + return PARAM_INVALID; + } + if (!r.is_dtyep_valid) { + GELOGE(PARAM_INVALID, "input format:%s is not defined!", r.dtype); + return PARAM_INVALID; + } + } + return SUCCESS; +} + std::unique_ptr SingleOpParser::CreateOpDesc(const string &op_type) { return std::unique_ptr(new(std::nothrow) OpDesc(op_type, op_type)); } @@ -556,6 +578,13 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector &op_list); private: + struct JsonTensorVeriry { + bool is_format_valid = true; + bool is_dtype_valid = true; + std::string format; + std::string dtype; + } + static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj); static bool Validate(const SingleOpDesc &op_desc); + static Status ValidateSingleOpJson(); static std::unique_ptr CreateOpDesc(const std::string &op_type); static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); static Status UpdateDynamicTensorName(std::vector &desc); @@ -78,6 +86,8 @@ class SingleOpParser { static Status SetShapeRange(const std::string &op_name, const SingleOpTensorDesc &tensor_desc, GeTensorDesc &ge_tensor_desc); + + static std::vector json_op_valid_result_; }; } // namespace ge