From 950558cab5b61913d94b94bdc8ef06370345ecc1 Mon Sep 17 00:00:00 2001 From: wxl Date: Sat, 5 Dec 2020 11:58:11 +0800 Subject: [PATCH] single op add check json file --- ge/offline/single_op_parser.cc | 32 ++++++++++++++++++++++++++++++++ ge/offline/single_op_parser.h | 11 +++++++++++ 2 files changed, 43 insertions(+) diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index d4b9c1c9..767634b5 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -28,6 +28,7 @@ #include "framework/common/util.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/op_desc_utils.h" +#include "graph/utils/type_utils.h" #include "graph/operator_factory_impl.h" using Json = nlohmann::json; @@ -176,6 +177,7 @@ T GetValue(const map &dict, string &key, T default_val) { } void from_json(const Json &j, SingleOpTensorDesc &desc) { + JsonTensorVeriry tensor_verify_result; desc.dims = j.at(kKeyShape).get>(); auto it = j.find(kKeyShapeRange); if (it != j.end()) { @@ -187,6 +189,11 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { } string format_str = j.at(kKeyFormat).get(); string type_str = j.at(kKeyType).get(); + tensor_verify_result.is_format_valid = TypeUtils::IsFormatValid(format_str); + tensor_verify_result.format = format_str; + tensor_verify_result.is_dtype_valid = TypeUtils::IsFormatValid(type_str); + tensor_verify_result.dtye = type_str; + desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); it = j.find(kKeyOriginFormat); @@ -198,10 +205,12 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { if (tensor_name != j.end()) { desc.name = tensor_name->get(); } + tensor_verify_result.tensor_name = desc.name; auto dynamic_input_name = j.find(kKeyDynamicInput); if (dynamic_input_name != j.end()) { desc.dynamic_input_name = dynamic_input_name->get(); } + json_op_valid_result_.emplace_back(tensor_verify_result); } void from_json(const Json &j, SingleOpAttr &attr) { @@ -350,6 +359,22 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { return true; } +Status ValidateSingleOpJson() { + for (const auto &r : json_op_valid_result_) { + if (!r.is_format_valid) { + string err_str = "json tensor format invalid.Tensor name is [" + r.tensor_name + "], format is " + r.format; + GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, err_str); + return PARAM_INVALID; + } + if (!r.is_dtyep_valid) { + string err_str = "json tensor datatype invalid.Tensor name is [" + r.tensor_name + "], datatype is " + r.dtype; + GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, err_str); + return PARAM_INVALID; + } + } + return SUCCESS; +} + std::unique_ptr SingleOpParser::CreateOpDesc(const string &op_type) { return std::unique_ptr(new(std::nothrow) OpDesc(op_type, op_type)); } @@ -556,6 +581,13 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector &op_list); private: + struct JsonTensorVeriry { + bool is_format_valid = true; + bool is_dtype_valid = true; + std::string tensor_name; + std::string format; + std::string dtype; + } + static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj); static bool Validate(const SingleOpDesc &op_desc); + static Status ValidateSingleOpJson(); static std::unique_ptr CreateOpDesc(const std::string &op_type); static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); static Status UpdateDynamicTensorName(std::vector &desc); @@ -78,6 +87,8 @@ class SingleOpParser { static Status SetShapeRange(const std::string &op_name, const SingleOpTensorDesc &tensor_desc, GeTensorDesc &ge_tensor_desc); + + static std::vector json_op_valid_result_; }; } // namespace ge