diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 8a94aa9b..1796d424 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -452,7 +452,9 @@ Status GeGenerator::Initialize(const map &options, OmgContext &o Status GeGenerator::Finalize() { ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize); - GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); + if (impl_ == nullptr) { + return SUCCESS; + } Status ret = impl_->graph_manager_.Finalize(); if (ret != SUCCESS) { GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "[Call][Finalize] Graph manager finalize failed."); diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index ce9448d5..5a8ca923 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -89,6 +89,7 @@ map kDataTypeDict = { {"float", DT_FLOAT}, {"float32", DT_FLOAT}, {"double", DT_DOUBLE}, + {"complex64", DT_COMPLEX64} }; map kFormatDict = { @@ -161,9 +162,13 @@ std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) { } } // namespace -template -void SetAttrValue(const Json &j, SingleOpAttr &attr) { - attr.value.SetValue(j.at(kKeyValue).get()); +bool AttrValueIsString(const Json &j, const string &key) { + try { + string tmp_str = j.at(key).get(); + return true; + } catch (Json::type_error &e) { + return false; + } } template @@ -177,6 +182,20 @@ T GetValue(const map &dict, string &key, T default_val) { return it->second; } +template +void SetAttrValue(const Json &j, SingleOpAttr &attr) { + // when attr type is "data_type", we support two kinds of attr value. + // 1. value: "DT_FLOAT", "DT_INT32", "DT_INT8" ... + // 2. value: 1, 3 ... + if (j.at(kKeyType).get() == "data_type" && AttrValueIsString(j, kKeyValue)) { + string type_str = j.at(kKeyValue).get(); + DataType dtype = TypeUtils::SerialStringToDataType(type_str); + attr.value.SetValue(dtype); + return; + } + attr.value.SetValue(j.at(kKeyValue).get()); +} + void from_json(const Json &j, SingleOpTensorDesc &desc) { bool is_tensor_valid = true; desc.dims = j.at(kKeyShape).get>();