From 3f3e0384c8c31fa67c54754341d9799b8c7de230 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 8 Jun 2021 10:46:19 +0800 Subject: [PATCH] add complex64 support, add data_type support and fix release error in aclgrphGenerateForOp --- ge/generator/ge_generator.cc | 4 +++- ge/offline/single_op_parser.cc | 46 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 8a94aa9b..1796d424 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -452,7 +452,9 @@ Status GeGenerator::Initialize(const map &options, OmgContext &o Status GeGenerator::Finalize() { ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize); - GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); + if (impl_ == nullptr) { + return SUCCESS; + } Status ret = impl_->graph_manager_.Finalize(); if (ret != SUCCESS) { GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "[Call][Finalize] Graph manager finalize failed."); diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index ce9448d5..94d4d579 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -89,6 +89,7 @@ map kDataTypeDict = { {"float", DT_FLOAT}, {"float32", DT_FLOAT}, {"double", DT_DOUBLE}, + {"complex64", DT_COMPLEX64} }; map kFormatDict = { @@ -133,6 +134,22 @@ map kFormatDict = { {"fractal_z_g", FORMAT_FRACTAL_Z_G} }; +map kDataTypeStringToEnum = { + {"DT_BOOL", DT_BOOL}, + {"DT_INT8", DT_INT8}, + {"DT_UINT8", DT_UINT8}, + {"DT_INT16", DT_INT16}, + {"DT_UINT16", DT_UINT16}, + {"DT_INT32", DT_INT32}, + {"DT_UINT32", DT_UINT32}, + {"DT_INT64", DT_INT64}, + {"DT_UINT64", DT_UINT64}, + {"DT_FLOAT16", DT_FLOAT16}, + {"DT_FLOAT", DT_FLOAT}, + {"DT_DOUBLE", DT_DOUBLE}, + {"DT_COMPLEX64", DT_COMPLEX64} +}; + std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) { std::stringstream file_name_ss; file_name_ss << index; @@ -161,9 +178,13 @@ std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) { } } // namespace -template -void SetAttrValue(const Json &j, SingleOpAttr &attr) { - attr.value.SetValue(j.at(kKeyValue).get()); +bool AttrValueIsString(const Json &j, const string &key) { + try { + string tmp_str = j.at(key).get(); + return true; + } catch (Json::type_error &e) { + return false; + } } template @@ -177,6 +198,25 @@ T GetValue(const map &dict, string &key, T default_val) { return it->second; } +template +void SetAttrValue(const Json &j, SingleOpAttr &attr) { + // when attr type is "data_type", we support two kinds of attr value. + // 1. value: "DT_FLOAT", "DT_INT32", "DT_INT8" ... + // 2. value: 1, 3 ... + if (j.at(kKeyType).get() == "data_type") { + if (AttrValueIsString(j, kKeyValue)) { + string type_str = j.at(kKeyValue).get(); + DataType dtype = GetValue(kDataTypeStringToEnum, type_str, DT_UNDEFINED); + attr.value.SetValue(dtype); + return; + } else { + attr.value.SetValue(j.at(kKeyValue).get()); + return; + } + } + attr.value.SetValue(j.at(kKeyValue).get()); +} + void from_json(const Json &j, SingleOpTensorDesc &desc) { bool is_tensor_valid = true; desc.dims = j.at(kKeyShape).get>();