|
|
@@ -28,6 +28,7 @@ |
|
|
|
#include "framework/common/util.h" |
|
|
|
#include "graph/utils/tensor_utils.h" |
|
|
|
#include "graph/utils/op_desc_utils.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "graph/operator_factory_impl.h" |
|
|
|
|
|
|
|
using Json = nlohmann::json; |
|
|
@@ -176,6 +177,7 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) { |
|
|
|
} |
|
|
|
|
|
|
|
void from_json(const Json &j, SingleOpTensorDesc &desc) { |
|
|
|
JsonTensorVeriry tensor_verify_result; |
|
|
|
desc.dims = j.at(kKeyShape).get<vector<int64_t>>(); |
|
|
|
auto it = j.find(kKeyShapeRange); |
|
|
|
if (it != j.end()) { |
|
|
@@ -187,6 +189,11 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { |
|
|
|
} |
|
|
|
string format_str = j.at(kKeyFormat).get<string>(); |
|
|
|
string type_str = j.at(kKeyType).get<string>(); |
|
|
|
tensor_verify_result.is_format_valid = TypeUtils::IsFormatValid(format_str); |
|
|
|
tensor_verify_result.format = format_str; |
|
|
|
tensor_verify_result.is_dtype_valid = TypeUtils::IsFormatValid(type_str); |
|
|
|
tensor_verify_result.dtye = type_str; |
|
|
|
|
|
|
|
desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); |
|
|
|
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED); |
|
|
|
it = j.find(kKeyOriginFormat); |
|
|
@@ -202,6 +209,7 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) { |
|
|
|
if (dynamic_input_name != j.end()) { |
|
|
|
desc.dynamic_input_name = dynamic_input_name->get<string>(); |
|
|
|
} |
|
|
|
json_op_valid_result_.emplace_back(tensor_verify_result); |
|
|
|
} |
|
|
|
|
|
|
|
void from_json(const Json &j, SingleOpAttr &attr) { |
|
|
@@ -350,6 +358,20 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
Status ValidateSingleOpJson() { |
|
|
|
for (const auto &r : json_op_valid_result_) { |
|
|
|
if (!r.is_format_valid) { |
|
|
|
GELOGE(PARAM_INVALID, "input format:%s is not defined!", r.format); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
if (!r.is_dtyep_valid) { |
|
|
|
GELOGE(PARAM_INVALID, "input format:%s is not defined!", r.dtype); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<OpDesc> SingleOpParser::CreateOpDesc(const string &op_type) { |
|
|
|
return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type)); |
|
|
|
} |
|
|
@@ -556,6 +578,13 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<Si |
|
|
|
SingleOpDesc single_op_desc; |
|
|
|
GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str()); |
|
|
|
single_op_desc = single_op_json; |
|
|
|
|
|
|
|
ret = ValidateSingleOpJson(); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "user json file param is invalid!"); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Update dynamic tensor name failed!"); |
|
|
|
return FAILED; |
|
|
|