You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_files.py 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798
  1. """ Utilities function to manage graph files
  2. """
  3. from os.path import dirname, splitext
  4. def load_dataset(filename, filename_targets=None, gformat=None, **kwargs):
  5. """Read graph data from filename and load them as NetworkX graphs.
  6. Parameters
  7. ----------
  8. filename : string
  9. The name of the file from where the dataset is read.
  10. filename_y : string
  11. The name of file of the targets corresponding to graphs.
  12. extra_params : dict
  13. Extra parameters only designated to '.mat' format.
  14. Return
  15. ------
  16. data : List of NetworkX graph.
  17. y : List
  18. Targets corresponding to graphs.
  19. Notes
  20. -----
  21. This function supports following graph dataset formats:
  22. 'ds': load data from .ds file. See comments of function loadFromDS for a example.
  23. 'cxl': load data from Graph eXchange Language file (.cxl file). See
  24. `here <http://www.gupro.de/GXL/Introduction/background.html>`__ for detail.
  25. 'sdf': load data from structured data file (.sdf file). See
  26. `here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__
  27. for details.
  28. 'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See
  29. README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__
  30. for details.
  31. 'txt': Load graph data from a special .txt file. See
  32. `here <https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets>`__
  33. for details. Note here filename is the name of either .txt file in
  34. the dataset directory.
  35. """
  36. extension = splitext(filename)[1][1:]
  37. if extension == "ds":
  38. data, y, label_names = load_from_ds(filename, filename_targets)
  39. elif extension == "cxl":
  40. import xml.etree.ElementTree as ET
  41. dirname_dataset = dirname(filename)
  42. tree = ET.parse(filename)
  43. root = tree.getroot()
  44. data = []
  45. y = []
  46. for graph in root.iter('graph'):
  47. mol_filename = graph.attrib['file']
  48. mol_class = graph.attrib['class']
  49. data.append(load_gxl(dirname_dataset + '/' + mol_filename))
  50. y.append(mol_class)
  51. elif extension == 'xml':
  52. dir_dataset = kwargs.get('dirname_dataset', None)
  53. data, y = loadFromXML(filename, dir_dataset)
  54. elif extension == "sdf":
  55. # import numpy as np
  56. from tqdm import tqdm
  57. import sys
  58. data = loadSDF(filename)
  59. y_raw = open(filename_targets).read().splitlines()
  60. y_raw.pop(0)
  61. tmp0 = []
  62. tmp1 = []
  63. for i in range(0, len(y_raw)):
  64. tmp = y_raw[i].split(',')
  65. tmp0.append(tmp[0])
  66. tmp1.append(tmp[1].strip())
  67. y = []
  68. for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout):
  69. try:
  70. y.append(tmp1[tmp0.index(data[i].name)].strip())
  71. except ValueError: # if data[i].name not in tmp0
  72. data[i] = []
  73. data = list(filter(lambda a: a != [], data))
  74. elif extension == "mat":
  75. order = kwargs.get('order')
  76. data, y = loadMAT(filename, order)
  77. elif extension == 'txt':
  78. data, y, label_names = load_tud(filename)
  79. return data, y, label_names
  80. def save_dataset(Gn, y, gformat='gxl', group=None, filename='gfile', xparams=None):
  81. """Save list of graphs.
  82. """
  83. import os
  84. dirname_ds = os.path.dirname(filename)
  85. if dirname_ds != '':
  86. dirname_ds += '/'
  87. if not os.path.exists(dirname_ds) :
  88. os.makedirs(dirname_ds)
  89. if xparams is not None and 'graph_dir' in xparams:
  90. graph_dir = xparams['graph_dir'] + '/'
  91. if not os.path.exists(graph_dir):
  92. os.makedirs(graph_dir)
  93. else:
  94. graph_dir = dirname_ds
  95. if group == 'xml' and gformat == 'gxl':
  96. kwargs = {'method': xparams['method']} if xparams is not None else {}
  97. with open(filename + '.xml', 'w') as fgroup:
  98. fgroup.write("<?xml version=\"1.0\"?>")
  99. fgroup.write("\n<!DOCTYPE GraphCollection SYSTEM \"http://www.inf.unibz.it/~blumenthal/dtd/GraphCollection.dtd\">")
  100. fgroup.write("\n<GraphCollection>")
  101. for idx, g in enumerate(Gn):
  102. fname_tmp = "graph" + str(idx) + ".gxl"
  103. saveGXL(g, graph_dir + fname_tmp, **kwargs)
  104. fgroup.write("\n\t<graph file=\"" + fname_tmp + "\" class=\"" + str(y[idx]) + "\"/>")
  105. fgroup.write("\n</GraphCollection>")
  106. fgroup.close()
  107. def load_ct(filename):
  108. """load data from a Chemical Table (.ct) file.
  109. Notes
  110. ------
  111. a typical example of data in .ct is like this:
  112. 3 2 <- number of nodes and edges
  113. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  114. 0.0000 0.0000 0.0000 C
  115. 0.0000 0.0000 0.0000 O
  116. 1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo
  117. 2 3 1 1
  118. Check `CTFile Formats file <https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=10&ved=2ahUKEwivhaSdjsTlAhVhx4UKHczHA8gQFjAJegQIARAC&url=https%3A%2F%2Fwww.daylight.com%2Fmeetings%2Fmug05%2FKappler%2Fctfile.pdf&usg=AOvVaw1cDNrrmMClkFPqodlF2inS>`__
  119. for detailed format discription.
  120. """
  121. import networkx as nx
  122. from os.path import basename
  123. g = nx.Graph()
  124. with open(filename) as f:
  125. content = f.read().splitlines()
  126. g = nx.Graph(
  127. name = str(content[0]),
  128. filename = basename(filename)) # set name of the graph
  129. tmp = content[1].split(" ")
  130. if tmp[0] == '':
  131. nb_nodes = int(tmp[1]) # number of the nodes
  132. nb_edges = int(tmp[2]) # number of the edges
  133. else:
  134. nb_nodes = int(tmp[0])
  135. nb_edges = int(tmp[1])
  136. # patch for compatibility : label will be removed later
  137. for i in range(0, nb_nodes):
  138. tmp = content[i + 2].split(" ")
  139. tmp = [x for x in tmp if x != '']
  140. g.add_node(i, atom=tmp[3].strip(),
  141. label=[item.strip() for item in tmp[3:]],
  142. attributes=[item.strip() for item in tmp[0:3]])
  143. for i in range(0, nb_edges):
  144. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  145. tmp = [x for x in tmp if x != '']
  146. g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  147. bond_type=tmp[2].strip(),
  148. label=[item.strip() for item in tmp[2:]])
  149. return g
  150. def load_gxl(filename): # @todo: directed graphs.
  151. from os.path import basename
  152. import networkx as nx
  153. import xml.etree.ElementTree as ET
  154. tree = ET.parse(filename)
  155. root = tree.getroot()
  156. index = 0
  157. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  158. dic = {} # used to retrieve incident nodes of edges
  159. for node in root.iter('node'):
  160. dic[node.attrib['id']] = index
  161. labels = {}
  162. for attr in node.iter('attr'):
  163. labels[attr.attrib['name']] = attr[0].text
  164. g.add_node(index, **labels)
  165. index += 1
  166. for edge in root.iter('edge'):
  167. labels = {}
  168. for attr in edge.iter('attr'):
  169. labels[attr.attrib['name']] = attr[0].text
  170. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  171. # get label names.
  172. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  173. for node in root.iter('node'):
  174. for attr in node.iter('attr'):
  175. if attr[0].tag == 'int': # @todo: this maybe wrong, and slow.
  176. label_names['node_labels'].append(attr.attrib['name'])
  177. else:
  178. label_names['node_attrs'].append(attr.attrib['name'])
  179. break
  180. for edge in root.iter('edge'):
  181. for attr in edge.iter('attr'):
  182. if attr[0].tag == 'int': # @todo: this maybe wrong, and slow.
  183. label_names['edge_labels'].append(attr.attrib['name'])
  184. else:
  185. label_names['edge_attrs'].append(attr.attrib['name'])
  186. break
  187. return g, label_names
  188. def saveGXL(graph, filename, method='default', node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):
  189. if method == 'default':
  190. gxl_file = open(filename, 'w')
  191. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  192. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  193. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  194. if 'name' in graph.graph:
  195. name = str(graph.graph['name'])
  196. else:
  197. name = 'dummy'
  198. gxl_file.write("<graph id=\"" + name + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  199. for v, attrs in graph.nodes(data=True):
  200. gxl_file.write("<node id=\"_" + str(v) + "\">")
  201. for l_name in node_labels:
  202. gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
  203. str(attrs[l_name]) + "</int></attr>")
  204. for a_name in node_attrs:
  205. gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
  206. str(attrs[a_name]) + "</float></attr>")
  207. gxl_file.write("</node>\n")
  208. for v1, v2, attrs in graph.edges(data=True):
  209. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  210. for l_name in edge_labels:
  211. gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
  212. str(attrs[l_name]) + "</int></attr>")
  213. for a_name in edge_attrs:
  214. gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
  215. str(attrs[a_name]) + "</float></attr>")
  216. gxl_file.write("</edge>\n")
  217. gxl_file.write("</graph>\n")
  218. gxl_file.write("</gxl>")
  219. gxl_file.close()
  220. elif method == 'benoit':
  221. import xml.etree.ElementTree as ET
  222. root_node = ET.Element('gxl')
  223. attr = dict()
  224. attr['id'] = str(graph.graph['name'])
  225. attr['edgeids'] = 'true'
  226. attr['edgemode'] = 'undirected'
  227. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  228. for v in graph:
  229. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  230. for attr in graph.nodes[v].keys():
  231. cur_attr = ET.SubElement(
  232. current_node, 'attr', attrib={'name': attr})
  233. cur_value = ET.SubElement(cur_attr,
  234. graph.nodes[v][attr].__class__.__name__)
  235. cur_value.text = graph.nodes[v][attr]
  236. for v1 in graph:
  237. for v2 in graph[v1]:
  238. if (v1 < v2): # Non oriented graphs
  239. cur_edge = ET.SubElement(
  240. graph_node,
  241. 'edge',
  242. attrib={
  243. 'from': str(v1),
  244. 'to': str(v2)
  245. })
  246. for attr in graph[v1][v2].keys():
  247. cur_attr = ET.SubElement(
  248. cur_edge, 'attr', attrib={'name': attr})
  249. cur_value = ET.SubElement(
  250. cur_attr, graph[v1][v2][attr].__class__.__name__)
  251. cur_value.text = str(graph[v1][v2][attr])
  252. tree = ET.ElementTree(root_node)
  253. tree.write(filename)
  254. elif method == 'gedlib':
  255. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  256. # pass
  257. gxl_file = open(filename, 'w')
  258. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  259. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  260. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  261. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"true\" edgemode=\"undirected\">\n")
  262. for v, attrs in graph.nodes(data=True):
  263. gxl_file.write("<node id=\"_" + str(v) + "\">")
  264. gxl_file.write("<attr name=\"" + "chem" + "\"><int>" + str(attrs['chem']) + "</int></attr>")
  265. gxl_file.write("</node>\n")
  266. for v1, v2, attrs in graph.edges(data=True):
  267. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  268. gxl_file.write("<attr name=\"valence\"><int>" + str(attrs['valence']) + "</int></attr>")
  269. # gxl_file.write("<attr name=\"valence\"><int>" + "1" + "</int></attr>")
  270. gxl_file.write("</edge>\n")
  271. gxl_file.write("</graph>\n")
  272. gxl_file.write("</gxl>")
  273. gxl_file.close()
  274. elif method == 'gedlib-letter':
  275. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  276. # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl
  277. gxl_file = open(filename, 'w')
  278. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  279. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  280. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  281. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  282. for v, attrs in graph.nodes(data=True):
  283. gxl_file.write("<node id=\"_" + str(v) + "\">")
  284. gxl_file.write("<attr name=\"x\"><float>" + str(attrs['attributes'][0]) + "</float></attr>")
  285. gxl_file.write("<attr name=\"y\"><float>" + str(attrs['attributes'][1]) + "</float></attr>")
  286. gxl_file.write("</node>\n")
  287. for v1, v2, attrs in graph.edges(data=True):
  288. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\"/>\n")
  289. gxl_file.write("</graph>\n")
  290. gxl_file.write("</gxl>")
  291. gxl_file.close()
  292. def loadSDF(filename):
  293. """load data from structured data file (.sdf file).
  294. Notes
  295. ------
  296. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  297. Check `here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__ for detailed structure.
  298. """
  299. import networkx as nx
  300. from os.path import basename
  301. from tqdm import tqdm
  302. import sys
  303. data = []
  304. with open(filename) as f:
  305. content = f.read().splitlines()
  306. index = 0
  307. pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  308. while index < len(content):
  309. index_old = index
  310. g = nx.Graph(name=content[index].strip()) # set name of the graph
  311. tmp = content[index + 3]
  312. nb_nodes = int(tmp[:3]) # number of the nodes
  313. nb_edges = int(tmp[3:6]) # number of the edges
  314. for i in range(0, nb_nodes):
  315. tmp = content[i + index + 4]
  316. g.add_node(i, atom=tmp[31:34].strip())
  317. for i in range(0, nb_edges):
  318. tmp = content[i + index + g.number_of_nodes() + 4]
  319. tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  320. g.add_edge(
  321. int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  322. data.append(g)
  323. index += 4 + g.number_of_nodes() + g.number_of_edges()
  324. while content[index].strip() != '$$$$': # seperator
  325. index += 1
  326. index += 1
  327. pbar.update(index - index_old)
  328. pbar.update(1)
  329. pbar.close()
  330. return data
  331. def loadMAT(filename, order):
  332. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  333. Notes
  334. ------
  335. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  336. Check README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__ for detailed structure.
  337. """
  338. from scipy.io import loadmat
  339. import numpy as np
  340. import networkx as nx
  341. data = []
  342. content = loadmat(filename)
  343. # print(content)
  344. # print('----')
  345. for key, value in content.items():
  346. if key[0] == 'l': # class label
  347. y = np.transpose(value)[0].tolist()
  348. # print(y)
  349. elif key[0] != '_':
  350. # print(value[0][0][0])
  351. # print()
  352. # print(value[0][0][1])
  353. # print()
  354. # print(value[0][0][2])
  355. # print()
  356. # if len(value[0][0]) > 3:
  357. # print(value[0][0][3])
  358. # print('----')
  359. # if adjacency matrix is not compressed / edge label exists
  360. if order[1] == 0:
  361. for i, item in enumerate(value[0]):
  362. # print(item)
  363. # print('------')
  364. g = nx.Graph(name=i) # set name of the graph
  365. nl = np.transpose(item[order[3]][0][0][0]) # node label
  366. # print(item[order[3]])
  367. # print()
  368. for index, label in enumerate(nl[0]):
  369. g.add_node(index, atom=str(label))
  370. el = item[order[4]][0][0][0] # edge label
  371. for edge in el:
  372. g.add_edge(
  373. edge[0] - 1, edge[1] - 1, bond_type=str(edge[2]))
  374. data.append(g)
  375. else:
  376. from scipy.sparse import csc_matrix
  377. for i, item in enumerate(value[0]):
  378. # print(item)
  379. # print('------')
  380. g = nx.Graph(name=i) # set name of the graph
  381. nl = np.transpose(item[order[3]][0][0][0]) # node label
  382. # print(nl)
  383. # print()
  384. for index, label in enumerate(nl[0]):
  385. g.add_node(index, atom=str(label))
  386. sam = item[order[0]] # sparse adjacency matrix
  387. index_no0 = sam.nonzero()
  388. for col, row in zip(index_no0[0], index_no0[1]):
  389. # print(col)
  390. # print(row)
  391. g.add_edge(col, row)
  392. data.append(g)
  393. # print(g.edges(data=True))
  394. return data, y
  395. def load_tud(filename):
  396. """Load graph data from TUD dataset files.
  397. Notes
  398. ------
  399. The graph data is loaded from separate files.
  400. Check README in `downloadable file <http://tiny.cc/PK_MLJ_data>`__, 2018 for detailed structure.
  401. """
  402. import networkx as nx
  403. from os import listdir
  404. from os.path import dirname, basename
  405. def get_infos_from_readme(frm): # @todo: add README (cuniform), maybe node/edge label maps.
  406. """Get information from DS_label_readme.txt file.
  407. """
  408. def get_label_names_from_line(line):
  409. """Get names of labels/attributes from a line.
  410. """
  411. str_names = line.split('[')[1].split(']')[0]
  412. names = str_names.split(',')
  413. names = [attr.strip() for attr in names]
  414. return names
  415. def get_class_label_map(label_map_strings):
  416. label_map = {}
  417. for string in label_map_strings:
  418. integer, label = string.split('\t')
  419. label_map[int(integer.strip())] = label.strip()
  420. return label_map
  421. label_names = {'node_labels': [], 'node_attrs': [],
  422. 'edge_labels': [], 'edge_attrs': []}
  423. class_label_map = None
  424. class_label_map_strings = []
  425. content_rm = open(frm).read().splitlines()
  426. i = 0
  427. while i < len(content_rm):
  428. line = content_rm[i].strip()
  429. # get node/edge labels and attributes.
  430. if line.startswith('Node labels:'):
  431. label_names['node_labels'] = get_label_names_from_line(line)
  432. elif line.startswith('Node attributes:'):
  433. label_names['node_attrs'] = get_label_names_from_line(line)
  434. elif line.startswith('Edge labels:'):
  435. label_names['edge_labels'] = get_label_names_from_line(line)
  436. elif line.startswith('Edge attributes:'):
  437. label_names['edge_attrs'] = get_label_names_from_line(line)
  438. # get class label map.
  439. elif line.startswith('Class labels were converted to integer values using this map:'):
  440. i += 2
  441. line = content_rm[i].strip()
  442. while line != '' and i < len(content_rm):
  443. class_label_map_strings.append(line)
  444. i += 1
  445. line = content_rm[i].strip()
  446. class_label_map = get_class_label_map(class_label_map_strings)
  447. i += 1
  448. return label_names, class_label_map
  449. # get dataset name.
  450. dirname_dataset = dirname(filename)
  451. filename = basename(filename)
  452. fn_split = filename.split('_A')
  453. ds_name = fn_split[0].strip()
  454. # load data file names
  455. for name in listdir(dirname_dataset):
  456. if ds_name + '_A' in name:
  457. fam = dirname_dataset + '/' + name
  458. elif ds_name + '_graph_indicator' in name:
  459. fgi = dirname_dataset + '/' + name
  460. elif ds_name + '_graph_labels' in name:
  461. fgl = dirname_dataset + '/' + name
  462. elif ds_name + '_node_labels' in name:
  463. fnl = dirname_dataset + '/' + name
  464. elif ds_name + '_edge_labels' in name:
  465. fel = dirname_dataset + '/' + name
  466. elif ds_name + '_edge_attributes' in name:
  467. fea = dirname_dataset + '/' + name
  468. elif ds_name + '_node_attributes' in name:
  469. fna = dirname_dataset + '/' + name
  470. elif ds_name + '_graph_attributes' in name:
  471. fga = dirname_dataset + '/' + name
  472. elif ds_name + '_label_readme' in name:
  473. frm = dirname_dataset + '/' + name
  474. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  475. elif ds_name + '_attributes' in name:
  476. fna = dirname_dataset + '/' + name
  477. # get labels and attributes names.
  478. if 'frm' in locals():
  479. label_names, class_label_map = get_infos_from_readme(frm)
  480. else:
  481. label_names = {'node_labels': [], 'node_attrs': [],
  482. 'edge_labels': [], 'edge_attrs': []}
  483. class_label_map = None
  484. content_gi = open(fgi).read().splitlines() # graph indicator
  485. content_am = open(fam).read().splitlines() # adjacency matrix
  486. # load targets.
  487. if 'fgl' in locals():
  488. content_targets = open(fgl).read().splitlines() # targets (classification)
  489. targets = [float(i) for i in content_targets]
  490. elif 'fga' in locals():
  491. content_targets = open(fga).read().splitlines() # targets (regression)
  492. targets = [int(i) for i in content_targets]
  493. if class_label_map is not None:
  494. targets = [class_label_map[t] for t in targets]
  495. else:
  496. raise Exception('Can not find targets file. Please make sure there is a "', ds_name, '_graph_labels.txt" or "', ds_name, '_graph_attributes.txt"', 'file in your dataset folder.')
  497. # create graphs and add nodes
  498. data = [nx.Graph(name=str(i)) for i in range(0, len(content_targets))]
  499. if 'fnl' in locals():
  500. content_nl = open(fnl).read().splitlines() # node labels
  501. for idx, line in enumerate(content_gi):
  502. # transfer to int first in case of unexpected blanks
  503. data[int(line) - 1].add_node(idx)
  504. labels = [l.strip() for l in content_nl[idx].split(',')]
  505. if label_names['node_labels'] == []: # @todo: need fix bug.
  506. for i, label in enumerate(labels):
  507. l_name = 'label_' + str(i)
  508. data[int(line) - 1].nodes[idx][l_name] = label
  509. label_names['node_labels'].append(l_name)
  510. else:
  511. for i, l_name in enumerate(label_names['node_labels']):
  512. data[int(line) - 1].nodes[idx][l_name] = labels[i]
  513. else:
  514. for i, line in enumerate(content_gi):
  515. data[int(line) - 1].add_node(i)
  516. # add edges
  517. for line in content_am:
  518. tmp = line.split(',')
  519. n1 = int(tmp[0]) - 1
  520. n2 = int(tmp[1]) - 1
  521. # ignore edge weight here.
  522. g = int(content_gi[n1]) - 1
  523. data[g].add_edge(n1, n2)
  524. # add edge labels
  525. if 'fel' in locals():
  526. content_el = open(fel).read().splitlines()
  527. for idx, line in enumerate(content_el):
  528. labels = [l.strip() for l in line.split(',')]
  529. n = [int(i) - 1 for i in content_am[idx].split(',')]
  530. g = int(content_gi[n[0]]) - 1
  531. if label_names['edge_labels'] == []:
  532. for i, label in enumerate(labels):
  533. l_name = 'label_' + str(i)
  534. data[g].edges[n[0], n[1]][l_name] = label
  535. label_names['edge_labels'].append(l_name)
  536. else:
  537. for i, l_name in enumerate(label_names['edge_labels']):
  538. data[g].edges[n[0], n[1]][l_name] = labels[i]
  539. # add node attributes
  540. if 'fna' in locals():
  541. content_na = open(fna).read().splitlines()
  542. for idx, line in enumerate(content_na):
  543. attrs = [a.strip() for a in line.split(',')]
  544. g = int(content_gi[idx]) - 1
  545. if label_names['node_attrs'] == []:
  546. for i, attr in enumerate(attrs):
  547. a_name = 'attr_' + str(i)
  548. data[g].nodes[idx][a_name] = attr
  549. label_names['node_attrs'].append(a_name)
  550. else:
  551. for i, a_name in enumerate(label_names['node_attrs']):
  552. data[g].nodes[idx][a_name] = attrs[i]
  553. # add edge attributes
  554. if 'fea' in locals():
  555. content_ea = open(fea).read().splitlines()
  556. for idx, line in enumerate(content_ea):
  557. attrs = [a.strip() for a in line.split(',')]
  558. n = [int(i) - 1 for i in content_am[idx].split(',')]
  559. g = int(content_gi[n[0]]) - 1
  560. if label_names['edge_attrs'] == []:
  561. for i, attr in enumerate(attrs):
  562. a_name = 'attr_' + str(i)
  563. data[g].edges[n[0], n[1]][a_name] = attr
  564. label_names['edge_attrs'].append(a_name)
  565. else:
  566. for i, a_name in enumerate(label_names['edge_attrs']):
  567. data[g].edges[n[0], n[1]][a_name] = attrs[i]
  568. return data, targets, label_names
  569. def loadFromXML(filename, dir_dataset=None):
  570. import xml.etree.ElementTree as ET
  571. if dir_dataset is not None:
  572. dir_dataset = dir_dataset
  573. else:
  574. dir_dataset = dirname(filename)
  575. tree = ET.parse(filename)
  576. root = tree.getroot()
  577. data = []
  578. y = []
  579. for graph in root.iter('graph'):
  580. mol_filename = graph.attrib['file']
  581. mol_class = graph.attrib['class']
  582. data.append(load_gxl(dir_dataset + '/' + mol_filename))
  583. y.append(mol_class)
  584. return data, y
  585. def load_from_ds(filename, filename_targets):
  586. """Load data from .ds file.
  587. Possible graph formats include:
  588. '.ct': see function load_ct for detail.
  589. '.gxl': see dunction load_gxl for detail.
  590. Note these graph formats are checked automatically by the extensions of
  591. graph files.
  592. """
  593. def append_label_names(label_names, new_names):
  594. for key, val in label_names.items():
  595. label_names[key] += [name for name in new_names[key] if name not in val]
  596. dirname_dataset = dirname(filename)
  597. data = []
  598. y = []
  599. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  600. content = open(filename).read().splitlines()
  601. extension = splitext(content[0].split(' ')[0])[1][1:]
  602. if filename_targets is None or filename_targets == '':
  603. if extension == 'ct':
  604. for i in range(0, len(content)):
  605. tmp = content[i].split(' ')
  606. # remove the '#'s in file names
  607. data.append(
  608. load_ct(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  609. y.append(float(tmp[1]))
  610. elif extension == 'gxl':
  611. for i in range(0, len(content)):
  612. tmp = content[i].split(' ')
  613. # remove the '#'s in file names
  614. g, l_names = load_gxl(dirname_dataset + '/' + tmp[0].replace('#', '', 1))
  615. data.append(g)
  616. append_label_names(label_names, l_names)
  617. y.append(float(tmp[1]))
  618. else: # y in a seperate file
  619. if extension == 'ct':
  620. for i in range(0, len(content)):
  621. tmp = content[i]
  622. # remove the '#'s in file names
  623. data.append(
  624. load_ct(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  625. elif extension == 'gxl':
  626. for i in range(0, len(content)):
  627. tmp = content[i]
  628. # remove the '#'s in file names
  629. g, l_names = load_gxl(dirname_dataset + '/' + tmp[0].replace('#', '', 1))
  630. data.append(g)
  631. append_label_names(label_names, l_names)
  632. content_y = open(filename_targets).read().splitlines()
  633. # assume entries in filename and filename_targets have the same order.
  634. for item in content_y:
  635. tmp = item.split(' ')
  636. # assume the 3rd entry in a line is y (for Alkane dataset)
  637. y.append(float(tmp[2]))
  638. return data, y, label_names
  639. if __name__ == '__main__':
  640. # ### Load dataset from .ds file.
  641. # # .ct files.
  642. # ds = {'name': 'Alkane', 'dataset': '../../datasets/Alkane/dataset.ds',
  643. # 'dataset_y': '../../datasets/Alkane/dataset_boiling_point_names.txt'}
  644. # Gn, y = loadDataset(ds['dataset'], filename_y=ds['dataset_y'])
  645. ## ds = {'name': 'Acyclic', 'dataset': '../../datasets/acyclic/dataset_bps.ds'} # node symb
  646. ## Gn, y = loadDataset(ds['dataset'])
  647. ## ds = {'name': 'MAO', 'dataset': '../../datasets/MAO/dataset.ds'} # node/edge symb
  648. ## Gn, y = loadDataset(ds['dataset'])
  649. ## ds = {'name': 'PAH', 'dataset': '../../datasets/PAH/dataset.ds'} # unlabeled
  650. ## Gn, y = loadDataset(ds['dataset'])
  651. # print(Gn[1].nodes(data=True))
  652. # print(Gn[1].edges(data=True))
  653. # print(y[1])
  654. # .gxl file.
  655. ds = {'name': 'monoterpenoides',
  656. 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  657. Gn, y, label_names = load_dataset(ds['dataset'])
  658. print(Gn[1].graph)
  659. print(Gn[1].nodes(data=True))
  660. print(Gn[1].edges(data=True))
  661. print(y[1])
  662. # ### Convert graph from one format to another.
  663. # # .gxl file.
  664. # import networkx as nx
  665. # ds = {'name': 'monoterpenoides',
  666. # 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  667. # Gn, y = loadDataset(ds['dataset'])
  668. # y = [int(i) for i in y]
  669. # print(Gn[1].nodes(data=True))
  670. # print(Gn[1].edges(data=True))
  671. # print(y[1])
  672. # # Convert a graph to the proper NetworkX format that can be recognized by library gedlib.
  673. # Gn_new = []
  674. # for G in Gn:
  675. # G_new = nx.Graph()
  676. # for nd, attrs in G.nodes(data=True):
  677. # G_new.add_node(str(nd), chem=attrs['atom'])
  678. # for nd1, nd2, attrs in G.edges(data=True):
  679. # G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  680. ## G_new.add_edge(str(nd1), str(nd2))
  681. # Gn_new.append(G_new)
  682. # print(Gn_new[1].nodes(data=True))
  683. # print(Gn_new[1].edges(data=True))
  684. # print(Gn_new[1])
  685. # filename = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/gxl/monoterpenoides'
  686. # xparams = {'method': 'gedlib'}
  687. # saveDataset(Gn, y, gformat='gxl', group='xml', filename=filename, xparams=xparams)
  688. # save dataset.
  689. # ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
  690. # 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  691. # Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  692. # saveDataset(Gn, y, group='xml', filename='temp/temp')
  693. # test - new way to add labels and attributes.
  694. # dataset = '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
  695. # filename = '../../datasets/Fingerprint/Fingerprint_A.txt'
  696. # dataset = '../../datasets/Letter-med/Letter-med_A.txt'
  697. # dataset = '../../datasets/AIDS/AIDS_A.txt'
  698. # dataset = '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt'
  699. # Gn, targets, label_names = load_dataset(filename)
  700. pass

A Python package for graph kernels, graph edit distances and graph pre-image problem.