Browse Source

single op add check json file

pull/532/head
wxl 4 years ago
parent
commit
950558cab5
2 changed files with 43 additions and 0 deletions
  1. +32
    -0
      ge/offline/single_op_parser.cc
  2. +11
    -0
      ge/offline/single_op_parser.h

+ 32
- 0
ge/offline/single_op_parser.cc View File

@@ -28,6 +28,7 @@
#include "framework/common/util.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/operator_factory_impl.h"

using Json = nlohmann::json;
@@ -176,6 +177,7 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) {
}

void from_json(const Json &j, SingleOpTensorDesc &desc) {
JsonTensorVeriry tensor_verify_result;
desc.dims = j.at(kKeyShape).get<vector<int64_t>>();
auto it = j.find(kKeyShapeRange);
if (it != j.end()) {
@@ -187,6 +189,11 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
}
string format_str = j.at(kKeyFormat).get<string>();
string type_str = j.at(kKeyType).get<string>();
tensor_verify_result.is_format_valid = TypeUtils::IsFormatValid(format_str);
tensor_verify_result.format = format_str;
tensor_verify_result.is_dtype_valid = TypeUtils::IsFormatValid(type_str);
tensor_verify_result.dtye = type_str;

desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
it = j.find(kKeyOriginFormat);
@@ -198,10 +205,12 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
if (tensor_name != j.end()) {
desc.name = tensor_name->get<string>();
}
tensor_verify_result.tensor_name = desc.name;
auto dynamic_input_name = j.find(kKeyDynamicInput);
if (dynamic_input_name != j.end()) {
desc.dynamic_input_name = dynamic_input_name->get<string>();
}
json_op_valid_result_.emplace_back(tensor_verify_result);
}

void from_json(const Json &j, SingleOpAttr &attr) {
@@ -350,6 +359,22 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {
return true;
}

Status ValidateSingleOpJson() {
for (const auto &r : json_op_valid_result_) {
if (!r.is_format_valid) {
string err_str = "json tensor format invalid.Tensor name is [" + r.tensor_name + "], format is " + r.format;
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, err_str);
return PARAM_INVALID;
}
if (!r.is_dtyep_valid) {
string err_str = "json tensor datatype invalid.Tensor name is [" + r.tensor_name + "], datatype is " + r.dtype;
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, err_str);
return PARAM_INVALID;
}
}
return SUCCESS;
}

std::unique_ptr<OpDesc> SingleOpParser::CreateOpDesc(const string &op_type) {
return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type));
}
@@ -556,6 +581,13 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<Si
SingleOpDesc single_op_desc;
GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str());
single_op_desc = single_op_json;

ret = ValidateSingleOpJson();
if (ret != SUCCESS) {
GELOGE(PARAM_INVALID, "user json file param is invalid!");
return ret;
}

if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) {
GELOGE(FAILED, "Update dynamic tensor name failed!");
return FAILED;


+ 11
- 0
ge/offline/single_op_parser.h View File

@@ -69,8 +69,17 @@ class SingleOpParser {
static Status ParseSingleOpList(const std::string &file, std::vector<SingleOpBuildParam> &op_list);

private:
struct JsonTensorVeriry {
bool is_format_valid = true;
bool is_dtype_valid = true;
std::string tensor_name;
std::string format;
std::string dtype;
}

static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj);
static bool Validate(const SingleOpDesc &op_desc);
static Status ValidateSingleOpJson();
static std::unique_ptr<OpDesc> CreateOpDesc(const std::string &op_type);
static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param);
static Status UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc);
@@ -78,6 +87,8 @@ class SingleOpParser {
static Status SetShapeRange(const std::string &op_name,
const SingleOpTensorDesc &tensor_desc,
GeTensorDesc &ge_tensor_desc);

static std::vector<JsonTensorVeriry> json_op_valid_result_;
};
} // namespace ge



Loading…
Cancel
Save