Browse Source

!1759 add complex64 support, add data_type support and fix release error in aclgrphGenerateForOp

From: @lichun30
Reviewed-by: @xchu42
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
7e0c912bc3
2 changed files with 25 additions and 4 deletions
  1. +3
    -1
      ge/generator/ge_generator.cc
  2. +22
    -3
      ge/offline/single_op_parser.cc

+ 3
- 1
ge/generator/ge_generator.cc View File

@@ -452,7 +452,9 @@ Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &o

Status GeGenerator::Finalize() {
ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize);
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
if (impl_ == nullptr) {
return SUCCESS;
}
Status ret = impl_->graph_manager_.Finalize();
if (ret != SUCCESS) {
GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "[Call][Finalize] Graph manager finalize failed.");


+ 22
- 3
ge/offline/single_op_parser.cc View File

@@ -89,6 +89,7 @@ map<string, DataType> kDataTypeDict = {
{"float", DT_FLOAT},
{"float32", DT_FLOAT},
{"double", DT_DOUBLE},
{"complex64", DT_COMPLEX64}
};

map<string, Format> kFormatDict = {
@@ -161,9 +162,13 @@ std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) {
}
} // namespace

template<typename T>
void SetAttrValue(const Json &j, SingleOpAttr &attr) {
attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
bool AttrValueIsString(const Json &j, const string &key) {
try {
string tmp_str = j.at(key).get<string>();
return true;
} catch (Json::type_error &e) {
return false;
}
}

template<typename T>
@@ -177,6 +182,20 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) {
return it->second;
}

template<typename T>
void SetAttrValue(const Json &j, SingleOpAttr &attr) {
// when attr type is "data_type", we support two kinds of attr value.
// 1. value: "DT_FLOAT", "DT_INT32", "DT_INT8" ...
// 2. value: 1, 3 ...
if (j.at(kKeyType).get<string>() == "data_type" && AttrValueIsString(j, kKeyValue)) {
string type_str = j.at(kKeyValue).get<string>();
DataType dtype = TypeUtils::SerialStringToDataType(type_str);
attr.value.SetValue<DataType>(dtype);
return;
}
attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
}

void from_json(const Json &j, SingleOpTensorDesc &desc) {
bool is_tensor_valid = true;
desc.dims = j.at(kKeyShape).get<vector<int64_t>>();


Loading…
Cancel
Save