You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphfiles.py 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. """ Utilities function to manage graph files
  2. """
  3. from os.path import dirname, splitext
  4. def loadCT(filename):
  5. """load data from a Chemical Table (.ct) file.
  6. Notes
  7. ------
  8. a typical example of data in .ct is like this:
  9. 3 2 <- number of nodes and edges
  10. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  11. 0.0000 0.0000 0.0000 C
  12. 0.0000 0.0000 0.0000 O
  13. 1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo
  14. 2 3 1 1
  15. Check `CTFile Formats file <https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=10&ved=2ahUKEwivhaSdjsTlAhVhx4UKHczHA8gQFjAJegQIARAC&url=https%3A%2F%2Fwww.daylight.com%2Fmeetings%2Fmug05%2FKappler%2Fctfile.pdf&usg=AOvVaw1cDNrrmMClkFPqodlF2inS>`__
  16. for detailed format discription.
  17. """
  18. import networkx as nx
  19. from os.path import basename
  20. g = nx.Graph()
  21. with open(filename) as f:
  22. content = f.read().splitlines()
  23. g = nx.Graph(
  24. name = str(content[0]),
  25. filename = basename(filename)) # set name of the graph
  26. tmp = content[1].split(" ")
  27. if tmp[0] == '':
  28. nb_nodes = int(tmp[1]) # number of the nodes
  29. nb_edges = int(tmp[2]) # number of the edges
  30. else:
  31. nb_nodes = int(tmp[0])
  32. nb_edges = int(tmp[1])
  33. # patch for compatibility : label will be removed later
  34. for i in range(0, nb_nodes):
  35. tmp = content[i + 2].split(" ")
  36. tmp = [x for x in tmp if x != '']
  37. g.add_node(i, atom=tmp[3].strip(),
  38. label=[item.strip() for item in tmp[3:]],
  39. attributes=[item.strip() for item in tmp[0:3]])
  40. for i in range(0, nb_edges):
  41. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  42. tmp = [x for x in tmp if x != '']
  43. g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  44. bond_type=tmp[2].strip(),
  45. label=[item.strip() for item in tmp[2:]])
  46. return g
  47. def loadGXL(filename):
  48. from os.path import basename
  49. import networkx as nx
  50. import xml.etree.ElementTree as ET
  51. tree = ET.parse(filename)
  52. root = tree.getroot()
  53. index = 0
  54. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  55. dic = {} # used to retrieve incident nodes of edges
  56. for node in root.iter('node'):
  57. dic[node.attrib['id']] = index
  58. labels = {}
  59. for attr in node.iter('attr'):
  60. labels[attr.attrib['name']] = attr[0].text
  61. if 'chem' in labels:
  62. labels['label'] = labels['chem']
  63. labels['atom'] = labels['chem']
  64. g.add_node(index, **labels)
  65. index += 1
  66. for edge in root.iter('edge'):
  67. labels = {}
  68. for attr in edge.iter('attr'):
  69. labels[attr.attrib['name']] = attr[0].text
  70. if 'valence' in labels:
  71. labels['label'] = labels['valence']
  72. labels['bond_type'] = labels['valence']
  73. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  74. return g
  75. def saveGXL(graph, filename, method='default'):
  76. if method == 'default':
  77. gxl_file = open(filename, 'w')
  78. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  79. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  80. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  81. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  82. for v, attrs in graph.nodes(data=True):
  83. if graph.graph['node_labels'] == [] and graph.graph['node_attrs'] == []:
  84. gxl_file.write("<node id=\"_" + str(v) + "\"/>\n")
  85. else:
  86. gxl_file.write("<node id=\"_" + str(v) + "\">")
  87. for l_name in graph.graph['node_labels']:
  88. gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
  89. str(attrs[l_name]) + "</int></attr>")
  90. for a_name in graph.graph['node_attrs']:
  91. gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
  92. str(attrs[a_name]) + "</float></attr>")
  93. gxl_file.write("</node>\n")
  94. for v1, v2, attrs in graph.edges(data=True):
  95. if graph.graph['edge_labels'] == [] and graph.graph['edge_attrs'] == []:
  96. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\"/>\n")
  97. else:
  98. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  99. for l_name in graph.graph['edge_labels']:
  100. gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
  101. str(attrs[l_name]) + "</int></attr>")
  102. for a_name in graph.graph['edge_attrs']:
  103. gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
  104. str(attrs[a_name]) + "</float></attr>")
  105. gxl_file.write("</edge>\n")
  106. gxl_file.write("</graph>\n")
  107. gxl_file.write("</gxl>")
  108. gxl_file.close()
  109. elif method == 'benoit':
  110. import xml.etree.ElementTree as ET
  111. root_node = ET.Element('gxl')
  112. attr = dict()
  113. attr['id'] = str(graph.graph['name'])
  114. attr['edgeids'] = 'true'
  115. attr['edgemode'] = 'undirected'
  116. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  117. for v in graph:
  118. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  119. for attr in graph.nodes[v].keys():
  120. cur_attr = ET.SubElement(
  121. current_node, 'attr', attrib={'name': attr})
  122. cur_value = ET.SubElement(cur_attr,
  123. graph.nodes[v][attr].__class__.__name__)
  124. cur_value.text = graph.nodes[v][attr]
  125. for v1 in graph:
  126. for v2 in graph[v1]:
  127. if (v1 < v2): # Non oriented graphs
  128. cur_edge = ET.SubElement(
  129. graph_node,
  130. 'edge',
  131. attrib={
  132. 'from': str(v1),
  133. 'to': str(v2)
  134. })
  135. for attr in graph[v1][v2].keys():
  136. cur_attr = ET.SubElement(
  137. cur_edge, 'attr', attrib={'name': attr})
  138. cur_value = ET.SubElement(
  139. cur_attr, graph[v1][v2][attr].__class__.__name__)
  140. cur_value.text = str(graph[v1][v2][attr])
  141. tree = ET.ElementTree(root_node)
  142. tree.write(filename)
  143. elif method == 'gedlib':
  144. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  145. # pass
  146. gxl_file = open(filename, 'w')
  147. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  148. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  149. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  150. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"true\" edgemode=\"undirected\">\n")
  151. for v, attrs in graph.nodes(data=True):
  152. gxl_file.write("<node id=\"_" + str(v) + "\">")
  153. gxl_file.write("<attr name=\"" + "chem" + "\"><int>" + str(attrs['chem']) + "</int></attr>")
  154. gxl_file.write("</node>\n")
  155. for v1, v2, attrs in graph.edges(data=True):
  156. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  157. gxl_file.write("<attr name=\"valence\"><int>" + str(attrs['valence']) + "</int></attr>")
  158. # gxl_file.write("<attr name=\"valence\"><int>" + "1" + "</int></attr>")
  159. gxl_file.write("</edge>\n")
  160. gxl_file.write("</graph>\n")
  161. gxl_file.write("</gxl>")
  162. gxl_file.close()
  163. elif method == 'gedlib-letter':
  164. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  165. # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl
  166. gxl_file = open(filename, 'w')
  167. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  168. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  169. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  170. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  171. for v, attrs in graph.nodes(data=True):
  172. gxl_file.write("<node id=\"_" + str(v) + "\">")
  173. gxl_file.write("<attr name=\"x\"><float>" + str(attrs['attributes'][0]) + "</float></attr>")
  174. gxl_file.write("<attr name=\"y\"><float>" + str(attrs['attributes'][1]) + "</float></attr>")
  175. gxl_file.write("</node>\n")
  176. for v1, v2, attrs in graph.edges(data=True):
  177. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\"/>\n")
  178. gxl_file.write("</graph>\n")
  179. gxl_file.write("</gxl>")
  180. gxl_file.close()
  181. def loadSDF(filename):
  182. """load data from structured data file (.sdf file).
  183. Notes
  184. ------
  185. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  186. Check `here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__ for detailed structure.
  187. """
  188. import networkx as nx
  189. from os.path import basename
  190. from tqdm import tqdm
  191. import sys
  192. data = []
  193. with open(filename) as f:
  194. content = f.read().splitlines()
  195. index = 0
  196. pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  197. while index < len(content):
  198. index_old = index
  199. g = nx.Graph(name=content[index].strip()) # set name of the graph
  200. tmp = content[index + 3]
  201. nb_nodes = int(tmp[:3]) # number of the nodes
  202. nb_edges = int(tmp[3:6]) # number of the edges
  203. for i in range(0, nb_nodes):
  204. tmp = content[i + index + 4]
  205. g.add_node(i, atom=tmp[31:34].strip())
  206. for i in range(0, nb_edges):
  207. tmp = content[i + index + g.number_of_nodes() + 4]
  208. tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  209. g.add_edge(
  210. int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  211. data.append(g)
  212. index += 4 + g.number_of_nodes() + g.number_of_edges()
  213. while content[index].strip() != '$$$$': # seperator
  214. index += 1
  215. index += 1
  216. pbar.update(index - index_old)
  217. pbar.update(1)
  218. pbar.close()
  219. return data
  220. def loadMAT(filename, extra_params):
  221. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  222. Notes
  223. ------
  224. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  225. Check README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__ for detailed structure.
  226. """
  227. from scipy.io import loadmat
  228. import numpy as np
  229. import networkx as nx
  230. data = []
  231. content = loadmat(filename)
  232. order = extra_params['am_sp_al_nl_el']
  233. # print(content)
  234. # print('----')
  235. for key, value in content.items():
  236. if key[0] == 'l': # class label
  237. y = np.transpose(value)[0].tolist()
  238. # print(y)
  239. elif key[0] != '_':
  240. # print(value[0][0][0])
  241. # print()
  242. # print(value[0][0][1])
  243. # print()
  244. # print(value[0][0][2])
  245. # print()
  246. # if len(value[0][0]) > 3:
  247. # print(value[0][0][3])
  248. # print('----')
  249. # if adjacency matrix is not compressed / edge label exists
  250. if order[1] == 0:
  251. for i, item in enumerate(value[0]):
  252. # print(item)
  253. # print('------')
  254. g = nx.Graph(name=i) # set name of the graph
  255. nl = np.transpose(item[order[3]][0][0][0]) # node label
  256. # print(item[order[3]])
  257. # print()
  258. for index, label in enumerate(nl[0]):
  259. g.add_node(index, atom=str(label))
  260. el = item[order[4]][0][0][0] # edge label
  261. for edge in el:
  262. g.add_edge(
  263. edge[0] - 1, edge[1] - 1, bond_type=str(edge[2]))
  264. data.append(g)
  265. else:
  266. from scipy.sparse import csc_matrix
  267. for i, item in enumerate(value[0]):
  268. # print(item)
  269. # print('------')
  270. g = nx.Graph(name=i) # set name of the graph
  271. nl = np.transpose(item[order[3]][0][0][0]) # node label
  272. # print(nl)
  273. # print()
  274. for index, label in enumerate(nl[0]):
  275. g.add_node(index, atom=str(label))
  276. sam = item[order[0]] # sparse adjacency matrix
  277. index_no0 = sam.nonzero()
  278. for col, row in zip(index_no0[0], index_no0[1]):
  279. # print(col)
  280. # print(row)
  281. g.add_edge(col, row)
  282. data.append(g)
  283. # print(g.edges(data=True))
  284. return data, y
  285. def loadTXT(filename):
  286. """Load graph data from a .txt file.
  287. Notes
  288. ------
  289. The graph data is loaded from separate files.
  290. Check README in `downloadable file <http://tiny.cc/PK_MLJ_data>`__, 2018 for detailed structure.
  291. """
  292. # import numpy as np
  293. import networkx as nx
  294. from os import listdir
  295. from os.path import dirname, basename
  296. def get_label_names(frm):
  297. """Get label names from DS_label_readme.txt file.
  298. """
  299. def get_names_from_line(line):
  300. """Get names of labels/attributes from a line.
  301. """
  302. str_names = line.split('[')[1].split(']')[0]
  303. names = str_names.split(',')
  304. names = [attr.strip() for attr in names]
  305. return names
  306. label_names = {'node_labels': [], 'node_attrs': [],
  307. 'edge_labels': [], 'edge_attrs': []}
  308. content_rm = open(frm).read().splitlines()
  309. for line in content_rm:
  310. line = line.strip()
  311. if line.startswith('Node labels:'):
  312. label_names['node_labels'] = get_names_from_line(line)
  313. elif line.startswith('Node attributes:'):
  314. label_names['node_attrs'] = get_names_from_line(line)
  315. elif line.startswith('Edge labels:'):
  316. label_names['edge_labels'] = get_names_from_line(line)
  317. elif line.startswith('Edge attributes:'):
  318. label_names['edge_attrs'] = get_names_from_line(line)
  319. return label_names
  320. # get dataset name.
  321. dirname_dataset = dirname(filename)
  322. filename = basename(filename)
  323. fn_split = filename.split('_A')
  324. ds_name = fn_split[0].strip()
  325. # load data file names
  326. for name in listdir(dirname_dataset):
  327. if ds_name + '_A' in name:
  328. fam = dirname_dataset + '/' + name
  329. elif ds_name + '_graph_indicator' in name:
  330. fgi = dirname_dataset + '/' + name
  331. elif ds_name + '_graph_labels' in name:
  332. fgl = dirname_dataset + '/' + name
  333. elif ds_name + '_node_labels' in name:
  334. fnl = dirname_dataset + '/' + name
  335. elif ds_name + '_edge_labels' in name:
  336. fel = dirname_dataset + '/' + name
  337. elif ds_name + '_edge_attributes' in name:
  338. fea = dirname_dataset + '/' + name
  339. elif ds_name + '_node_attributes' in name:
  340. fna = dirname_dataset + '/' + name
  341. elif ds_name + '_graph_attributes' in name:
  342. fga = dirname_dataset + '/' + name
  343. elif ds_name + '_label_readme' in name:
  344. frm = dirname_dataset + '/' + name
  345. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  346. elif ds_name + '_attributes' in name:
  347. fna = dirname_dataset + '/' + name
  348. # get labels and attributes names.
  349. if 'frm' in locals():
  350. label_names = get_label_names(frm)
  351. else:
  352. label_names = {'node_labels': [], 'node_attrs': [],
  353. 'edge_labels': [], 'edge_attrs': []}
  354. content_gi = open(fgi).read().splitlines() # graph indicator
  355. content_am = open(fam).read().splitlines() # adjacency matrix
  356. content_gl = open(fgl).read().splitlines() # graph labels
  357. # create graphs and add nodes
  358. data = [nx.Graph(name=str(i),
  359. node_labels=label_names['node_labels'],
  360. node_attrs=label_names['node_attrs'],
  361. edge_labels=label_names['edge_labels'],
  362. edge_attrs=label_names['edge_attrs']) for i in range(0, len(content_gl))]
  363. if 'fnl' in locals():
  364. content_nl = open(fnl).read().splitlines() # node labels
  365. for idx, line in enumerate(content_gi):
  366. # transfer to int first in case of unexpected blanks
  367. data[int(line) - 1].add_node(idx)
  368. labels = [l.strip() for l in content_nl[idx].split(',')]
  369. data[int(line) - 1].nodes[idx]['atom'] = str(int(labels[0])) # @todo: this should be removed after.
  370. if data[int(line) - 1].graph['node_labels'] == []:
  371. for i, label in enumerate(labels):
  372. l_name = 'label_' + str(i)
  373. data[int(line) - 1].nodes[idx][l_name] = label
  374. data[int(line) - 1].graph['node_labels'].append(l_name)
  375. else:
  376. for i, l_name in enumerate(data[int(line) - 1].graph['node_labels']):
  377. data[int(line) - 1].nodes[idx][l_name] = labels[i]
  378. else:
  379. for i, line in enumerate(content_gi):
  380. data[int(line) - 1].add_node(i)
  381. # add edges
  382. for line in content_am:
  383. tmp = line.split(',')
  384. n1 = int(tmp[0]) - 1
  385. n2 = int(tmp[1]) - 1
  386. # ignore edge weight here.
  387. g = int(content_gi[n1]) - 1
  388. data[g].add_edge(n1, n2)
  389. # add edge labels
  390. if 'fel' in locals():
  391. content_el = open(fel).read().splitlines()
  392. for idx, line in enumerate(content_el):
  393. labels = [l.strip() for l in line.split(',')]
  394. n = [int(i) - 1 for i in content_am[idx].split(',')]
  395. g = int(content_gi[n[0]]) - 1
  396. data[g].edges[n[0], n[1]]['bond_type'] = labels[0] # @todo: this should be removed after.
  397. if data[g].graph['edge_labels'] == []:
  398. for i, label in enumerate(labels):
  399. l_name = 'label_' + str(i)
  400. data[g].edges[n[0], n[1]][l_name] = label
  401. data[g].graph['edge_labels'].append(l_name)
  402. else:
  403. for i, l_name in enumerate(data[g].graph['edge_labels']):
  404. data[g].edges[n[0], n[1]][l_name] = labels[i]
  405. # add node attributes
  406. if 'fna' in locals():
  407. content_na = open(fna).read().splitlines()
  408. for idx, line in enumerate(content_na):
  409. attrs = [a.strip() for a in line.split(',')]
  410. g = int(content_gi[idx]) - 1
  411. data[g].nodes[idx]['attributes'] = attrs # @todo: this should be removed after.
  412. if data[g].graph['node_attrs'] == []:
  413. for i, attr in enumerate(attrs):
  414. a_name = 'attr_' + str(i)
  415. data[g].nodes[idx][a_name] = attr
  416. data[g].graph['node_attrs'].append(a_name)
  417. else:
  418. for i, a_name in enumerate(data[g].graph['node_attrs']):
  419. data[g].nodes[idx][a_name] = attrs[i]
  420. # add edge attributes
  421. if 'fea' in locals():
  422. content_ea = open(fea).read().splitlines()
  423. for idx, line in enumerate(content_ea):
  424. attrs = [a.strip() for a in line.split(',')]
  425. n = [int(i) - 1 for i in content_am[idx].split(',')]
  426. g = int(content_gi[n[0]]) - 1
  427. data[g].edges[n[0], n[1]]['attributes'] = attrs # @todo: this should be removed after.
  428. if data[g].graph['edge_attrs'] == []:
  429. for i, attr in enumerate(attrs):
  430. a_name = 'attr_' + str(i)
  431. data[g].edges[n[0], n[1]][a_name] = attr
  432. data[g].graph['edge_attrs'].append(a_name)
  433. else:
  434. for i, a_name in enumerate(data[g].graph['edge_attrs']):
  435. data[g].edges[n[0], n[1]][a_name] = attrs[i]
  436. # load y
  437. y = [int(i) for i in content_gl]
  438. return data, y
  439. def loadDataset(filename, filename_y=None, extra_params=None):
  440. """Read graph data from filename and load them as NetworkX graphs.
  441. Parameters
  442. ----------
  443. filename : string
  444. The name of the file from where the dataset is read.
  445. filename_y : string
  446. The name of file of the targets corresponding to graphs.
  447. extra_params : dict
  448. Extra parameters only designated to '.mat' format.
  449. Return
  450. ------
  451. data : List of NetworkX graph.
  452. y : List
  453. Targets corresponding to graphs.
  454. Notes
  455. -----
  456. This function supports following graph dataset formats:
  457. 'ds': load data from .ds file. See comments of function loadFromDS for a example.
  458. 'cxl': load data from Graph eXchange Language file (.cxl file). See
  459. `here <http://www.gupro.de/GXL/Introduction/background.html>`__ for detail.
  460. 'sdf': load data from structured data file (.sdf file). See
  461. `here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__
  462. for details.
  463. 'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See
  464. README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__
  465. for details.
  466. 'txt': Load graph data from a special .txt file. See
  467. `here <https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets>`__
  468. for details. Note here filename is the name of either .txt file in
  469. the dataset directory.
  470. """
  471. extension = splitext(filename)[1][1:]
  472. if extension == "ds":
  473. data, y = loadFromDS(filename, filename_y)
  474. elif extension == "cxl":
  475. import xml.etree.ElementTree as ET
  476. dirname_dataset = dirname(filename)
  477. tree = ET.parse(filename)
  478. root = tree.getroot()
  479. data = []
  480. y = []
  481. for graph in root.iter('graph'):
  482. mol_filename = graph.attrib['file']
  483. mol_class = graph.attrib['class']
  484. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  485. y.append(mol_class)
  486. elif extension == 'xml':
  487. data, y = loadFromXML(filename, extra_params)
  488. elif extension == "sdf":
  489. # import numpy as np
  490. from tqdm import tqdm
  491. import sys
  492. data = loadSDF(filename)
  493. y_raw = open(filename_y).read().splitlines()
  494. y_raw.pop(0)
  495. tmp0 = []
  496. tmp1 = []
  497. for i in range(0, len(y_raw)):
  498. tmp = y_raw[i].split(',')
  499. tmp0.append(tmp[0])
  500. tmp1.append(tmp[1].strip())
  501. y = []
  502. for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout):
  503. try:
  504. y.append(tmp1[tmp0.index(data[i].name)].strip())
  505. except ValueError: # if data[i].name not in tmp0
  506. data[i] = []
  507. data = list(filter(lambda a: a != [], data))
  508. elif extension == "mat":
  509. data, y = loadMAT(filename, extra_params)
  510. elif extension == 'txt':
  511. data, y = loadTXT(filename)
  512. # print(len(y))
  513. # print(y)
  514. # print(data[0].nodes(data=True))
  515. # print('----')
  516. # print(data[0].edges(data=True))
  517. # for g in data:
  518. # print(g.nodes(data=True))
  519. # print('----')
  520. # print(g.edges(data=True))
  521. return data, y
  522. def loadFromXML(filename, extra_params):
  523. import xml.etree.ElementTree as ET
  524. if extra_params:
  525. dirname_dataset = extra_params
  526. else:
  527. dirname_dataset = dirname(filename)
  528. tree = ET.parse(filename)
  529. root = tree.getroot()
  530. data = []
  531. y = []
  532. for graph in root.iter('graph'):
  533. mol_filename = graph.attrib['file']
  534. mol_class = graph.attrib['class']
  535. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  536. y.append(mol_class)
  537. return data, y
  538. def loadFromDS(filename, filename_y):
  539. """Load data from .ds file.
  540. Possible graph formats include:
  541. '.ct': see function loadCT for detail.
  542. '.gxl': see dunction loadGXL for detail.
  543. Note these graph formats are checked automatically by the extensions of
  544. graph files.
  545. """
  546. dirname_dataset = dirname(filename)
  547. data = []
  548. y = []
  549. content = open(filename).read().splitlines()
  550. extension = splitext(content[0].split(' ')[0])[1][1:]
  551. if filename_y is None or filename_y == '':
  552. if extension == 'ct':
  553. for i in range(0, len(content)):
  554. tmp = content[i].split(' ')
  555. # remove the '#'s in file names
  556. data.append(
  557. loadCT(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  558. y.append(float(tmp[1]))
  559. elif extension == 'gxl':
  560. for i in range(0, len(content)):
  561. tmp = content[i].split(' ')
  562. # remove the '#'s in file names
  563. data.append(
  564. loadGXL(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  565. y.append(float(tmp[1]))
  566. else: # y in a seperate file
  567. if extension == 'ct':
  568. for i in range(0, len(content)):
  569. tmp = content[i]
  570. # remove the '#'s in file names
  571. data.append(
  572. loadCT(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  573. elif extension == 'gxl':
  574. for i in range(0, len(content)):
  575. tmp = content[i]
  576. # remove the '#'s in file names
  577. data.append(
  578. loadGXL(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  579. content_y = open(filename_y).read().splitlines()
  580. # assume entries in filename and filename_y have the same order.
  581. for item in content_y:
  582. tmp = item.split(' ')
  583. # assume the 3rd entry in a line is y (for Alkane dataset)
  584. y.append(float(tmp[2]))
  585. return data, y
  586. def saveDataset(Gn, y, gformat='gxl', group=None, filename='gfile', xparams=None):
  587. """Save list of graphs.
  588. """
  589. import os
  590. dirname_ds = os.path.dirname(filename)
  591. if dirname_ds != '':
  592. dirname_ds += '/'
  593. if not os.path.exists(dirname_ds) :
  594. os.makedirs(dirname_ds)
  595. if xparams is not None and 'graph_dir' in xparams:
  596. graph_dir = xparams['graph_dir'] + '/'
  597. if not os.path.exists(graph_dir):
  598. os.makedirs(graph_dir)
  599. else:
  600. graph_dir = dirname_ds
  601. if group == 'xml' and gformat == 'gxl':
  602. kwargs = {'method': xparams['method']} if xparams is not None else {}
  603. with open(filename + '.xml', 'w') as fgroup:
  604. fgroup.write("<?xml version=\"1.0\"?>")
  605. fgroup.write("\n<!DOCTYPE GraphCollection SYSTEM \"http://www.inf.unibz.it/~blumenthal/dtd/GraphCollection.dtd\">")
  606. fgroup.write("\n<GraphCollection>")
  607. for idx, g in enumerate(Gn):
  608. fname_tmp = "graph" + str(idx) + ".gxl"
  609. saveGXL(g, graph_dir + fname_tmp, **kwargs)
  610. fgroup.write("\n\t<graph file=\"" + fname_tmp + "\" class=\"" + str(y[idx]) + "\"/>")
  611. fgroup.write("\n</GraphCollection>")
  612. fgroup.close()
  613. if __name__ == '__main__':
  614. # ### Load dataset from .ds file.
  615. # # .ct files.
  616. # ds = {'name': 'Alkane', 'dataset': '../../datasets/Alkane/dataset.ds',
  617. # 'dataset_y': '../../datasets/Alkane/dataset_boiling_point_names.txt'}
  618. # Gn, y = loadDataset(ds['dataset'], filename_y=ds['dataset_y'])
  619. ## ds = {'name': 'Acyclic', 'dataset': '../../datasets/acyclic/dataset_bps.ds'} # node symb
  620. ## Gn, y = loadDataset(ds['dataset'])
  621. ## ds = {'name': 'MAO', 'dataset': '../../datasets/MAO/dataset.ds'} # node/edge symb
  622. ## Gn, y = loadDataset(ds['dataset'])
  623. ## ds = {'name': 'PAH', 'dataset': '../../datasets/PAH/dataset.ds'} # unlabeled
  624. ## Gn, y = loadDataset(ds['dataset'])
  625. # print(Gn[1].nodes(data=True))
  626. # print(Gn[1].edges(data=True))
  627. # print(y[1])
  628. # # .gxl file.
  629. # ds = {'name': 'monoterpenoides',
  630. # 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  631. # Gn, y = loadDataset(ds['dataset'])
  632. # print(Gn[1].nodes(data=True))
  633. # print(Gn[1].edges(data=True))
  634. # print(y[1])
  635. # ### Convert graph from one format to another.
  636. # # .gxl file.
  637. # import networkx as nx
  638. # ds = {'name': 'monoterpenoides',
  639. # 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  640. # Gn, y = loadDataset(ds['dataset'])
  641. # y = [int(i) for i in y]
  642. # print(Gn[1].nodes(data=True))
  643. # print(Gn[1].edges(data=True))
  644. # print(y[1])
  645. # # Convert a graph to the proper NetworkX format that can be recognized by library gedlib.
  646. # Gn_new = []
  647. # for G in Gn:
  648. # G_new = nx.Graph()
  649. # for nd, attrs in G.nodes(data=True):
  650. # G_new.add_node(str(nd), chem=attrs['atom'])
  651. # for nd1, nd2, attrs in G.edges(data=True):
  652. # G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  653. ## G_new.add_edge(str(nd1), str(nd2))
  654. # Gn_new.append(G_new)
  655. # print(Gn_new[1].nodes(data=True))
  656. # print(Gn_new[1].edges(data=True))
  657. # print(Gn_new[1])
  658. # filename = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/gxl/monoterpenoides'
  659. # xparams = {'method': 'gedlib'}
  660. # saveDataset(Gn, y, gformat='gxl', group='xml', filename=filename, xparams=xparams)
  661. # save dataset.
  662. # ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
  663. # 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  664. # Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  665. # saveDataset(Gn, y, group='xml', filename='temp/temp')
  666. # test - new way to add labels and attributes.
  667. # dataset = '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
  668. # dataset = '../../datasets/Fingerprint/Fingerprint_A.txt'
  669. # dataset = '../../datasets/Letter-med/Letter-med_A.txt'
  670. # dataset = '../../datasets/AIDS/AIDS_A.txt'
  671. # dataset = '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt'
  672. # Gn, y_all = loadDataset(dataset)
  673. pass

A Python package for graph kernels, graph edit distances and graph pre-image problem.