Browse Source

add complex64 support, add data_type support and fix release error in aclgrphGenerateForOp

tags/v1.3.0
lichun 4 years ago
parent
commit
3f3e0384c8
2 changed files with 46 additions and 4 deletions
  1. +3
    -1
      ge/generator/ge_generator.cc
  2. +43
    -3
      ge/offline/single_op_parser.cc

+ 3
- 1
ge/generator/ge_generator.cc View File

@@ -452,7 +452,9 @@ Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &o

Status GeGenerator::Finalize() {
ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize);
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
if (impl_ == nullptr) {
return SUCCESS;
}
Status ret = impl_->graph_manager_.Finalize();
if (ret != SUCCESS) {
GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "[Call][Finalize] Graph manager finalize failed.");


+ 43
- 3
ge/offline/single_op_parser.cc View File

@@ -89,6 +89,7 @@ map<string, DataType> kDataTypeDict = {
{"float", DT_FLOAT},
{"float32", DT_FLOAT},
{"double", DT_DOUBLE},
{"complex64", DT_COMPLEX64}
};

map<string, Format> kFormatDict = {
@@ -133,6 +134,22 @@ map<string, Format> kFormatDict = {
{"fractal_z_g", FORMAT_FRACTAL_Z_G}
};

map<string, DataType> kDataTypeStringToEnum = {
{"DT_BOOL", DT_BOOL},
{"DT_INT8", DT_INT8},
{"DT_UINT8", DT_UINT8},
{"DT_INT16", DT_INT16},
{"DT_UINT16", DT_UINT16},
{"DT_INT32", DT_INT32},
{"DT_UINT32", DT_UINT32},
{"DT_INT64", DT_INT64},
{"DT_UINT64", DT_UINT64},
{"DT_FLOAT16", DT_FLOAT16},
{"DT_FLOAT", DT_FLOAT},
{"DT_DOUBLE", DT_DOUBLE},
{"DT_COMPLEX64", DT_COMPLEX64}
};

std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) {
std::stringstream file_name_ss;
file_name_ss << index;
@@ -161,9 +178,13 @@ std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) {
}
} // namespace

template<typename T>
void SetAttrValue(const Json &j, SingleOpAttr &attr) {
attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
bool AttrValueIsString(const Json &j, const string &key) {
try {
string tmp_str = j.at(key).get<string>();
return true;
} catch (Json::type_error &e) {
return false;
}
}

template<typename T>
@@ -177,6 +198,25 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) {
return it->second;
}

template<typename T>
void SetAttrValue(const Json &j, SingleOpAttr &attr) {
// when attr type is "data_type", we support two kinds of attr value.
// 1. value: "DT_FLOAT", "DT_INT32", "DT_INT8" ...
// 2. value: 1, 3 ...
if (j.at(kKeyType).get<string>() == "data_type") {
if (AttrValueIsString(j, kKeyValue)) {
string type_str = j.at(kKeyValue).get<string>();
DataType dtype = GetValue(kDataTypeStringToEnum, type_str, DT_UNDEFINED);
attr.value.SetValue<DataType>(dtype);
return;
} else {
attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
return;
}
}
attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
}

void from_json(const Json &j, SingleOpTensorDesc &desc) {
bool is_tensor_valid = true;
desc.dims = j.at(kKeyShape).get<vector<int64_t>>();


Loading…
Cancel
Save