Browse Source

bugfix

pull/451/head
wxl 4 years ago
parent
commit
6907537a21
3 changed files with 44 additions and 0 deletions
  1. +5
    -0
      ge/ir_build/atc_ir_common.cc
  2. +29
    -0
      ge/offline/single_op_parser.cc
  3. +10
    -0
      ge/offline/single_op_parser.h

+ 5
- 0
ge/ir_build/atc_ir_common.cc View File

@@ -19,6 +19,7 @@
#include "framework/common/string_util.h"
#include "framework/common/types.h"
#include "framework/common/util.h"
#include "graph/utils/type_utils.h"

using std::pair;
using std::string;
@@ -106,6 +107,10 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map<string, vector<int64_t>>

bool CheckDynamicImagesizeInputShapeValid(unordered_map<string, vector<int64_t>> shape_map,
const std::string input_format, std::string &dynamic_image_size) {
if (!input_format.empty() && !ge::TypeUtils::IsFormatValid(const_cast<string &>(input_format)) {
GELOGE(ge::PARAM_INVALID, "user input format [%s] is nod found!", input_format.c_str());
return false;
}
int32_t size = 0;
for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) {
vector<int64_t> shape = iter->second;


+ 29
- 0
ge/offline/single_op_parser.cc View File

@@ -28,6 +28,7 @@
#include "framework/common/util.h"
#include "graph/utils/tensor_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/type_utils.h"
#include "graph/operator_factory_impl.h"

using Json = nlohmann::json;
@@ -176,6 +177,7 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) {
}

void from_json(const Json &j, SingleOpTensorDesc &desc) {
JsonTensorVeriry tensor_verify_result;
desc.dims = j.at(kKeyShape).get<vector<int64_t>>();
auto it = j.find(kKeyShapeRange);
if (it != j.end()) {
@@ -187,6 +189,11 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
}
string format_str = j.at(kKeyFormat).get<string>();
string type_str = j.at(kKeyType).get<string>();
tensor_verify_result.is_format_valid = TypeUtils::IsFormatValid(format_str);
tensor_verify_result.format = format_str;
tensor_verify_result.is_dtype_valid = TypeUtils::IsFormatValid(type_str);
tensor_verify_result.dtye = type_str;

desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
it = j.find(kKeyOriginFormat);
@@ -202,6 +209,7 @@ void from_json(const Json &j, SingleOpTensorDesc &desc) {
if (dynamic_input_name != j.end()) {
desc.dynamic_input_name = dynamic_input_name->get<string>();
}
json_op_valid_result_.emplace_back(tensor_verify_result);
}

void from_json(const Json &j, SingleOpAttr &attr) {
@@ -350,6 +358,20 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {
return true;
}

Status ValidateSingleOpJson() {
for (const auto &r : json_op_valid_result_) {
if (!r.is_format_valid) {
GELOGE(PARAM_INVALID, "input format:%s is not defined!", r.format);
return PARAM_INVALID;
}
if (!r.is_dtyep_valid) {
GELOGE(PARAM_INVALID, "input format:%s is not defined!", r.dtype);
return PARAM_INVALID;
}
}
return SUCCESS;
}

std::unique_ptr<OpDesc> SingleOpParser::CreateOpDesc(const string &op_type) {
return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type));
}
@@ -556,6 +578,13 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<Si
SingleOpDesc single_op_desc;
GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str());
single_op_desc = single_op_json;

ret = ValidateSingleOpJson();
if (ret != SUCCESS) {
GELOGE(PARAM_INVALID, "user json file param is invalid!");
return ret;
}

if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) {
GELOGE(FAILED, "Update dynamic tensor name failed!");
return FAILED;


+ 10
- 0
ge/offline/single_op_parser.h View File

@@ -69,8 +69,16 @@ class SingleOpParser {
static Status ParseSingleOpList(const std::string &file, std::vector<SingleOpBuildParam> &op_list);

private:
struct JsonTensorVeriry {
bool is_format_valid = true;
bool is_dtype_valid = true;
std::string format;
std::string dtype;
}

static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj);
static bool Validate(const SingleOpDesc &op_desc);
static Status ValidateSingleOpJson();
static std::unique_ptr<OpDesc> CreateOpDesc(const std::string &op_type);
static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param);
static Status UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc);
@@ -78,6 +86,8 @@ class SingleOpParser {
static Status SetShapeRange(const std::string &op_name,
const SingleOpTensorDesc &tensor_desc,
GeTensorDesc &ge_tensor_desc);

static std::vector<JsonTensorVeriry> json_op_valid_result_;
};
} // namespace ge



Loading…
Cancel
Save