You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphfiles.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. """ Utilities function to manage graph files
  2. """
  3. def loadCT(filename):
  4. """load data from .ct file.
  5. Notes
  6. ------
  7. a typical example of data in .ct is like this:
  8. 3 2 <- number of nodes and edges
  9. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  10. 0.0000 0.0000 0.0000 C
  11. 0.0000 0.0000 0.0000 O
  12. 1 3 1 1 <- each line describes an edge : to, from,?, label
  13. 2 3 1 1
  14. """
  15. import networkx as nx
  16. from os.path import basename
  17. g = nx.Graph()
  18. with open(filename) as f:
  19. content = f.read().splitlines()
  20. g = nx.Graph(
  21. name = str(content[0]),
  22. filename = basename(filename)) # set name of the graph
  23. tmp = content[1].split(" ")
  24. if tmp[0] == '':
  25. nb_nodes = int(tmp[1]) # number of the nodes
  26. nb_edges = int(tmp[2]) # number of the edges
  27. else:
  28. nb_nodes = int(tmp[0])
  29. nb_edges = int(tmp[1])
  30. # patch for compatibility : label will be removed later
  31. for i in range(0, nb_nodes):
  32. tmp = content[i + 2].split(" ")
  33. tmp = [x for x in tmp if x != '']
  34. g.add_node(i, atom=tmp[3], label=tmp[3])
  35. for i in range(0, nb_edges):
  36. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  37. tmp = [x for x in tmp if x != '']
  38. g.add_edge(
  39. int(tmp[0]) - 1,
  40. int(tmp[1]) - 1,
  41. bond_type=tmp[3].strip(),
  42. label=tmp[3].strip())
  43. # for i in range(0, nb_edges):
  44. # tmp = content[i + g.number_of_nodes() + 2]
  45. # tmp = [tmp[i:i+3] for i in range(0, len(tmp), 3)]
  46. # g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  47. # bond_type=tmp[3].strip(), label=tmp[3].strip())
  48. return g
  49. def loadGXL(filename):
  50. from os.path import basename
  51. import networkx as nx
  52. import xml.etree.ElementTree as ET
  53. tree = ET.parse(filename)
  54. root = tree.getroot()
  55. index = 0
  56. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  57. dic = {} # used to retrieve incident nodes of edges
  58. for node in root.iter('node'):
  59. dic[node.attrib['id']] = index
  60. labels = {}
  61. for attr in node.iter('attr'):
  62. labels[attr.attrib['name']] = attr[0].text
  63. if 'chem' in labels:
  64. labels['label'] = labels['chem']
  65. g.add_node(index, **labels)
  66. index += 1
  67. for edge in root.iter('edge'):
  68. labels = {}
  69. for attr in edge.iter('attr'):
  70. labels[attr.attrib['name']] = attr[0].text
  71. if 'valence' in labels:
  72. labels['label'] = labels['valence']
  73. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  74. return g
  75. def saveGXL(graph, filename, method='gedlib-letter'):
  76. if method == 'benoit':
  77. import xml.etree.ElementTree as ET
  78. root_node = ET.Element('gxl')
  79. attr = dict()
  80. attr['id'] = str(graph.graph['name'])
  81. attr['edgeids'] = 'true'
  82. attr['edgemode'] = 'undirected'
  83. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  84. for v in graph:
  85. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  86. for attr in graph.nodes[v].keys():
  87. cur_attr = ET.SubElement(
  88. current_node, 'attr', attrib={'name': attr})
  89. cur_value = ET.SubElement(cur_attr,
  90. graph.nodes[v][attr].__class__.__name__)
  91. cur_value.text = graph.nodes[v][attr]
  92. for v1 in graph:
  93. for v2 in graph[v1]:
  94. if (v1 < v2): # Non oriented graphs
  95. cur_edge = ET.SubElement(
  96. graph_node,
  97. 'edge',
  98. attrib={
  99. 'from': str(v1),
  100. 'to': str(v2)
  101. })
  102. for attr in graph[v1][v2].keys():
  103. cur_attr = ET.SubElement(
  104. cur_edge, 'attr', attrib={'name': attr})
  105. cur_value = ET.SubElement(
  106. cur_attr, graph[v1][v2][attr].__class__.__name__)
  107. cur_value.text = str(graph[v1][v2][attr])
  108. tree = ET.ElementTree(root_node)
  109. tree.write(filename)
  110. elif method == 'gedlib':
  111. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  112. # pass
  113. gxl_file = open(filename, 'w')
  114. gxl_file.write("<?xml version=\"1.0\"?>\n")
  115. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  116. gxl_file.write("<gxl>\n")
  117. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"true\" edgemode=\"undirected\">\n")
  118. for v, attrs in graph.nodes(data=True):
  119. gxl_file.write("<node id=\"_" + str(v) + "\">\n")
  120. gxl_file.write("<attr name=\"" + "chem" + "\"><int>" + str(attrs['atom']) + "</int></attr>\n")
  121. gxl_file.write("</node>\n")
  122. for v1, v2, attrs in graph.edges(data=True):
  123. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">\n")
  124. # gxl_file.write("<attr name=\"valence\"><int>" + str(attrs['bond_type']) + "</int></attr>\n")
  125. gxl_file.write("<attr name=\"valence\"><int>" + "1" + "</int></attr>\n")
  126. gxl_file.write("</edge>\n")
  127. gxl_file.write("</graph>\n")
  128. gxl_file.write("</gxl>\n")
  129. gxl_file.close()
  130. elif method == 'gedlib-letter':
  131. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  132. # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl
  133. gxl_file = open(filename, 'w')
  134. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  135. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  136. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  137. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"false\" edgemode=\"undirected\">")
  138. for v, attrs in graph.nodes(data=True):
  139. gxl_file.write("<node id=\"_" + str(v) + "\">")
  140. gxl_file.write("<attr name=\"x\"><float>" + str(attrs['attributes'][0]) + "</float></attr>")
  141. gxl_file.write("<attr name=\"y\"><float>" + str(attrs['attributes'][1]) + "</float></attr>")
  142. gxl_file.write("</node>")
  143. for v1, v2, attrs in graph.edges(data=True):
  144. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\"/>")
  145. gxl_file.write("</graph>")
  146. gxl_file.write("</gxl>")
  147. gxl_file.close()
  148. def loadSDF(filename):
  149. """load data from structured data file (.sdf file).
  150. Notes
  151. ------
  152. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  153. Check http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx, 2018 for detailed structure.
  154. """
  155. import networkx as nx
  156. from os.path import basename
  157. from tqdm import tqdm
  158. import sys
  159. data = []
  160. with open(filename) as f:
  161. content = f.read().splitlines()
  162. index = 0
  163. pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  164. while index < len(content):
  165. index_old = index
  166. g = nx.Graph(name=content[index].strip()) # set name of the graph
  167. tmp = content[index + 3]
  168. nb_nodes = int(tmp[:3]) # number of the nodes
  169. nb_edges = int(tmp[3:6]) # number of the edges
  170. for i in range(0, nb_nodes):
  171. tmp = content[i + index + 4]
  172. g.add_node(i, atom=tmp[31:34].strip())
  173. for i in range(0, nb_edges):
  174. tmp = content[i + index + g.number_of_nodes() + 4]
  175. tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  176. g.add_edge(
  177. int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  178. data.append(g)
  179. index += 4 + g.number_of_nodes() + g.number_of_edges()
  180. while content[index].strip() != '$$$$': # seperator
  181. index += 1
  182. index += 1
  183. pbar.update(index - index_old)
  184. pbar.update(1)
  185. pbar.close()
  186. return data
  187. def loadMAT(filename, extra_params):
  188. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  189. Notes
  190. ------
  191. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  192. Check README in downloadable file in http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/, 2018 for detailed structure.
  193. """
  194. from scipy.io import loadmat
  195. import numpy as np
  196. import networkx as nx
  197. data = []
  198. content = loadmat(filename)
  199. order = extra_params['am_sp_al_nl_el']
  200. # print(content)
  201. # print('----')
  202. for key, value in content.items():
  203. if key[0] == 'l': # class label
  204. y = np.transpose(value)[0].tolist()
  205. # print(y)
  206. elif key[0] != '_':
  207. # print(value[0][0][0])
  208. # print()
  209. # print(value[0][0][1])
  210. # print()
  211. # print(value[0][0][2])
  212. # print()
  213. # if len(value[0][0]) > 3:
  214. # print(value[0][0][3])
  215. # print('----')
  216. # if adjacency matrix is not compressed / edge label exists
  217. if order[1] == 0:
  218. for i, item in enumerate(value[0]):
  219. # print(item)
  220. # print('------')
  221. g = nx.Graph(name=i) # set name of the graph
  222. nl = np.transpose(item[order[3]][0][0][0]) # node label
  223. # print(item[order[3]])
  224. # print()
  225. for index, label in enumerate(nl[0]):
  226. g.add_node(index, atom=str(label))
  227. el = item[order[4]][0][0][0] # edge label
  228. for edge in el:
  229. g.add_edge(
  230. edge[0] - 1, edge[1] - 1, bond_type=str(edge[2]))
  231. data.append(g)
  232. else:
  233. from scipy.sparse import csc_matrix
  234. for i, item in enumerate(value[0]):
  235. # print(item)
  236. # print('------')
  237. g = nx.Graph(name=i) # set name of the graph
  238. nl = np.transpose(item[order[3]][0][0][0]) # node label
  239. # print(nl)
  240. # print()
  241. for index, label in enumerate(nl[0]):
  242. g.add_node(index, atom=str(label))
  243. sam = item[order[0]] # sparse adjacency matrix
  244. index_no0 = sam.nonzero()
  245. for col, row in zip(index_no0[0], index_no0[1]):
  246. # print(col)
  247. # print(row)
  248. g.add_edge(col, row)
  249. data.append(g)
  250. # print(g.edges(data=True))
  251. return data, y
  252. def loadTXT(dirname_dataset):
  253. """Load graph data from a .txt file.
  254. Notes
  255. ------
  256. The graph data is loaded from separate files.
  257. Check README in downloadable file http://tiny.cc/PK_MLJ_data, 2018 for detailed structure.
  258. """
  259. import numpy as np
  260. import networkx as nx
  261. from os import listdir
  262. from os.path import dirname
  263. # load data file names
  264. for name in listdir(dirname_dataset):
  265. if '_A' in name:
  266. fam = dirname_dataset + '/' + name
  267. elif '_graph_indicator' in name:
  268. fgi = dirname_dataset + '/' + name
  269. elif '_graph_labels' in name:
  270. fgl = dirname_dataset + '/' + name
  271. elif '_node_labels' in name:
  272. fnl = dirname_dataset + '/' + name
  273. elif '_edge_labels' in name:
  274. fel = dirname_dataset + '/' + name
  275. elif '_edge_attributes' in name:
  276. fea = dirname_dataset + '/' + name
  277. elif '_node_attributes' in name:
  278. fna = dirname_dataset + '/' + name
  279. elif '_graph_attributes' in name:
  280. fga = dirname_dataset + '/' + name
  281. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  282. elif '_attributes' in name:
  283. fna = dirname_dataset + '/' + name
  284. content_gi = open(fgi).read().splitlines() # graph indicator
  285. content_am = open(fam).read().splitlines() # adjacency matrix
  286. content_gl = open(fgl).read().splitlines() # lass labels
  287. # create graphs and add nodes
  288. data = [nx.Graph(name=i) for i in range(0, len(content_gl))]
  289. if 'fnl' in locals():
  290. content_nl = open(fnl).read().splitlines() # node labels
  291. for i, line in enumerate(content_gi):
  292. # transfer to int first in case of unexpected blanks
  293. data[int(line) - 1].add_node(i, atom=str(int(content_nl[i])))
  294. else:
  295. for i, line in enumerate(content_gi):
  296. data[int(line) - 1].add_node(i)
  297. # add edges
  298. for line in content_am:
  299. tmp = line.split(',')
  300. n1 = int(tmp[0]) - 1
  301. n2 = int(tmp[1]) - 1
  302. # ignore edge weight here.
  303. g = int(content_gi[n1]) - 1
  304. data[g].add_edge(n1, n2)
  305. # add edge labels
  306. if 'fel' in locals():
  307. content_el = open(fel).read().splitlines()
  308. for index, line in enumerate(content_el):
  309. label = line.strip()
  310. n = [int(i) - 1 for i in content_am[index].split(',')]
  311. g = int(content_gi[n[0]]) - 1
  312. data[g].edges[n[0], n[1]]['bond_type'] = label
  313. # add node attributes
  314. if 'fna' in locals():
  315. content_na = open(fna).read().splitlines()
  316. for i, line in enumerate(content_na):
  317. attrs = [i.strip() for i in line.split(',')]
  318. g = int(content_gi[i]) - 1
  319. data[g].nodes[i]['attributes'] = attrs
  320. # add edge attributes
  321. if 'fea' in locals():
  322. content_ea = open(fea).read().splitlines()
  323. for index, line in enumerate(content_ea):
  324. attrs = [i.strip() for i in line.split(',')]
  325. n = [int(i) - 1 for i in content_am[index].split(',')]
  326. g = int(content_gi[n[0]]) - 1
  327. data[g].edges[n[0], n[1]]['attributes'] = attrs
  328. # load y
  329. y = [int(i) for i in content_gl]
  330. return data, y
  331. def loadDataset(filename, filename_y=None, extra_params=None):
  332. """load file list of the dataset.
  333. """
  334. from os.path import dirname, splitext
  335. dirname_dataset = dirname(filename)
  336. extension = splitext(filename)[1][1:]
  337. data = []
  338. y = []
  339. if extension == "ds":
  340. content = open(filename).read().splitlines()
  341. if filename_y is None or filename_y == '':
  342. for i in range(0, len(content)):
  343. tmp = content[i].split(' ')
  344. # remove the '#'s in file names
  345. data.append(
  346. loadCT(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  347. y.append(float(tmp[1]))
  348. else: # y in a seperate file
  349. for i in range(0, len(content)):
  350. tmp = content[i]
  351. # remove the '#'s in file names
  352. data.append(
  353. loadCT(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  354. content_y = open(filename_y).read().splitlines()
  355. # assume entries in filename and filename_y have the same order.
  356. for item in content_y:
  357. tmp = item.split(' ')
  358. # assume the 3rd entry in a line is y (for Alkane dataset)
  359. y.append(float(tmp[2]))
  360. elif extension == "cxl":
  361. import xml.etree.ElementTree as ET
  362. tree = ET.parse(filename)
  363. root = tree.getroot()
  364. data = []
  365. y = []
  366. for graph in root.iter('print'):
  367. mol_filename = graph.attrib['file']
  368. mol_class = graph.attrib['class']
  369. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  370. y.append(mol_class)
  371. elif extension == "sdf":
  372. import numpy as np
  373. from tqdm import tqdm
  374. import sys
  375. data = loadSDF(filename)
  376. y_raw = open(filename_y).read().splitlines()
  377. y_raw.pop(0)
  378. tmp0 = []
  379. tmp1 = []
  380. for i in range(0, len(y_raw)):
  381. tmp = y_raw[i].split(',')
  382. tmp0.append(tmp[0])
  383. tmp1.append(tmp[1].strip())
  384. y = []
  385. for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout):
  386. try:
  387. y.append(tmp1[tmp0.index(data[i].name)].strip())
  388. except ValueError: # if data[i].name not in tmp0
  389. data[i] = []
  390. data = list(filter(lambda a: a != [], data))
  391. elif extension == "mat":
  392. data, y = loadMAT(filename, extra_params)
  393. elif extension == 'txt':
  394. data, y = loadTXT(dirname_dataset)
  395. # print(len(y))
  396. # print(y)
  397. # print(data[0].nodes(data=True))
  398. # print('----')
  399. # print(data[0].edges(data=True))
  400. # for g in data:
  401. # print(g.nodes(data=True))
  402. # print('----')
  403. # print(g.edges(data=True))
  404. return data, y
  405. def saveDataset(Gn, y, gformat='gxl', group=None, filename='gfile'):
  406. """Save list of graphs.
  407. """
  408. import os
  409. dirname_ds = os.path.dirname(filename)
  410. if dirname_ds != '':
  411. dirname_ds += '/'
  412. if not os.path.exists(dirname_ds) :
  413. os.makedirs(dirname_ds)
  414. if group == 'xml' and gformat == 'gxl':
  415. with open(filename + '.xml', 'w') as fgroup:
  416. fgroup.write("<?xml version=\"1.0\"?>")
  417. fgroup.write("\n<!DOCTYPE GraphCollection SYSTEM \"https://dbblumenthal.github.io/gedlib/GraphCollection_8dtd_source.html\">")
  418. fgroup.write("\n<GraphCollection>")
  419. for idx, g in enumerate(Gn):
  420. fname_tmp = "graph" + str(idx) + ".gxl"
  421. saveGXL(g, dirname_ds + fname_tmp)
  422. fgroup.write("\n\t<graph file=\"" + fname_tmp + "\" class=\"" + str(y[idx]) + "\"/>")
  423. fgroup.write("\n</GraphCollection>")
  424. fgroup.close()
  425. if __name__ == '__main__':
  426. ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
  427. 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  428. Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  429. saveDataset(Gn, y, group='xml', filename='temp/temp')

A Python package for graph kernels, graph edit distances and graph pre-image problem.