You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_files.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. """ Utilities function to manage graph files
  2. """
  3. import warnings
  4. warnings.simplefilter('always', DeprecationWarning)
  5. warnings.warn('The functions in the module "gklearn.utils.graph_files" will be deprecated and removed since version 0.4.0. Use the corresponding functions in the module "gklearn.dataset" instead.', DeprecationWarning)
  6. from os.path import dirname, splitext
  7. def load_dataset(filename, filename_targets=None, gformat=None, **kwargs):
  8. """Read graph data from filename and load them as NetworkX graphs.
  9. Parameters
  10. ----------
  11. filename : string
  12. The name of the file from where the dataset is read.
  13. filename_y : string
  14. The name of file of the targets corresponding to graphs.
  15. extra_params : dict
  16. Extra parameters only designated to '.mat' format.
  17. Return
  18. ------
  19. data : List of NetworkX graph.
  20. y : List
  21. Targets corresponding to graphs.
  22. Notes
  23. -----
  24. This function supports following graph dataset formats:
  25. 'ds': load data from .ds file. See comments of function loadFromDS for a example.
  26. 'cxl': load data from Graph eXchange Language file (.cxl file). See
  27. `here <http://www.gupro.de/GXL/Introduction/background.html>`__ for detail.
  28. 'sdf': load data from structured data file (.sdf file). See
  29. `here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__
  30. for details.
  31. 'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See
  32. README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__
  33. for details.
  34. 'txt': Load graph data from a special .txt file. See
  35. `here <https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets>`__
  36. for details. Note here filename is the name of either .txt file in
  37. the dataset directory.
  38. """
  39. import warnings
  40. warnings.simplefilter('always', DeprecationWarning)
  41. warnings.warn('The function "gklearn.utils.load_dataset" will be deprecated and removed since version 0.4.0. Use the class "gklearn.dataset.DataLoader" instead.', DeprecationWarning)
  42. extension = splitext(filename)[1][1:]
  43. if extension == "ds":
  44. data, y, label_names = load_from_ds(filename, filename_targets)
  45. elif extension == "cxl":
  46. dir_dataset = kwargs.get('dirname_dataset', None)
  47. data, y, label_names = load_from_xml(filename, dir_dataset)
  48. elif extension == 'xml':
  49. dir_dataset = kwargs.get('dirname_dataset', None)
  50. data, y, label_names = load_from_xml(filename, dir_dataset)
  51. elif extension == "mat":
  52. order = kwargs.get('order')
  53. data, y, label_names = load_mat(filename, order)
  54. elif extension == 'txt':
  55. data, y, label_names = load_tud(filename)
  56. return data, y, label_names
  57. def save_dataset(Gn, y, gformat='gxl', group=None, filename='gfile', **kwargs):
  58. """Save list of graphs.
  59. """
  60. import warnings
  61. warnings.simplefilter('always', DeprecationWarning)
  62. warnings.warn('The function "gklearn.utils.save_dataset" will be deprecated and removed since version 0.4.0. Use the class "gklearn.dataset.DataSaver" instead.', DeprecationWarning)
  63. import os
  64. dirname_ds = os.path.dirname(filename)
  65. if dirname_ds != '':
  66. dirname_ds += '/'
  67. os.makedirs(dirname_ds, exist_ok=True)
  68. if 'graph_dir' in kwargs:
  69. graph_dir = kwargs['graph_dir'] + '/'
  70. os.makedirs(graph_dir, exist_ok=True)
  71. del kwargs['graph_dir']
  72. else:
  73. graph_dir = dirname_ds
  74. if group == 'xml' and gformat == 'gxl':
  75. with open(filename + '.xml', 'w') as fgroup:
  76. fgroup.write("<?xml version=\"1.0\"?>")
  77. fgroup.write("\n<!DOCTYPE GraphCollection SYSTEM \"http://www.inf.unibz.it/~blumenthal/dtd/GraphCollection.dtd\">")
  78. fgroup.write("\n<GraphCollection>")
  79. for idx, g in enumerate(Gn):
  80. fname_tmp = "graph" + str(idx) + ".gxl"
  81. save_gxl(g, graph_dir + fname_tmp, **kwargs)
  82. fgroup.write("\n\t<graph file=\"" + fname_tmp + "\" class=\"" + str(y[idx]) + "\"/>")
  83. fgroup.write("\n</GraphCollection>")
  84. fgroup.close()
  85. def load_ct(filename): # @todo: this function is only tested on CTFile V2000; header not considered; only simple cases (atoms and bonds are considered.)
  86. """load data from a Chemical Table (.ct) file.
  87. Notes
  88. ------
  89. a typical example of data in .ct is like this:
  90. 3 2 <- number of nodes and edges
  91. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  92. 0.0000 0.0000 0.0000 C
  93. 0.0000 0.0000 0.0000 O
  94. 1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo
  95. 2 3 1 1
  96. Check `CTFile Formats file <https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=10&ved=2ahUKEwivhaSdjsTlAhVhx4UKHczHA8gQFjAJegQIARAC&url=https%3A%2F%2Fwww.daylight.com%2Fmeetings%2Fmug05%2FKappler%2Fctfile.pdf&usg=AOvVaw1cDNrrmMClkFPqodlF2inS>`__
  97. for detailed format discription.
  98. """
  99. import networkx as nx
  100. from os.path import basename
  101. g = nx.Graph()
  102. with open(filename) as f:
  103. content = f.read().splitlines()
  104. g = nx.Graph(name=str(content[0]), filename=basename(filename)) # set name of the graph
  105. # read the counts line.
  106. tmp = content[1].split(' ')
  107. tmp = [x for x in tmp if x != '']
  108. nb_atoms = int(tmp[0].strip()) # number of atoms
  109. nb_bonds = int(tmp[1].strip()) # number of bonds
  110. count_line_tags = ['number_of_atoms', 'number_of_bonds', 'number_of_atom_lists', '', 'chiral_flag', 'number_of_stext_entries', '', '', '', '', 'number_of_properties', 'CT_version']
  111. i = 0
  112. while i < len(tmp):
  113. if count_line_tags[i] != '': # if not obsoleted
  114. g.graph[count_line_tags[i]] = tmp[i].strip()
  115. i += 1
  116. # read the atom block.
  117. atom_tags = ['x', 'y', 'z', 'atom_symbol', 'mass_difference', 'charge', 'atom_stereo_parity', 'hydrogen_count_plus_1', 'stereo_care_box', 'valence', 'h0_designator', '', '', 'atom_atom_mapping_number', 'inversion_retention_flag', 'exact_change_flag']
  118. for i in range(0, nb_atoms):
  119. tmp = content[i + 2].split(' ')
  120. tmp = [x for x in tmp if x != '']
  121. g.add_node(i)
  122. j = 0
  123. while j < len(tmp):
  124. if atom_tags[j] != '':
  125. g.nodes[i][atom_tags[j]] = tmp[j].strip()
  126. j += 1
  127. # read the bond block.
  128. bond_tags = ['first_atom_number', 'second_atom_number', 'bond_type', 'bond_stereo', '', 'bond_topology', 'reacting_center_status']
  129. for i in range(0, nb_bonds):
  130. tmp = content[i + g.number_of_nodes() + 2].split(' ')
  131. tmp = [x for x in tmp if x != '']
  132. n1, n2 = int(tmp[0].strip()) - 1, int(tmp[1].strip()) - 1
  133. g.add_edge(n1, n2)
  134. j = 2
  135. while j < len(tmp):
  136. if bond_tags[j] != '':
  137. g.edges[(n1, n2)][bond_tags[j]] = tmp[j].strip()
  138. j += 1
  139. # get label names.
  140. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  141. atom_symbolic = [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, None, None, 1, 1, 1]
  142. for nd in g.nodes():
  143. for key in g.nodes[nd]:
  144. if atom_symbolic[atom_tags.index(key)] == 1:
  145. label_names['node_labels'].append(key)
  146. else:
  147. label_names['node_attrs'].append(key)
  148. break
  149. bond_symbolic = [None, None, 1, 1, None, 1, 1]
  150. for ed in g.edges():
  151. for key in g.edges[ed]:
  152. if bond_symbolic[bond_tags.index(key)] == 1:
  153. label_names['edge_labels'].append(key)
  154. else:
  155. label_names['edge_attrs'].append(key)
  156. break
  157. return g, label_names
  158. def load_gxl(filename): # @todo: directed graphs.
  159. from os.path import basename
  160. import networkx as nx
  161. import xml.etree.ElementTree as ET
  162. tree = ET.parse(filename)
  163. root = tree.getroot()
  164. index = 0
  165. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  166. dic = {} # used to retrieve incident nodes of edges
  167. for node in root.iter('node'):
  168. dic[node.attrib['id']] = index
  169. labels = {}
  170. for attr in node.iter('attr'):
  171. labels[attr.attrib['name']] = attr[0].text
  172. g.add_node(index, **labels)
  173. index += 1
  174. for edge in root.iter('edge'):
  175. labels = {}
  176. for attr in edge.iter('attr'):
  177. labels[attr.attrib['name']] = attr[0].text
  178. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  179. # get label names.
  180. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  181. for node in root.iter('node'):
  182. for attr in node.iter('attr'):
  183. if attr[0].tag == 'int': # @todo: this maybe wrong, and slow.
  184. label_names['node_labels'].append(attr.attrib['name'])
  185. else:
  186. label_names['node_attrs'].append(attr.attrib['name'])
  187. break
  188. for edge in root.iter('edge'):
  189. for attr in edge.iter('attr'):
  190. if attr[0].tag == 'int': # @todo: this maybe wrong, and slow.
  191. label_names['edge_labels'].append(attr.attrib['name'])
  192. else:
  193. label_names['edge_attrs'].append(attr.attrib['name'])
  194. break
  195. return g, label_names
  196. def save_gxl(graph, filename, method='default', node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):
  197. if method == 'default':
  198. gxl_file = open(filename, 'w')
  199. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  200. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  201. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  202. if 'name' in graph.graph:
  203. name = str(graph.graph['name'])
  204. else:
  205. name = 'dummy'
  206. gxl_file.write("<graph id=\"" + name + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  207. for v, attrs in graph.nodes(data=True):
  208. gxl_file.write("<node id=\"_" + str(v) + "\">")
  209. for l_name in node_labels:
  210. gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
  211. str(attrs[l_name]) + "</int></attr>")
  212. for a_name in node_attrs:
  213. gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
  214. str(attrs[a_name]) + "</float></attr>")
  215. gxl_file.write("</node>\n")
  216. for v1, v2, attrs in graph.edges(data=True):
  217. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  218. for l_name in edge_labels:
  219. gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
  220. str(attrs[l_name]) + "</int></attr>")
  221. for a_name in edge_attrs:
  222. gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
  223. str(attrs[a_name]) + "</float></attr>")
  224. gxl_file.write("</edge>\n")
  225. gxl_file.write("</graph>\n")
  226. gxl_file.write("</gxl>")
  227. gxl_file.close()
  228. elif method == 'benoit':
  229. import xml.etree.ElementTree as ET
  230. root_node = ET.Element('gxl')
  231. attr = dict()
  232. attr['id'] = str(graph.graph['name'])
  233. attr['edgeids'] = 'true'
  234. attr['edgemode'] = 'undirected'
  235. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  236. for v in graph:
  237. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  238. for attr in graph.nodes[v].keys():
  239. cur_attr = ET.SubElement(
  240. current_node, 'attr', attrib={'name': attr})
  241. cur_value = ET.SubElement(cur_attr,
  242. graph.nodes[v][attr].__class__.__name__)
  243. cur_value.text = graph.nodes[v][attr]
  244. for v1 in graph:
  245. for v2 in graph[v1]:
  246. if (v1 < v2): # Non oriented graphs
  247. cur_edge = ET.SubElement(
  248. graph_node,
  249. 'edge',
  250. attrib={
  251. 'from': str(v1),
  252. 'to': str(v2)
  253. })
  254. for attr in graph[v1][v2].keys():
  255. cur_attr = ET.SubElement(
  256. cur_edge, 'attr', attrib={'name': attr})
  257. cur_value = ET.SubElement(
  258. cur_attr, graph[v1][v2][attr].__class__.__name__)
  259. cur_value.text = str(graph[v1][v2][attr])
  260. tree = ET.ElementTree(root_node)
  261. tree.write(filename)
  262. elif method == 'gedlib':
  263. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  264. # pass
  265. gxl_file = open(filename, 'w')
  266. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  267. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  268. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  269. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"true\" edgemode=\"undirected\">\n")
  270. for v, attrs in graph.nodes(data=True):
  271. gxl_file.write("<node id=\"_" + str(v) + "\">")
  272. gxl_file.write("<attr name=\"" + "chem" + "\"><int>" + str(attrs['chem']) + "</int></attr>")
  273. gxl_file.write("</node>\n")
  274. for v1, v2, attrs in graph.edges(data=True):
  275. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  276. gxl_file.write("<attr name=\"valence\"><int>" + str(attrs['valence']) + "</int></attr>")
  277. # gxl_file.write("<attr name=\"valence\"><int>" + "1" + "</int></attr>")
  278. gxl_file.write("</edge>\n")
  279. gxl_file.write("</graph>\n")
  280. gxl_file.write("</gxl>")
  281. gxl_file.close()
  282. elif method == 'gedlib-letter':
  283. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  284. # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl
  285. gxl_file = open(filename, 'w')
  286. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  287. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  288. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  289. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  290. for v, attrs in graph.nodes(data=True):
  291. gxl_file.write("<node id=\"_" + str(v) + "\">")
  292. gxl_file.write("<attr name=\"x\"><float>" + str(attrs['attributes'][0]) + "</float></attr>")
  293. gxl_file.write("<attr name=\"y\"><float>" + str(attrs['attributes'][1]) + "</float></attr>")
  294. gxl_file.write("</node>\n")
  295. for v1, v2, attrs in graph.edges(data=True):
  296. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\"/>\n")
  297. gxl_file.write("</graph>\n")
  298. gxl_file.write("</gxl>")
  299. gxl_file.close()
  300. # def loadSDF(filename):
  301. # """load data from structured data file (.sdf file).
  302. # Notes
  303. # ------
  304. # A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  305. # Check `here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__ for detailed structure.
  306. # """
  307. # import networkx as nx
  308. # from os.path import basename
  309. # from tqdm import tqdm
  310. # import sys
  311. # data = []
  312. # with open(filename) as f:
  313. # content = f.read().splitlines()
  314. # index = 0
  315. # pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  316. # while index < len(content):
  317. # index_old = index
  318. # g = nx.Graph(name=content[index].strip()) # set name of the graph
  319. # tmp = content[index + 3]
  320. # nb_nodes = int(tmp[:3]) # number of the nodes
  321. # nb_edges = int(tmp[3:6]) # number of the edges
  322. # for i in range(0, nb_nodes):
  323. # tmp = content[i + index + 4]
  324. # g.add_node(i, atom=tmp[31:34].strip())
  325. # for i in range(0, nb_edges):
  326. # tmp = content[i + index + g.number_of_nodes() + 4]
  327. # tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  328. # g.add_edge(
  329. # int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  330. # data.append(g)
  331. # index += 4 + g.number_of_nodes() + g.number_of_edges()
  332. # while content[index].strip() != '$$$$': # seperator
  333. # index += 1
  334. # index += 1
  335. # pbar.update(index - index_old)
  336. # pbar.update(1)
  337. # pbar.close()
  338. # return data
  339. def load_mat(filename, order): # @todo: need to be updated (auto order) or deprecated.
  340. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  341. Notes
  342. ------
  343. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  344. Check README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__ for detailed structure.
  345. """
  346. from scipy.io import loadmat
  347. import numpy as np
  348. import networkx as nx
  349. data = []
  350. content = loadmat(filename)
  351. # print(content)
  352. # print('----')
  353. for key, value in content.items():
  354. if key[0] == 'l': # class label
  355. y = np.transpose(value)[0].tolist()
  356. # print(y)
  357. elif key[0] != '_':
  358. # print(value[0][0][0])
  359. # print()
  360. # print(value[0][0][1])
  361. # print()
  362. # print(value[0][0][2])
  363. # print()
  364. # if len(value[0][0]) > 3:
  365. # print(value[0][0][3])
  366. # print('----')
  367. # if adjacency matrix is not compressed / edge label exists
  368. if order[1] == 0:
  369. for i, item in enumerate(value[0]):
  370. # print(item)
  371. # print('------')
  372. g = nx.Graph(name=i) # set name of the graph
  373. nl = np.transpose(item[order[3]][0][0][0]) # node label
  374. # print(item[order[3]])
  375. # print()
  376. for index, label in enumerate(nl[0]):
  377. g.add_node(index, label_1=str(label))
  378. el = item[order[4]][0][0][0] # edge label
  379. for edge in el:
  380. g.add_edge(edge[0] - 1, edge[1] - 1, label_1=str(edge[2]))
  381. data.append(g)
  382. else:
  383. # from scipy.sparse import csc_matrix
  384. for i, item in enumerate(value[0]):
  385. # print(item)
  386. # print('------')
  387. g = nx.Graph(name=i) # set name of the graph
  388. nl = np.transpose(item[order[3]][0][0][0]) # node label
  389. # print(nl)
  390. # print()
  391. for index, label in enumerate(nl[0]):
  392. g.add_node(index, label_1=str(label))
  393. sam = item[order[0]] # sparse adjacency matrix
  394. index_no0 = sam.nonzero()
  395. for col, row in zip(index_no0[0], index_no0[1]):
  396. # print(col)
  397. # print(row)
  398. g.add_edge(col, row)
  399. data.append(g)
  400. # print(g.edges(data=True))
  401. label_names = {'node_labels': ['label_1'], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  402. if order[1] == 0:
  403. label_names['edge_labels'].append('label_1')
  404. return data, y, label_names
  405. def load_tud(filename):
  406. """Load graph data from TUD dataset files.
  407. Notes
  408. ------
  409. The graph data is loaded from separate files.
  410. Check README in `downloadable file <http://tiny.cc/PK_MLJ_data>`__, 2018 for detailed structure.
  411. """
  412. import networkx as nx
  413. from os import listdir
  414. from os.path import dirname, basename
  415. def get_infos_from_readme(frm): # @todo: add README (cuniform), maybe node/edge label maps.
  416. """Get information from DS_label_readme.txt file.
  417. """
  418. def get_label_names_from_line(line):
  419. """Get names of labels/attributes from a line.
  420. """
  421. str_names = line.split('[')[1].split(']')[0]
  422. names = str_names.split(',')
  423. names = [attr.strip() for attr in names]
  424. return names
  425. def get_class_label_map(label_map_strings):
  426. label_map = {}
  427. for string in label_map_strings:
  428. integer, label = string.split('\t')
  429. label_map[int(integer.strip())] = label.strip()
  430. return label_map
  431. label_names = {'node_labels': [], 'node_attrs': [],
  432. 'edge_labels': [], 'edge_attrs': []}
  433. class_label_map = None
  434. class_label_map_strings = []
  435. with open(frm) as rm:
  436. content_rm = rm.read().splitlines()
  437. i = 0
  438. while i < len(content_rm):
  439. line = content_rm[i].strip()
  440. # get node/edge labels and attributes.
  441. if line.startswith('Node labels:'):
  442. label_names['node_labels'] = get_label_names_from_line(line)
  443. elif line.startswith('Node attributes:'):
  444. label_names['node_attrs'] = get_label_names_from_line(line)
  445. elif line.startswith('Edge labels:'):
  446. label_names['edge_labels'] = get_label_names_from_line(line)
  447. elif line.startswith('Edge attributes:'):
  448. label_names['edge_attrs'] = get_label_names_from_line(line)
  449. # get class label map.
  450. elif line.startswith('Class labels were converted to integer values using this map:'):
  451. i += 2
  452. line = content_rm[i].strip()
  453. while line != '' and i < len(content_rm):
  454. class_label_map_strings.append(line)
  455. i += 1
  456. line = content_rm[i].strip()
  457. class_label_map = get_class_label_map(class_label_map_strings)
  458. i += 1
  459. return label_names, class_label_map
  460. # get dataset name.
  461. dirname_dataset = dirname(filename)
  462. filename = basename(filename)
  463. fn_split = filename.split('_A')
  464. ds_name = fn_split[0].strip()
  465. # load data file names
  466. for name in listdir(dirname_dataset):
  467. if ds_name + '_A' in name:
  468. fam = dirname_dataset + '/' + name
  469. elif ds_name + '_graph_indicator' in name:
  470. fgi = dirname_dataset + '/' + name
  471. elif ds_name + '_graph_labels' in name:
  472. fgl = dirname_dataset + '/' + name
  473. elif ds_name + '_node_labels' in name:
  474. fnl = dirname_dataset + '/' + name
  475. elif ds_name + '_edge_labels' in name:
  476. fel = dirname_dataset + '/' + name
  477. elif ds_name + '_edge_attributes' in name:
  478. fea = dirname_dataset + '/' + name
  479. elif ds_name + '_node_attributes' in name:
  480. fna = dirname_dataset + '/' + name
  481. elif ds_name + '_graph_attributes' in name:
  482. fga = dirname_dataset + '/' + name
  483. elif ds_name + '_label_readme' in name:
  484. frm = dirname_dataset + '/' + name
  485. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  486. elif ds_name + '_attributes' in name:
  487. fna = dirname_dataset + '/' + name
  488. # get labels and attributes names.
  489. if 'frm' in locals():
  490. label_names, class_label_map = get_infos_from_readme(frm)
  491. else:
  492. label_names = {'node_labels': [], 'node_attrs': [],
  493. 'edge_labels': [], 'edge_attrs': []}
  494. class_label_map = None
  495. with open(fgi) as gi:
  496. content_gi = gi.read().splitlines() # graph indicator
  497. with open(fam) as am:
  498. content_am = am.read().splitlines() # adjacency matrix
  499. # load targets.
  500. if 'fgl' in locals():
  501. with open(fgl) as gl:
  502. content_targets = gl.read().splitlines() # targets (classification)
  503. targets = [float(i) for i in content_targets]
  504. elif 'fga' in locals():
  505. with open(fga) as ga:
  506. content_targets = ga.read().splitlines() # targets (regression)
  507. targets = [int(i) for i in content_targets]
  508. else:
  509. raise Exception('Can not find targets file. Please make sure there is a "', ds_name, '_graph_labels.txt" or "', ds_name, '_graph_attributes.txt"', 'file in your dataset folder.')
  510. if class_label_map is not None:
  511. targets = [class_label_map[t] for t in targets]
  512. # create graphs and add nodes
  513. data = [nx.Graph(name=str(i)) for i in range(0, len(content_targets))]
  514. if 'fnl' in locals():
  515. with open(fnl) as nl:
  516. content_nl = nl.read().splitlines() # node labels
  517. for idx, line in enumerate(content_gi):
  518. # transfer to int first in case of unexpected blanks
  519. data[int(line) - 1].add_node(idx)
  520. labels = [l.strip() for l in content_nl[idx].split(',')]
  521. if label_names['node_labels'] == []: # @todo: need fix bug.
  522. for i, label in enumerate(labels):
  523. l_name = 'label_' + str(i)
  524. data[int(line) - 1].nodes[idx][l_name] = label
  525. label_names['node_labels'].append(l_name)
  526. else:
  527. for i, l_name in enumerate(label_names['node_labels']):
  528. data[int(line) - 1].nodes[idx][l_name] = labels[i]
  529. else:
  530. for i, line in enumerate(content_gi):
  531. data[int(line) - 1].add_node(i)
  532. # add edges
  533. for line in content_am:
  534. tmp = line.split(',')
  535. n1 = int(tmp[0]) - 1
  536. n2 = int(tmp[1]) - 1
  537. # ignore edge weight here.
  538. g = int(content_gi[n1]) - 1
  539. data[g].add_edge(n1, n2)
  540. # add edge labels
  541. if 'fel' in locals():
  542. with open(fel) as el:
  543. content_el = el.read().splitlines()
  544. for idx, line in enumerate(content_el):
  545. labels = [l.strip() for l in line.split(',')]
  546. n = [int(i) - 1 for i in content_am[idx].split(',')]
  547. g = int(content_gi[n[0]]) - 1
  548. if label_names['edge_labels'] == []:
  549. for i, label in enumerate(labels):
  550. l_name = 'label_' + str(i)
  551. data[g].edges[n[0], n[1]][l_name] = label
  552. label_names['edge_labels'].append(l_name)
  553. else:
  554. for i, l_name in enumerate(label_names['edge_labels']):
  555. data[g].edges[n[0], n[1]][l_name] = labels[i]
  556. # add node attributes
  557. if 'fna' in locals():
  558. with open(fna) as na:
  559. content_na = na.read().splitlines()
  560. for idx, line in enumerate(content_na):
  561. attrs = [a.strip() for a in line.split(',')]
  562. g = int(content_gi[idx]) - 1
  563. if label_names['node_attrs'] == []:
  564. for i, attr in enumerate(attrs):
  565. a_name = 'attr_' + str(i)
  566. data[g].nodes[idx][a_name] = attr
  567. label_names['node_attrs'].append(a_name)
  568. else:
  569. for i, a_name in enumerate(label_names['node_attrs']):
  570. data[g].nodes[idx][a_name] = attrs[i]
  571. # add edge attributes
  572. if 'fea' in locals():
  573. with open(fea) as ea:
  574. content_ea = ea.read().splitlines()
  575. for idx, line in enumerate(content_ea):
  576. attrs = [a.strip() for a in line.split(',')]
  577. n = [int(i) - 1 for i in content_am[idx].split(',')]
  578. g = int(content_gi[n[0]]) - 1
  579. if label_names['edge_attrs'] == []:
  580. for i, attr in enumerate(attrs):
  581. a_name = 'attr_' + str(i)
  582. data[g].edges[n[0], n[1]][a_name] = attr
  583. label_names['edge_attrs'].append(a_name)
  584. else:
  585. for i, a_name in enumerate(label_names['edge_attrs']):
  586. data[g].edges[n[0], n[1]][a_name] = attrs[i]
  587. return data, targets, label_names
  588. def load_from_ds(filename, filename_targets):
  589. """Load data from .ds file.
  590. Possible graph formats include:
  591. '.ct': see function load_ct for detail.
  592. '.gxl': see dunction load_gxl for detail.
  593. Note these graph formats are checked automatically by the extensions of
  594. graph files.
  595. """
  596. dirname_dataset = dirname(filename)
  597. data = []
  598. y = []
  599. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  600. with open(filename) as fn:
  601. content = fn.read().splitlines()
  602. extension = splitext(content[0].split(' ')[0])[1][1:]
  603. if extension == 'ct':
  604. load_file_fun = load_ct
  605. elif extension == 'gxl' or extension == 'sdf': # @todo: .sdf not tested yet.
  606. load_file_fun = load_gxl
  607. if filename_targets is None or filename_targets == '':
  608. for i in range(0, len(content)):
  609. tmp = content[i].split(' ')
  610. # remove the '#'s in file names
  611. g, l_names = load_file_fun(dirname_dataset + '/' + tmp[0].replace('#', '', 1))
  612. data.append(g)
  613. _append_label_names(label_names, l_names)
  614. y.append(float(tmp[1]))
  615. else: # targets in a seperate file
  616. for i in range(0, len(content)):
  617. tmp = content[i]
  618. # remove the '#'s in file names
  619. g, l_names = load_file_fun(dirname_dataset + '/' + tmp.replace('#', '', 1))
  620. data.append(g)
  621. _append_label_names(label_names, l_names)
  622. with open(filename_targets) as fnt:
  623. content_y = fnt.read().splitlines()
  624. # assume entries in filename and filename_targets have the same order.
  625. for item in content_y:
  626. tmp = item.split(' ')
  627. # assume the 3rd entry in a line is y (for Alkane dataset)
  628. y.append(float(tmp[2]))
  629. return data, y, label_names
  630. # def load_from_cxl(filename):
  631. # import xml.etree.ElementTree as ET
  632. #
  633. # dirname_dataset = dirname(filename)
  634. # tree = ET.parse(filename)
  635. # root = tree.getroot()
  636. # data = []
  637. # y = []
  638. # for graph in root.iter('graph'):
  639. # mol_filename = graph.attrib['file']
  640. # mol_class = graph.attrib['class']
  641. # data.append(load_gxl(dirname_dataset + '/' + mol_filename))
  642. # y.append(mol_class)
  643. def load_from_xml(filename, dir_dataset=None):
  644. import xml.etree.ElementTree as ET
  645. if dir_dataset is not None:
  646. dir_dataset = dir_dataset
  647. else:
  648. dir_dataset = dirname(filename)
  649. tree = ET.parse(filename)
  650. root = tree.getroot()
  651. data = []
  652. y = []
  653. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  654. for graph in root.iter('graph'):
  655. mol_filename = graph.attrib['file']
  656. mol_class = graph.attrib['class']
  657. g, l_names = load_gxl(dir_dataset + '/' + mol_filename)
  658. data.append(g)
  659. _append_label_names(label_names, l_names)
  660. y.append(mol_class)
  661. return data, y, label_names
  662. def _append_label_names(label_names, new_names):
  663. for key, val in label_names.items():
  664. label_names[key] += [name for name in new_names[key] if name not in val]
  665. if __name__ == '__main__':
  666. # ### Load dataset from .ds file.
  667. # # .ct files.
  668. # ds = {'name': 'Alkane', 'dataset': '../../datasets/Alkane/dataset.ds',
  669. # 'dataset_y': '../../datasets/Alkane/dataset_boiling_point_names.txt'}
  670. # Gn, y = loadDataset(ds['dataset'], filename_y=ds['dataset_y'])
  671. # ds_file = '../../datasets/Acyclic/dataset_bps.ds' # node symb
  672. # Gn, targets, label_names = load_dataset(ds_file)
  673. # ds_file = '../../datasets/MAO/dataset.ds' # node/edge symb
  674. # Gn, targets, label_names = load_dataset(ds_file)
  675. ## ds = {'name': 'PAH', 'dataset': '../../datasets/PAH/dataset.ds'} # unlabeled
  676. ## Gn, y = loadDataset(ds['dataset'])
  677. # print(Gn[1].graph)
  678. # print(Gn[1].nodes(data=True))
  679. # print(Gn[1].edges(data=True))
  680. # print(targets[1])
  681. # # .gxl file.
  682. # ds_file = '../../datasets/monoterpenoides/dataset_10+.ds' # node/edge symb
  683. # Gn, y, label_names = load_dataset(ds_file)
  684. # print(Gn[1].graph)
  685. # print(Gn[1].nodes(data=True))
  686. # print(Gn[1].edges(data=True))
  687. # print(y[1])
  688. # .mat file.
  689. ds_file = '../../datasets/MUTAG_mat/MUTAG.mat'
  690. order = [0, 0, 3, 1, 2]
  691. Gn, targets, label_names = load_dataset(ds_file, order=order)
  692. print(Gn[1].graph)
  693. print(Gn[1].nodes(data=True))
  694. print(Gn[1].edges(data=True))
  695. print(targets[1])
  696. # ### Convert graph from one format to another.
  697. # # .gxl file.
  698. # import networkx as nx
  699. # ds = {'name': 'monoterpenoides',
  700. # 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  701. # Gn, y = loadDataset(ds['dataset'])
  702. # y = [int(i) for i in y]
  703. # print(Gn[1].nodes(data=True))
  704. # print(Gn[1].edges(data=True))
  705. # print(y[1])
  706. # # Convert a graph to the proper NetworkX format that can be recognized by library gedlib.
  707. # Gn_new = []
  708. # for G in Gn:
  709. # G_new = nx.Graph()
  710. # for nd, attrs in G.nodes(data=True):
  711. # G_new.add_node(str(nd), chem=attrs['atom'])
  712. # for nd1, nd2, attrs in G.edges(data=True):
  713. # G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  714. ## G_new.add_edge(str(nd1), str(nd2))
  715. # Gn_new.append(G_new)
  716. # print(Gn_new[1].nodes(data=True))
  717. # print(Gn_new[1].edges(data=True))
  718. # print(Gn_new[1])
  719. # filename = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/gxl/monoterpenoides'
  720. # xparams = {'method': 'gedlib'}
  721. # saveDataset(Gn, y, gformat='gxl', group='xml', filename=filename, xparams=xparams)
  722. # save dataset.
  723. # ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
  724. # 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  725. # Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  726. # saveDataset(Gn, y, group='xml', filename='temp/temp')
  727. # test - new way to add labels and attributes.
  728. # dataset = '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
  729. # filename = '../../datasets/Fingerprint/Fingerprint_A.txt'
  730. # dataset = '../../datasets/Letter-med/Letter-med_A.txt'
  731. # dataset = '../../datasets/AIDS/AIDS_A.txt'
  732. # dataset = '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt'
  733. # Gn, targets, label_names = load_dataset(filename)
  734. pass

A Python package for graph kernels, graph edit distances and graph pre-image problem.