|
|
@@ -89,6 +89,7 @@ map<string, DataType> kDataTypeDict = { |
|
|
|
{"float", DT_FLOAT}, |
|
|
|
{"float32", DT_FLOAT}, |
|
|
|
{"double", DT_DOUBLE}, |
|
|
|
{"complex64", DT_COMPLEX64} |
|
|
|
}; |
|
|
|
|
|
|
|
map<string, Format> kFormatDict = { |
|
|
@@ -133,6 +134,22 @@ map<string, Format> kFormatDict = { |
|
|
|
{"fractal_z_g", FORMAT_FRACTAL_Z_G} |
|
|
|
}; |
|
|
|
|
|
|
|
map<string, DataType> kDataTypeStringToEnum = { |
|
|
|
{"DT_BOOL", DT_BOOL}, |
|
|
|
{"DT_INT8", DT_INT8}, |
|
|
|
{"DT_UINT8", DT_UINT8}, |
|
|
|
{"DT_INT16", DT_INT16}, |
|
|
|
{"DT_UINT16", DT_UINT16}, |
|
|
|
{"DT_INT32", DT_INT32}, |
|
|
|
{"DT_UINT32", DT_UINT32}, |
|
|
|
{"DT_INT64", DT_INT64}, |
|
|
|
{"DT_UINT64", DT_UINT64}, |
|
|
|
{"DT_FLOAT16", DT_FLOAT16}, |
|
|
|
{"DT_FLOAT", DT_FLOAT}, |
|
|
|
{"DT_DOUBLE", DT_DOUBLE}, |
|
|
|
{"DT_COMPLEX64", DT_COMPLEX64} |
|
|
|
}; |
|
|
|
|
|
|
|
std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) { |
|
|
|
std::stringstream file_name_ss; |
|
|
|
file_name_ss << index; |
|
|
@@ -161,9 +178,13 @@ std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) { |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
template<typename T> |
|
|
|
void SetAttrValue(const Json &j, SingleOpAttr &attr) { |
|
|
|
attr.value.SetValue<T>(j.at(kKeyValue).get<T>()); |
|
|
|
bool AttrValueIsString(const Json &j, const string &key) { |
|
|
|
try { |
|
|
|
string tmp_str = j.at(key).get<string>(); |
|
|
|
return true; |
|
|
|
} catch (Json::type_error &e) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template<typename T> |
|
|
@@ -177,6 +198,25 @@ T GetValue(const map<string, T> &dict, string &key, T default_val) { |
|
|
|
return it->second; |
|
|
|
} |
|
|
|
|
|
|
|
template<typename T> |
|
|
|
void SetAttrValue(const Json &j, SingleOpAttr &attr) { |
|
|
|
// when attr type is "data_type", we support two kinds of attr value. |
|
|
|
// 1. value: "DT_FLOAT", "DT_INT32", "DT_INT8" ... |
|
|
|
// 2. value: 1, 3 ... |
|
|
|
if (j.at(kKeyType).get<string>() == "data_type") { |
|
|
|
if (AttrValueIsString(j, kKeyValue)) { |
|
|
|
string type_str = j.at(kKeyValue).get<string>(); |
|
|
|
DataType dtype = GetValue(kDataTypeStringToEnum, type_str, DT_UNDEFINED); |
|
|
|
attr.value.SetValue<DataType>(dtype); |
|
|
|
return; |
|
|
|
} else { |
|
|
|
attr.value.SetValue<T>(j.at(kKeyValue).get<T>()); |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
attr.value.SetValue<T>(j.at(kKeyValue).get<T>()); |
|
|
|
} |
|
|
|
|
|
|
|
void from_json(const Json &j, SingleOpTensorDesc &desc) { |
|
|
|
bool is_tensor_valid = true; |
|
|
|
desc.dims = j.at(kKeyShape).get<vector<int64_t>>(); |
|
|
|