You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphfiles.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. """ Utilities function to manage graph files
  2. """
  3. def loadCT(filename):
  4. """load data from .ct file.
  5. Notes
  6. ------
  7. a typical example of data in .ct is like this:
  8. 3 2 <- number of nodes and edges
  9. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  10. 0.0000 0.0000 0.0000 C
  11. 0.0000 0.0000 0.0000 O
  12. 1 3 1 1 <- each line describes an edge : to, from,?, label
  13. 2 3 1 1
  14. """
  15. import networkx as nx
  16. from os.path import basename
  17. g = nx.Graph()
  18. with open(filename) as f:
  19. content = f.read().splitlines()
  20. g = nx.Graph(
  21. name = str(content[0]),
  22. filename = basename(filename)) # set name of the graph
  23. tmp = content[1].split(" ")
  24. if tmp[0] == '':
  25. nb_nodes = int(tmp[1]) # number of the nodes
  26. nb_edges = int(tmp[2]) # number of the edges
  27. else:
  28. nb_nodes = int(tmp[0])
  29. nb_edges = int(tmp[1])
  30. # patch for compatibility : label will be removed later
  31. for i in range(0, nb_nodes):
  32. tmp = content[i + 2].split(" ")
  33. tmp = [x for x in tmp if x != '']
  34. g.add_node(i, atom=tmp[3], label=tmp[3])
  35. for i in range(0, nb_edges):
  36. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  37. tmp = [x for x in tmp if x != '']
  38. g.add_edge(
  39. int(tmp[0]) - 1,
  40. int(tmp[1]) - 1,
  41. bond_type=tmp[3].strip(),
  42. label=tmp[3].strip())
  43. # for i in range(0, nb_edges):
  44. # tmp = content[i + g.number_of_nodes() + 2]
  45. # tmp = [tmp[i:i+3] for i in range(0, len(tmp), 3)]
  46. # g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  47. # bond_type=tmp[3].strip(), label=tmp[3].strip())
  48. return g
  49. def loadGXL(filename):
  50. from os.path import basename
  51. import networkx as nx
  52. import xml.etree.ElementTree as ET
  53. tree = ET.parse(filename)
  54. root = tree.getroot()
  55. index = 0
  56. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  57. dic = {} # used to retrieve incident nodes of edges
  58. for node in root.iter('node'):
  59. dic[node.attrib['id']] = index
  60. labels = {}
  61. for attr in node.iter('attr'):
  62. labels[attr.attrib['name']] = attr[0].text
  63. if 'chem' in labels:
  64. labels['label'] = labels['chem']
  65. g.add_node(index, **labels)
  66. index += 1
  67. for edge in root.iter('edge'):
  68. labels = {}
  69. for attr in edge.iter('attr'):
  70. labels[attr.attrib['name']] = attr[0].text
  71. if 'valence' in labels:
  72. labels['label'] = labels['valence']
  73. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  74. return g
  75. def saveGXL(graph, filename, method='gedlib'):
  76. if method == 'benoit':
  77. import xml.etree.ElementTree as ET
  78. root_node = ET.Element('gxl')
  79. attr = dict()
  80. attr['id'] = str(graph.graph['name'])
  81. attr['edgeids'] = 'true'
  82. attr['edgemode'] = 'undirected'
  83. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  84. for v in graph:
  85. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  86. for attr in graph.nodes[v].keys():
  87. cur_attr = ET.SubElement(
  88. current_node, 'attr', attrib={'name': attr})
  89. cur_value = ET.SubElement(cur_attr,
  90. graph.nodes[v][attr].__class__.__name__)
  91. cur_value.text = graph.nodes[v][attr]
  92. for v1 in graph:
  93. for v2 in graph[v1]:
  94. if (v1 < v2): # Non oriented graphs
  95. cur_edge = ET.SubElement(
  96. graph_node,
  97. 'edge',
  98. attrib={
  99. 'from': str(v1),
  100. 'to': str(v2)
  101. })
  102. for attr in graph[v1][v2].keys():
  103. cur_attr = ET.SubElement(
  104. cur_edge, 'attr', attrib={'name': attr})
  105. cur_value = ET.SubElement(
  106. cur_attr, graph[v1][v2][attr].__class__.__name__)
  107. cur_value.text = str(graph[v1][v2][attr])
  108. tree = ET.ElementTree(root_node)
  109. tree.write(filename)
  110. elif method == 'gedlib':
  111. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  112. # pass
  113. gxl_file = open(filename, 'w')
  114. gxl_file.write("<?xml version=\"1.0\"?>\n")
  115. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  116. gxl_file.write("<gxl>\n")
  117. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"true\" edgemode=\"undirected\">\n")
  118. for v, attrs in graph.nodes(data=True):
  119. gxl_file.write("<node id=\"_" + str(v) + "\">\n")
  120. gxl_file.write("<attr name=\"" + "chem" + "\"><int>" + str(attrs['atom']) + "</int></attr>\n")
  121. gxl_file.write("</node>\n")
  122. for v1, v2, attrs in graph.edges(data=True):
  123. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">\n")
  124. # gxl_file.write("<attr name=\"valence\"><int>" + str(attrs['bond_type']) + "</int></attr>\n")
  125. gxl_file.write("<attr name=\"valence\"><int>" + "1" + "</int></attr>\n")
  126. gxl_file.write("</edge>\n")
  127. gxl_file.write("</graph>\n")
  128. gxl_file.write("</gxl>\n")
  129. gxl_file.close()
  130. def loadSDF(filename):
  131. """load data from structured data file (.sdf file).
  132. Notes
  133. ------
  134. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  135. Check http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx, 2018 for detailed structure.
  136. """
  137. import networkx as nx
  138. from os.path import basename
  139. from tqdm import tqdm
  140. import sys
  141. data = []
  142. with open(filename) as f:
  143. content = f.read().splitlines()
  144. index = 0
  145. pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  146. while index < len(content):
  147. index_old = index
  148. g = nx.Graph(name=content[index].strip()) # set name of the graph
  149. tmp = content[index + 3]
  150. nb_nodes = int(tmp[:3]) # number of the nodes
  151. nb_edges = int(tmp[3:6]) # number of the edges
  152. for i in range(0, nb_nodes):
  153. tmp = content[i + index + 4]
  154. g.add_node(i, atom=tmp[31:34].strip())
  155. for i in range(0, nb_edges):
  156. tmp = content[i + index + g.number_of_nodes() + 4]
  157. tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  158. g.add_edge(
  159. int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  160. data.append(g)
  161. index += 4 + g.number_of_nodes() + g.number_of_edges()
  162. while content[index].strip() != '$$$$': # seperator
  163. index += 1
  164. index += 1
  165. pbar.update(index - index_old)
  166. pbar.update(1)
  167. pbar.close()
  168. return data
  169. def loadMAT(filename, extra_params):
  170. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  171. Notes
  172. ------
  173. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  174. Check README in downloadable file in http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/, 2018 for detailed structure.
  175. """
  176. from scipy.io import loadmat
  177. import numpy as np
  178. import networkx as nx
  179. data = []
  180. content = loadmat(filename)
  181. order = extra_params['am_sp_al_nl_el']
  182. # print(content)
  183. # print('----')
  184. for key, value in content.items():
  185. if key[0] == 'l': # class label
  186. y = np.transpose(value)[0].tolist()
  187. # print(y)
  188. elif key[0] != '_':
  189. # print(value[0][0][0])
  190. # print()
  191. # print(value[0][0][1])
  192. # print()
  193. # print(value[0][0][2])
  194. # print()
  195. # if len(value[0][0]) > 3:
  196. # print(value[0][0][3])
  197. # print('----')
  198. # if adjacency matrix is not compressed / edge label exists
  199. if order[1] == 0:
  200. for i, item in enumerate(value[0]):
  201. # print(item)
  202. # print('------')
  203. g = nx.Graph(name=i) # set name of the graph
  204. nl = np.transpose(item[order[3]][0][0][0]) # node label
  205. # print(item[order[3]])
  206. # print()
  207. for index, label in enumerate(nl[0]):
  208. g.add_node(index, atom=str(label))
  209. el = item[order[4]][0][0][0] # edge label
  210. for edge in el:
  211. g.add_edge(
  212. edge[0] - 1, edge[1] - 1, bond_type=str(edge[2]))
  213. data.append(g)
  214. else:
  215. from scipy.sparse import csc_matrix
  216. for i, item in enumerate(value[0]):
  217. # print(item)
  218. # print('------')
  219. g = nx.Graph(name=i) # set name of the graph
  220. nl = np.transpose(item[order[3]][0][0][0]) # node label
  221. # print(nl)
  222. # print()
  223. for index, label in enumerate(nl[0]):
  224. g.add_node(index, atom=str(label))
  225. sam = item[order[0]] # sparse adjacency matrix
  226. index_no0 = sam.nonzero()
  227. for col, row in zip(index_no0[0], index_no0[1]):
  228. # print(col)
  229. # print(row)
  230. g.add_edge(col, row)
  231. data.append(g)
  232. # print(g.edges(data=True))
  233. return data, y
  234. def loadTXT(dirname_dataset):
  235. """Load graph data from a .txt file.
  236. Notes
  237. ------
  238. The graph data is loaded from separate files.
  239. Check README in downloadable file http://tiny.cc/PK_MLJ_data, 2018 for detailed structure.
  240. """
  241. import numpy as np
  242. import networkx as nx
  243. from os import listdir
  244. from os.path import dirname
  245. # load data file names
  246. for name in listdir(dirname_dataset):
  247. if '_A' in name:
  248. fam = dirname_dataset + '/' + name
  249. elif '_graph_indicator' in name:
  250. fgi = dirname_dataset + '/' + name
  251. elif '_graph_labels' in name:
  252. fgl = dirname_dataset + '/' + name
  253. elif '_node_labels' in name:
  254. fnl = dirname_dataset + '/' + name
  255. elif '_edge_labels' in name:
  256. fel = dirname_dataset + '/' + name
  257. elif '_edge_attributes' in name:
  258. fea = dirname_dataset + '/' + name
  259. elif '_node_attributes' in name:
  260. fna = dirname_dataset + '/' + name
  261. elif '_graph_attributes' in name:
  262. fga = dirname_dataset + '/' + name
  263. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  264. elif '_attributes' in name:
  265. fna = dirname_dataset + '/' + name
  266. content_gi = open(fgi).read().splitlines() # graph indicator
  267. content_am = open(fam).read().splitlines() # adjacency matrix
  268. content_gl = open(fgl).read().splitlines() # lass labels
  269. # create graphs and add nodes
  270. data = [nx.Graph(name=i) for i in range(0, len(content_gl))]
  271. if 'fnl' in locals():
  272. content_nl = open(fnl).read().splitlines() # node labels
  273. for i, line in enumerate(content_gi):
  274. # transfer to int first in case of unexpected blanks
  275. data[int(line) - 1].add_node(i, atom=str(int(content_nl[i])))
  276. else:
  277. for i, line in enumerate(content_gi):
  278. data[int(line) - 1].add_node(i)
  279. # add edges
  280. for line in content_am:
  281. tmp = line.split(',')
  282. n1 = int(tmp[0]) - 1
  283. n2 = int(tmp[1]) - 1
  284. # ignore edge weight here.
  285. g = int(content_gi[n1]) - 1
  286. data[g].add_edge(n1, n2)
  287. # add edge labels
  288. if 'fel' in locals():
  289. content_el = open(fel).read().splitlines()
  290. for index, line in enumerate(content_el):
  291. label = line.strip()
  292. n = [int(i) - 1 for i in content_am[index].split(',')]
  293. g = int(content_gi[n[0]]) - 1
  294. data[g].edges[n[0], n[1]]['bond_type'] = label
  295. # add node attributes
  296. if 'fna' in locals():
  297. content_na = open(fna).read().splitlines()
  298. for i, line in enumerate(content_na):
  299. attrs = [i.strip() for i in line.split(',')]
  300. g = int(content_gi[i]) - 1
  301. data[g].nodes[i]['attributes'] = attrs
  302. # add edge attributes
  303. if 'fea' in locals():
  304. content_ea = open(fea).read().splitlines()
  305. for index, line in enumerate(content_ea):
  306. attrs = [i.strip() for i in line.split(',')]
  307. n = [int(i) - 1 for i in content_am[index].split(',')]
  308. g = int(content_gi[n[0]]) - 1
  309. data[g].edges[n[0], n[1]]['attributes'] = attrs
  310. # load y
  311. y = [int(i) for i in content_gl]
  312. return data, y
  313. def loadDataset(filename, filename_y=None, extra_params=None):
  314. """load file list of the dataset.
  315. """
  316. from os.path import dirname, splitext
  317. dirname_dataset = dirname(filename)
  318. extension = splitext(filename)[1][1:]
  319. data = []
  320. y = []
  321. if extension == "ds":
  322. content = open(filename).read().splitlines()
  323. if filename_y is None or filename_y == '':
  324. for i in range(0, len(content)):
  325. tmp = content[i].split(' ')
  326. # remove the '#'s in file names
  327. data.append(
  328. loadCT(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  329. y.append(float(tmp[1]))
  330. else: # y in a seperate file
  331. for i in range(0, len(content)):
  332. tmp = content[i]
  333. # remove the '#'s in file names
  334. data.append(
  335. loadCT(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  336. content_y = open(filename_y).read().splitlines()
  337. # assume entries in filename and filename_y have the same order.
  338. for item in content_y:
  339. tmp = item.split(' ')
  340. # assume the 3rd entry in a line is y (for Alkane dataset)
  341. y.append(float(tmp[2]))
  342. elif extension == "cxl":
  343. import xml.etree.ElementTree as ET
  344. tree = ET.parse(filename)
  345. root = tree.getroot()
  346. data = []
  347. y = []
  348. for graph in root.iter('print'):
  349. mol_filename = graph.attrib['file']
  350. mol_class = graph.attrib['class']
  351. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  352. y.append(mol_class)
  353. elif extension == "sdf":
  354. import numpy as np
  355. from tqdm import tqdm
  356. import sys
  357. data = loadSDF(filename)
  358. y_raw = open(filename_y).read().splitlines()
  359. y_raw.pop(0)
  360. tmp0 = []
  361. tmp1 = []
  362. for i in range(0, len(y_raw)):
  363. tmp = y_raw[i].split(',')
  364. tmp0.append(tmp[0])
  365. tmp1.append(tmp[1].strip())
  366. y = []
  367. for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout):
  368. try:
  369. y.append(tmp1[tmp0.index(data[i].name)].strip())
  370. except ValueError: # if data[i].name not in tmp0
  371. data[i] = []
  372. data = list(filter(lambda a: a != [], data))
  373. elif extension == "mat":
  374. data, y = loadMAT(filename, extra_params)
  375. elif extension == 'txt':
  376. data, y = loadTXT(dirname_dataset)
  377. # print(len(y))
  378. # print(y)
  379. # print(data[0].nodes(data=True))
  380. # print('----')
  381. # print(data[0].edges(data=True))
  382. # for g in data:
  383. # print(g.nodes(data=True))
  384. # print('----')
  385. # print(g.edges(data=True))
  386. return data, y
  387. def saveDataset(Gn, y, gformat='gxl', group=None, filename='gfile'):
  388. """Save list of graphs.
  389. """
  390. import os
  391. dirname_ds = os.path.dirname(filename)
  392. if dirname_ds != '':
  393. dirname_ds += '/'
  394. if not os.path.exists(dirname_ds) :
  395. os.makedirs(dirname_ds)
  396. if group == 'xml' and gformat == 'gxl':
  397. with open(filename + '.xml', 'w') as fgroup:
  398. fgroup.write("<?xml version=\"1.0\"?>")
  399. fgroup.write("\n<!DOCTYPE GraphCollection SYSTEM \"https://dbblumenthal.github.io/gedlib/GraphCollection_8dtd_source.html\">")
  400. fgroup.write("\n<GraphCollection>")
  401. for idx, g in enumerate(Gn):
  402. fname_tmp = "graph" + str(idx) + ".gxl"
  403. saveGXL(g, dirname_ds + fname_tmp)
  404. fgroup.write("\n\t<graph file=\"" + fname_tmp + "\" class=\"" + str(y[idx]) + "\"/>")
  405. fgroup.write("\n</GraphCollection>")
  406. fgroup.close()
  407. if __name__ == '__main__':
  408. ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
  409. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  410. Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  411. saveDataset(Gn, y, group='xml', filename='temp/temp')

A Python package for graph kernels, graph edit distances and graph pre-image problem.