You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

file_managers.py 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936
  1. """ Utilities function to manage graph files
  2. """
  3. from os.path import dirname, splitext
  4. class DataLoader():
  5. def __init__(self, filename, filename_targets=None, gformat=None, **kwargs):
  6. """Read graph data from filename and load them as NetworkX graphs.
  7. Parameters
  8. ----------
  9. filename : string
  10. The name of the file from where the dataset is read.
  11. filename_targets : string
  12. The name of file of the targets corresponding to graphs.
  13. Notes
  14. -----
  15. This function supports following graph dataset formats:
  16. 'ds': load data from .ds file. See comments of function loadFromDS for a example.
  17. 'cxl': load data from Graph eXchange Language file (.cxl file). See
  18. `here <http://www.gupro.de/GXL/Introduction/background.html>`__ for detail.
  19. 'sdf': load data from structured data file (.sdf file). See
  20. `here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__
  21. for details.
  22. 'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See
  23. README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__
  24. for details.
  25. 'txt': Load graph data from the TUDataset. See
  26. `here <https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets>`__
  27. for details. Note here filename is the name of either .txt file in
  28. the dataset directory.
  29. """
  30. if isinstance(filename, str):
  31. extension = splitext(filename)[1][1:]
  32. else: # filename is a list of files.
  33. extension = splitext(filename[0])[1][1:]
  34. if extension == "ds":
  35. self._graphs, self._targets, self._label_names = self.load_from_ds(filename, filename_targets)
  36. elif extension == "cxl":
  37. dir_dataset = kwargs.get('dirname_dataset', None)
  38. self._graphs, self._targets, self._label_names = self.load_from_xml(filename, dir_dataset)
  39. elif extension == 'xml':
  40. dir_dataset = kwargs.get('dirname_dataset', None)
  41. self._graphs, self._targets, self._label_names = self.load_from_xml(filename, dir_dataset)
  42. elif extension == "mat":
  43. order = kwargs.get('order')
  44. self._graphs, self._targets, self._label_names = self.load_mat(filename, order)
  45. elif extension == 'txt':
  46. if gformat is None:
  47. self._graphs, self._targets, self._label_names = self.load_tud(filename)
  48. elif gformat == 'cml':
  49. self._graphs, self._targets, self._label_names = self.load_from_ds(filename, filename_targets)
  50. else:
  51. raise ValueError('The input file with the extension ".', extension, '" is not supported. The supported extensions includes: ".ds", ".cxl", ".xml", ".mat", ".txt".')
  52. def load_from_ds(self, filename, filename_targets):
  53. """Load data from .ds file.
  54. Possible graph formats include:
  55. '.ct': see function load_ct for detail.
  56. '.gxl': see dunction load_gxl for detail.
  57. Note these graph formats are checked automatically by the extensions of
  58. graph files.
  59. """
  60. if isinstance(filename, str):
  61. dirname_dataset = dirname(filename)
  62. with open(filename) as f:
  63. content = f.read().splitlines()
  64. else: # filename is a list of files.
  65. dirname_dataset = dirname(filename[0])
  66. content = []
  67. for fn in filename:
  68. with open(fn) as f:
  69. content += f.read().splitlines()
  70. # to remove duplicate file names.
  71. data = []
  72. y = []
  73. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  74. content = [line for line in content if not line.endswith('.ds')] # Alkane
  75. content = [line for line in content if not line.startswith('#')] # Acyclic
  76. extension = splitext(content[0].split(' ')[0])[1][1:]
  77. if extension == 'ct':
  78. load_file_fun = self.load_ct
  79. elif extension == 'gxl' or extension == 'sdf': # @todo: .sdf not tested yet.
  80. load_file_fun = self.load_gxl
  81. elif extension == 'cml': # dataset "Chiral"
  82. load_file_fun = self.load_cml
  83. if filename_targets is None or filename_targets == '':
  84. for i in range(0, len(content)):
  85. tmp = content[i].split(' ')
  86. # remove the '#'s in file names
  87. g, l_names = load_file_fun(dirname_dataset + '/' + tmp[0].replace('#', '', 1))
  88. data.append(g)
  89. self._append_label_names(label_names, l_names) # @todo: this is so redundant.
  90. y.append(float(tmp[1]))
  91. else: # targets in a seperate file
  92. for i in range(0, len(content)):
  93. tmp = content[i]
  94. # remove the '#'s in file names
  95. g, l_names = load_file_fun(dirname_dataset + '/' + tmp.replace('#', '', 1))
  96. data.append(g)
  97. self._append_label_names(label_names, l_names)
  98. with open(filename_targets) as fnt:
  99. content_y = fnt.read().splitlines()
  100. # assume entries in filename and filename_targets have the same order.
  101. for item in content_y:
  102. tmp = item.split(' ')
  103. # assume the 3rd entry in a line is y (for Alkane dataset)
  104. y.append(float(tmp[2]))
  105. return data, y, label_names
  106. def load_from_xml(self, filename, dir_dataset=None):
  107. import xml.etree.ElementTree as ET
  108. def load_one_file(filename, data, y, label_names):
  109. tree = ET.parse(filename)
  110. root = tree.getroot()
  111. for graph in root.iter('graph') if root.find('graph') is not None else root.iter('print'): # "graph" for ... I forgot; "print" for datasets GREC and Web.
  112. mol_filename = graph.attrib['file']
  113. mol_class = graph.attrib['class']
  114. g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename)
  115. data.append(g)
  116. self._append_label_names(label_names, l_names)
  117. y.append(mol_class)
  118. data = []
  119. y = []
  120. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  121. if isinstance(filename, str):
  122. if dir_dataset is not None:
  123. dir_dataset = dir_dataset
  124. else:
  125. dir_dataset = dirname(filename)
  126. load_one_file(filename, data, y, label_names)
  127. else: # filename is a list of files.
  128. if dir_dataset is not None:
  129. dir_dataset = dir_dataset
  130. else:
  131. dir_dataset = dirname(filename[0])
  132. for fn in filename:
  133. load_one_file(fn, data, y, label_names)
  134. return data, y, label_names
  135. def load_mat(self, filename, order): # @todo: need to be updated (auto order) or deprecated.
  136. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  137. Notes
  138. ------
  139. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  140. Check README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__ for detailed structure.
  141. """
  142. from scipy.io import loadmat
  143. import numpy as np
  144. import networkx as nx
  145. data = []
  146. content = loadmat(filename)
  147. for key, value in content.items():
  148. if key[0] == 'l': # class label
  149. y = np.transpose(value)[0].tolist()
  150. elif key[0] != '_':
  151. # if adjacency matrix is not compressed / edge label exists
  152. if order[1] == 0:
  153. for i, item in enumerate(value[0]):
  154. g = nx.Graph(name=i) # set name of the graph
  155. nl = np.transpose(item[order[3]][0][0][0]) # node label
  156. for index, label in enumerate(nl[0]):
  157. g.add_node(index, label_1=str(label))
  158. el = item[order[4]][0][0][0] # edge label
  159. for edge in el:
  160. g.add_edge(edge[0] - 1, edge[1] - 1, label_1=str(edge[2]))
  161. data.append(g)
  162. else:
  163. for i, item in enumerate(value[0]):
  164. g = nx.Graph(name=i) # set name of the graph
  165. nl = np.transpose(item[order[3]][0][0][0]) # node label
  166. for index, label in enumerate(nl[0]):
  167. g.add_node(index, label_1=str(label))
  168. sam = item[order[0]] # sparse adjacency matrix
  169. index_no0 = sam.nonzero()
  170. for col, row in zip(index_no0[0], index_no0[1]):
  171. g.add_edge(col, row)
  172. data.append(g)
  173. label_names = {'node_labels': ['label_1'], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  174. if order[1] == 0:
  175. label_names['edge_labels'].append('label_1')
  176. return data, y, label_names
  177. def load_tud(self, filename):
  178. """Load graph data from TUD dataset files.
  179. Notes
  180. ------
  181. The graph data is loaded from separate files.
  182. Check README in `downloadable file <http://tiny.cc/PK_MLJ_data>`__, 2018 for detailed structure.
  183. """
  184. import networkx as nx
  185. from os import listdir
  186. from os.path import dirname, basename
  187. def get_infos_from_readme(frm): # @todo: add README (cuniform), maybe node/edge label maps.
  188. """Get information from DS_label_readme.txt file.
  189. """
  190. def get_label_names_from_line(line):
  191. """Get names of labels/attributes from a line.
  192. """
  193. str_names = line.split('[')[1].split(']')[0]
  194. names = str_names.split(',')
  195. names = [attr.strip() for attr in names]
  196. return names
  197. def get_class_label_map(label_map_strings):
  198. label_map = {}
  199. for string in label_map_strings:
  200. integer, label = string.split('\t')
  201. label_map[int(integer.strip())] = label.strip()
  202. return label_map
  203. label_names = {'node_labels': [], 'node_attrs': [],
  204. 'edge_labels': [], 'edge_attrs': []}
  205. class_label_map = None
  206. class_label_map_strings = []
  207. with open(frm) as rm:
  208. content_rm = rm.read().splitlines()
  209. i = 0
  210. while i < len(content_rm):
  211. line = content_rm[i].strip()
  212. # get node/edge labels and attributes.
  213. if line.startswith('Node labels:'):
  214. label_names['node_labels'] = get_label_names_from_line(line)
  215. elif line.startswith('Node attributes:'):
  216. label_names['node_attrs'] = get_label_names_from_line(line)
  217. elif line.startswith('Edge labels:'):
  218. label_names['edge_labels'] = get_label_names_from_line(line)
  219. elif line.startswith('Edge attributes:'):
  220. label_names['edge_attrs'] = get_label_names_from_line(line)
  221. # get class label map.
  222. elif line.startswith('Class labels were converted to integer values using this map:'):
  223. i += 2
  224. line = content_rm[i].strip()
  225. while line != '' and i < len(content_rm):
  226. class_label_map_strings.append(line)
  227. i += 1
  228. line = content_rm[i].strip()
  229. class_label_map = get_class_label_map(class_label_map_strings)
  230. i += 1
  231. return label_names, class_label_map
  232. # get dataset name.
  233. dirname_dataset = dirname(filename)
  234. filename = basename(filename)
  235. fn_split = filename.split('_A')
  236. ds_name = fn_split[0].strip()
  237. # load data file names
  238. for name in listdir(dirname_dataset):
  239. if ds_name + '_A' in name:
  240. fam = dirname_dataset + '/' + name
  241. elif ds_name + '_graph_indicator' in name:
  242. fgi = dirname_dataset + '/' + name
  243. elif ds_name + '_graph_labels' in name:
  244. fgl = dirname_dataset + '/' + name
  245. elif ds_name + '_node_labels' in name:
  246. fnl = dirname_dataset + '/' + name
  247. elif ds_name + '_edge_labels' in name:
  248. fel = dirname_dataset + '/' + name
  249. elif ds_name + '_edge_attributes' in name:
  250. fea = dirname_dataset + '/' + name
  251. elif ds_name + '_node_attributes' in name:
  252. fna = dirname_dataset + '/' + name
  253. elif ds_name + '_graph_attributes' in name:
  254. fga = dirname_dataset + '/' + name
  255. elif ds_name + '_label_readme' in name:
  256. frm = dirname_dataset + '/' + name
  257. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  258. elif ds_name + '_attributes' in name:
  259. fna = dirname_dataset + '/' + name
  260. # get labels and attributes names.
  261. if 'frm' in locals():
  262. label_names, class_label_map = get_infos_from_readme(frm)
  263. else:
  264. label_names = {'node_labels': [], 'node_attrs': [],
  265. 'edge_labels': [], 'edge_attrs': []}
  266. class_label_map = None
  267. with open(fgi) as gi:
  268. content_gi = gi.read().splitlines() # graph indicator
  269. with open(fam) as am:
  270. content_am = am.read().splitlines() # adjacency matrix
  271. # load targets.
  272. if 'fgl' in locals():
  273. with open(fgl) as gl:
  274. content_targets = gl.read().splitlines() # targets (classification)
  275. targets = [float(i) for i in content_targets]
  276. elif 'fga' in locals():
  277. with open(fga) as ga:
  278. content_targets = ga.read().splitlines() # targets (regression)
  279. targets = [int(i) for i in content_targets]
  280. else:
  281. exp_msg = 'Can not find targets file. Please make sure there is a "', ds_name, '_graph_labels.txt" or "', ds_name, '_graph_attributes.txt"', 'file in your dataset folder.'
  282. raise Exception(exp_msg)
  283. if class_label_map is not None:
  284. targets = [class_label_map[t] for t in targets]
  285. # create graphs and add nodes
  286. data = [nx.Graph(name=str(i)) for i in range(0, len(content_targets))]
  287. if 'fnl' in locals():
  288. with open(fnl) as nl:
  289. content_nl = nl.read().splitlines() # node labels
  290. for idx, line in enumerate(content_gi):
  291. # transfer to int first in case of unexpected blanks
  292. data[int(line) - 1].add_node(idx)
  293. labels = [l.strip() for l in content_nl[idx].split(',')]
  294. if label_names['node_labels'] == []: # @todo: need fix bug.
  295. for i, label in enumerate(labels):
  296. l_name = 'label_' + str(i)
  297. data[int(line) - 1].nodes[idx][l_name] = label
  298. label_names['node_labels'].append(l_name)
  299. else:
  300. for i, l_name in enumerate(label_names['node_labels']):
  301. data[int(line) - 1].nodes[idx][l_name] = labels[i]
  302. else:
  303. for i, line in enumerate(content_gi):
  304. data[int(line) - 1].add_node(i)
  305. # add edges
  306. for line in content_am:
  307. tmp = line.split(',')
  308. n1 = int(tmp[0]) - 1
  309. n2 = int(tmp[1]) - 1
  310. # ignore edge weight here.
  311. g = int(content_gi[n1]) - 1
  312. data[g].add_edge(n1, n2)
  313. # add edge labels
  314. if 'fel' in locals():
  315. with open(fel) as el:
  316. content_el = el.read().splitlines()
  317. for idx, line in enumerate(content_el):
  318. labels = [l.strip() for l in line.split(',')]
  319. n = [int(i) - 1 for i in content_am[idx].split(',')]
  320. g = int(content_gi[n[0]]) - 1
  321. if label_names['edge_labels'] == []:
  322. for i, label in enumerate(labels):
  323. l_name = 'label_' + str(i)
  324. data[g].edges[n[0], n[1]][l_name] = label
  325. label_names['edge_labels'].append(l_name)
  326. else:
  327. for i, l_name in enumerate(label_names['edge_labels']):
  328. data[g].edges[n[0], n[1]][l_name] = labels[i]
  329. # add node attributes
  330. if 'fna' in locals():
  331. with open(fna) as na:
  332. content_na = na.read().splitlines()
  333. for idx, line in enumerate(content_na):
  334. attrs = [a.strip() for a in line.split(',')]
  335. g = int(content_gi[idx]) - 1
  336. if label_names['node_attrs'] == []:
  337. for i, attr in enumerate(attrs):
  338. a_name = 'attr_' + str(i)
  339. data[g].nodes[idx][a_name] = attr
  340. label_names['node_attrs'].append(a_name)
  341. else:
  342. for i, a_name in enumerate(label_names['node_attrs']):
  343. data[g].nodes[idx][a_name] = attrs[i]
  344. # add edge attributes
  345. if 'fea' in locals():
  346. with open(fea) as ea:
  347. content_ea = ea.read().splitlines()
  348. for idx, line in enumerate(content_ea):
  349. attrs = [a.strip() for a in line.split(',')]
  350. n = [int(i) - 1 for i in content_am[idx].split(',')]
  351. g = int(content_gi[n[0]]) - 1
  352. if label_names['edge_attrs'] == []:
  353. for i, attr in enumerate(attrs):
  354. a_name = 'attr_' + str(i)
  355. data[g].edges[n[0], n[1]][a_name] = attr
  356. label_names['edge_attrs'].append(a_name)
  357. else:
  358. for i, a_name in enumerate(label_names['edge_attrs']):
  359. data[g].edges[n[0], n[1]][a_name] = attrs[i]
  360. return data, targets, label_names
  361. def load_ct(self, filename): # @todo: this function is only tested on CTFile V2000; header not considered; only simple cases (atoms and bonds are considered.)
  362. """load data from a Chemical Table (.ct) file.
  363. Notes
  364. ------
  365. a typical example of data in .ct is like this:
  366. 3 2 <- number of nodes and edges
  367. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  368. 0.0000 0.0000 0.0000 C
  369. 0.0000 0.0000 0.0000 O
  370. 1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo
  371. 2 3 1 1
  372. Check `CTFile Formats file <https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=10&ved=2ahUKEwivhaSdjsTlAhVhx4UKHczHA8gQFjAJegQIARAC&url=https%3A%2F%2Fwww.daylight.com%2Fmeetings%2Fmug05%2FKappler%2Fctfile.pdf&usg=AOvVaw1cDNrrmMClkFPqodlF2inS>`__
  373. for detailed format discription.
  374. """
  375. import networkx as nx
  376. from os.path import basename
  377. g = nx.Graph()
  378. with open(filename) as f:
  379. content = f.read().splitlines()
  380. g = nx.Graph(name=str(content[0]), filename=basename(filename)) # set name of the graph
  381. # read the counts line.
  382. tmp = content[1].split(' ')
  383. tmp = [x for x in tmp if x != '']
  384. nb_atoms = int(tmp[0].strip()) # number of atoms
  385. nb_bonds = int(tmp[1].strip()) # number of bonds
  386. count_line_tags = ['number_of_atoms', 'number_of_bonds', 'number_of_atom_lists', '', 'chiral_flag', 'number_of_stext_entries', '', '', '', '', 'number_of_properties', 'CT_version']
  387. i = 0
  388. while i < len(tmp):
  389. if count_line_tags[i] != '': # if not obsoleted
  390. g.graph[count_line_tags[i]] = tmp[i].strip()
  391. i += 1
  392. # read the atom block.
  393. atom_tags = ['x', 'y', 'z', 'atom_symbol', 'mass_difference', 'charge', 'atom_stereo_parity', 'hydrogen_count_plus_1', 'stereo_care_box', 'valence', 'h0_designator', '', '', 'atom_atom_mapping_number', 'inversion_retention_flag', 'exact_change_flag']
  394. for i in range(0, nb_atoms):
  395. tmp = content[i + 2].split(' ')
  396. tmp = [x for x in tmp if x != '']
  397. g.add_node(i)
  398. j = 0
  399. while j < len(tmp):
  400. if atom_tags[j] != '':
  401. g.nodes[i][atom_tags[j]] = tmp[j].strip()
  402. j += 1
  403. # read the bond block.
  404. bond_tags = ['first_atom_number', 'second_atom_number', 'bond_type', 'bond_stereo', '', 'bond_topology', 'reacting_center_status']
  405. for i in range(0, nb_bonds):
  406. tmp = content[i + g.number_of_nodes() + 2].split(' ')
  407. tmp = [x for x in tmp if x != '']
  408. n1, n2 = int(tmp[0].strip()) - 1, int(tmp[1].strip()) - 1
  409. g.add_edge(n1, n2)
  410. j = 2
  411. while j < len(tmp):
  412. if bond_tags[j] != '':
  413. g.edges[(n1, n2)][bond_tags[j]] = tmp[j].strip()
  414. j += 1
  415. # get label names.
  416. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  417. atom_symbolic = [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, None, None, 1, 1, 1]
  418. for nd in g.nodes():
  419. for key in g.nodes[nd]:
  420. if atom_symbolic[atom_tags.index(key)] == 1:
  421. label_names['node_labels'].append(key)
  422. else:
  423. label_names['node_attrs'].append(key)
  424. break
  425. bond_symbolic = [None, None, 1, 1, None, 1, 1]
  426. for ed in g.edges():
  427. for key in g.edges[ed]:
  428. if bond_symbolic[bond_tags.index(key)] == 1:
  429. label_names['edge_labels'].append(key)
  430. else:
  431. label_names['edge_attrs'].append(key)
  432. break
  433. return g, label_names
  434. def load_gxl(self, filename): # @todo: directed graphs.
  435. from os.path import basename
  436. import networkx as nx
  437. import xml.etree.ElementTree as ET
  438. tree = ET.parse(filename)
  439. root = tree.getroot()
  440. index = 0
  441. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  442. dic = {} # used to retrieve incident nodes of edges
  443. for node in root.iter('node'):
  444. dic[node.attrib['id']] = index
  445. labels = {}
  446. for attr in node.iter('attr'): # for datasets "GREC" and "Monoterpens".
  447. labels[attr.attrib['name']] = attr[0].text
  448. for attr in node.iter('attribute'): # for dataset "Web".
  449. labels[attr.attrib['name']] = attr.attrib['value']
  450. g.add_node(index, **labels)
  451. index += 1
  452. for edge in root.iter('edge'):
  453. labels = {}
  454. for attr in edge.iter('attr'): # for datasets "GREC" and "Monoterpens".
  455. labels[attr.attrib['name']] = attr[0].text
  456. for attr in edge.iter('attribute'): # for dataset "Web".
  457. labels[attr.attrib['name']] = attr.attrib['value']
  458. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  459. # get label names.
  460. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  461. # @todo: possible loss of label names if some nodes miss some labels.
  462. for node in root.iter('node'):
  463. for attr in node.iter('attr'): # for datasets "GREC" and "Monoterpens".
  464. if attr[0].tag == 'int' or attr.attrib['name'] == 'type': # @todo: this maybe wrong, and slow. "type" is for dataset GREC; "int" is for dataset "Monoterpens".
  465. label_names['node_labels'].append(attr.attrib['name'])
  466. else:
  467. label_names['node_attrs'].append(attr.attrib['name'])
  468. for attr in node.iter('attribute'): # for dataset "Web".
  469. label_names['node_attrs'].append(attr.attrib['name'])
  470. # @todo: is id useful in dataset "Web"? is "FREQUENCY" symbolic or not?
  471. break
  472. for edge in root.iter('edge'):
  473. for attr in edge.iter('attr'): # for datasets "GREC" and "Monoterpens".
  474. if attr[0].tag == 'int' or attr.attrib['name'] == 'frequency' or 'type' in attr.attrib['name']: # @todo: this maybe wrong, and slow. "frequency" and "type" are for dataset GREC; "int" is for dataset "Monoterpens".
  475. label_names['edge_labels'].append(attr.attrib['name'])
  476. else:
  477. label_names['edge_attrs'].append(attr.attrib['name'])
  478. for attr in edge.iter('attribute'): # for dataset "Web".
  479. label_names['edge_attrs'].append(attr.attrib['name'])
  480. break
  481. return g, label_names
  482. def load_cml(self, filename): # @todo: directed graphs.
  483. # @todo: what is "atomParity" and "bondStereo" in the data file?
  484. from os.path import basename
  485. import networkx as nx
  486. import xml.etree.ElementTree as ET
  487. xmlns = '{http://www.xml-cml.org/schema}' # @todo: why this has to be added?
  488. tree = ET.parse(filename)
  489. root = tree.getroot()
  490. index = 0
  491. if root.tag == xmlns + 'molecule':
  492. g_id = root.attrib['id']
  493. else:
  494. g_id = root.find(xmlns + 'molecule').attrib['id']
  495. g = nx.Graph(filename=basename(filename), name=g_id)
  496. dic = {} # used to retrieve incident nodes of edges
  497. for atom in root.iter(xmlns + 'atom'):
  498. dic[atom.attrib['id']] = index
  499. labels = {}
  500. for key, val in atom.attrib.items():
  501. if key != 'id':
  502. labels[key] = val
  503. g.add_node(index, **labels)
  504. index += 1
  505. for bond in root.iter(xmlns + 'bond'):
  506. labels = {}
  507. for key, val in bond.attrib.items():
  508. if key != 'atomRefs2' and key != 'id': # "id" is in dataset "ACE".
  509. labels[key] = val
  510. n1, n2 = bond.attrib['atomRefs2'].strip().split(' ')
  511. g.add_edge(dic[n1], dic[n2], **labels)
  512. # get label names.
  513. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
  514. # @todo: possible loss of label names if some nodes miss some labels.
  515. for key, val in g.nodes[0].items():
  516. try:
  517. float(val)
  518. except:
  519. label_names['node_labels'].append(key)
  520. else:
  521. if val.isdigit():
  522. label_names['node_labels'].append(key)
  523. else:
  524. label_names['node_attrs'].append(key)
  525. for _, _, attrs in g.edges(data=True):
  526. for key, val in attrs.items():
  527. try:
  528. float(val)
  529. except:
  530. label_names['edge_labels'].append(key)
  531. else:
  532. if val.isdigit():
  533. label_names['edge_labels'].append(key)
  534. else:
  535. label_names['edge_attrs'].append(key)
  536. break
  537. return g, label_names
  538. def _append_label_names(self, label_names, new_names):
  539. for key, val in label_names.items():
  540. label_names[key] += [name for name in new_names[key] if name not in val]
  541. @property
  542. def data(self):
  543. return self._graphs, self._targets, self._label_names
  544. @property
  545. def graphs(self):
  546. return self._graphs
  547. @property
  548. def targets(self):
  549. return self._targets
  550. @property
  551. def label_names(self):
  552. return self._label_names
  553. class DataSaver():
  554. def __init__(self, graphs, targets=None, filename='gfile', gformat='gxl', group=None, **kwargs):
  555. """Save list of graphs.
  556. """
  557. import os
  558. dirname_ds = os.path.dirname(filename)
  559. if dirname_ds != '':
  560. dirname_ds += '/'
  561. os.makedirs(dirname_ds, exist_ok=True)
  562. if 'graph_dir' in kwargs:
  563. graph_dir = kwargs['graph_dir'] + '/'
  564. os.makedirs(graph_dir, exist_ok=True)
  565. del kwargs['graph_dir']
  566. else:
  567. graph_dir = dirname_ds
  568. if group == 'xml' and gformat == 'gxl':
  569. with open(filename + '.xml', 'w') as fgroup:
  570. fgroup.write("<?xml version=\"1.0\"?>")
  571. fgroup.write("\n<!DOCTYPE GraphCollection SYSTEM \"http://www.inf.unibz.it/~blumenthal/dtd/GraphCollection.dtd\">")
  572. fgroup.write("\n<GraphCollection>")
  573. for idx, g in enumerate(graphs):
  574. fname_tmp = "graph" + str(idx) + ".gxl"
  575. self.save_gxl(g, graph_dir + fname_tmp, **kwargs)
  576. fgroup.write("\n\t<graph file=\"" + fname_tmp + "\" class=\"" + str(targets[idx]) + "\"/>")
  577. fgroup.write("\n</GraphCollection>")
  578. fgroup.close()
  579. def save_gxl(self, graph, filename, method='default', node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):
  580. if method == 'default':
  581. gxl_file = open(filename, 'w')
  582. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  583. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  584. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  585. if 'name' in graph.graph:
  586. name = str(graph.graph['name'])
  587. else:
  588. name = 'dummy'
  589. gxl_file.write("<graph id=\"" + name + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  590. for v, attrs in graph.nodes(data=True):
  591. gxl_file.write("<node id=\"_" + str(v) + "\">")
  592. for l_name in node_labels:
  593. gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
  594. str(attrs[l_name]) + "</int></attr>")
  595. for a_name in node_attrs:
  596. gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
  597. str(attrs[a_name]) + "</float></attr>")
  598. gxl_file.write("</node>\n")
  599. for v1, v2, attrs in graph.edges(data=True):
  600. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  601. for l_name in edge_labels:
  602. gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
  603. str(attrs[l_name]) + "</int></attr>")
  604. for a_name in edge_attrs:
  605. gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
  606. str(attrs[a_name]) + "</float></attr>")
  607. gxl_file.write("</edge>\n")
  608. gxl_file.write("</graph>\n")
  609. gxl_file.write("</gxl>")
  610. gxl_file.close()
  611. elif method == 'benoit':
  612. import xml.etree.ElementTree as ET
  613. root_node = ET.Element('gxl')
  614. attr = dict()
  615. attr['id'] = str(graph.graph['name'])
  616. attr['edgeids'] = 'true'
  617. attr['edgemode'] = 'undirected'
  618. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  619. for v in graph:
  620. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  621. for attr in graph.nodes[v].keys():
  622. cur_attr = ET.SubElement(
  623. current_node, 'attr', attrib={'name': attr})
  624. cur_value = ET.SubElement(cur_attr,
  625. graph.nodes[v][attr].__class__.__name__)
  626. cur_value.text = graph.nodes[v][attr]
  627. for v1 in graph:
  628. for v2 in graph[v1]:
  629. if (v1 < v2): # Non oriented graphs
  630. cur_edge = ET.SubElement(
  631. graph_node,
  632. 'edge',
  633. attrib={
  634. 'from': str(v1),
  635. 'to': str(v2)
  636. })
  637. for attr in graph[v1][v2].keys():
  638. cur_attr = ET.SubElement(
  639. cur_edge, 'attr', attrib={'name': attr})
  640. cur_value = ET.SubElement(
  641. cur_attr, graph[v1][v2][attr].__class__.__name__)
  642. cur_value.text = str(graph[v1][v2][attr])
  643. tree = ET.ElementTree(root_node)
  644. tree.write(filename)
  645. elif method == 'gedlib':
  646. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  647. # pass
  648. gxl_file = open(filename, 'w')
  649. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  650. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  651. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  652. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"true\" edgemode=\"undirected\">\n")
  653. for v, attrs in graph.nodes(data=True):
  654. gxl_file.write("<node id=\"_" + str(v) + "\">")
  655. gxl_file.write("<attr name=\"" + "chem" + "\"><int>" + str(attrs['chem']) + "</int></attr>")
  656. gxl_file.write("</node>\n")
  657. for v1, v2, attrs in graph.edges(data=True):
  658. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  659. gxl_file.write("<attr name=\"valence\"><int>" + str(attrs['valence']) + "</int></attr>")
  660. # gxl_file.write("<attr name=\"valence\"><int>" + "1" + "</int></attr>")
  661. gxl_file.write("</edge>\n")
  662. gxl_file.write("</graph>\n")
  663. gxl_file.write("</gxl>")
  664. gxl_file.close()
  665. elif method == 'gedlib-letter':
  666. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  667. # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl
  668. gxl_file = open(filename, 'w')
  669. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  670. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  671. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  672. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  673. for v, attrs in graph.nodes(data=True):
  674. gxl_file.write("<node id=\"_" + str(v) + "\">")
  675. gxl_file.write("<attr name=\"x\"><float>" + str(attrs['attributes'][0]) + "</float></attr>")
  676. gxl_file.write("<attr name=\"y\"><float>" + str(attrs['attributes'][1]) + "</float></attr>")
  677. gxl_file.write("</node>\n")
  678. for v1, v2, attrs in graph.edges(data=True):
  679. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\"/>\n")
  680. gxl_file.write("</graph>\n")
  681. gxl_file.write("</gxl>")
  682. gxl_file.close()
  683. # def loadSDF(filename):
  684. # """load data from structured data file (.sdf file).
  685. # Notes
  686. # ------
  687. # A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  688. # Check `here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__ for detailed structure.
  689. # """
  690. # import networkx as nx
  691. # from os.path import basename
  692. # from tqdm import tqdm
  693. # import sys
  694. # data = []
  695. # with open(filename) as f:
  696. # content = f.read().splitlines()
  697. # index = 0
  698. # pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  699. # while index < len(content):
  700. # index_old = index
  701. # g = nx.Graph(name=content[index].strip()) # set name of the graph
  702. # tmp = content[index + 3]
  703. # nb_nodes = int(tmp[:3]) # number of the nodes
  704. # nb_edges = int(tmp[3:6]) # number of the edges
  705. # for i in range(0, nb_nodes):
  706. # tmp = content[i + index + 4]
  707. # g.add_node(i, atom=tmp[31:34].strip())
  708. # for i in range(0, nb_edges):
  709. # tmp = content[i + index + g.number_of_nodes() + 4]
  710. # tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  711. # g.add_edge(
  712. # int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  713. # data.append(g)
  714. # index += 4 + g.number_of_nodes() + g.number_of_edges()
  715. # while content[index].strip() != '$$$$': # seperator
  716. # index += 1
  717. # index += 1
  718. # pbar.update(index - index_old)
  719. # pbar.update(1)
  720. # pbar.close()
  721. # return data
  722. # def load_from_cxl(filename):
  723. # import xml.etree.ElementTree as ET
  724. #
  725. # dirname_dataset = dirname(filename)
  726. # tree = ET.parse(filename)
  727. # root = tree.getroot()
  728. # data = []
  729. # y = []
  730. # for graph in root.iter('graph'):
  731. # mol_filename = graph.attrib['file']
  732. # mol_class = graph.attrib['class']
  733. # data.append(load_gxl(dirname_dataset + '/' + mol_filename))
  734. # y.append(mol_class)
  735. if __name__ == '__main__':
  736. # ### Load dataset from .ds file.
  737. # # .ct files.
  738. # ds = {'name': 'Alkane', 'dataset': '../../datasets/Alkane/dataset.ds',
  739. # 'dataset_y': '../../datasets/Alkane/dataset_boiling_point_names.txt'}
  740. # Gn, y = loadDataset(ds['dataset'], filename_y=ds['dataset_y'])
  741. # ds_file = '../../datasets/Acyclic/dataset_bps.ds' # node symb
  742. # Gn, targets, label_names = load_dataset(ds_file)
  743. # ds_file = '../../datasets/MAO/dataset.ds' # node/edge symb
  744. # Gn, targets, label_names = load_dataset(ds_file)
  745. ## ds = {'name': 'PAH', 'dataset': '../../datasets/PAH/dataset.ds'} # unlabeled
  746. ## Gn, y = loadDataset(ds['dataset'])
  747. # print(Gn[1].graph)
  748. # print(Gn[1].nodes(data=True))
  749. # print(Gn[1].edges(data=True))
  750. # print(targets[1])
  751. # # .gxl file.
  752. # ds_file = '../../datasets/monoterpenoides/dataset_10+.ds' # node/edge symb
  753. # Gn, y, label_names = load_dataset(ds_file)
  754. # print(Gn[1].graph)
  755. # print(Gn[1].nodes(data=True))
  756. # print(Gn[1].edges(data=True))
  757. # print(y[1])
  758. # .mat file.
  759. ds_file = '../../datasets/MUTAG_mat/MUTAG.mat'
  760. order = [0, 0, 3, 1, 2]
  761. gloader = DataLoader(ds_file, order=order)
  762. Gn, targets, label_names = gloader.data
  763. print(Gn[1].graph)
  764. print(Gn[1].nodes(data=True))
  765. print(Gn[1].edges(data=True))
  766. print(targets[1])
  767. # ### Convert graph from one format to another.
  768. # # .gxl file.
  769. # import networkx as nx
  770. # ds = {'name': 'monoterpenoides',
  771. # 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  772. # Gn, y = loadDataset(ds['dataset'])
  773. # y = [int(i) for i in y]
  774. # print(Gn[1].nodes(data=True))
  775. # print(Gn[1].edges(data=True))
  776. # print(y[1])
  777. # # Convert a graph to the proper NetworkX format that can be recognized by library gedlib.
  778. # Gn_new = []
  779. # for G in Gn:
  780. # G_new = nx.Graph()
  781. # for nd, attrs in G.nodes(data=True):
  782. # G_new.add_node(str(nd), chem=attrs['atom'])
  783. # for nd1, nd2, attrs in G.edges(data=True):
  784. # G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  785. ## G_new.add_edge(str(nd1), str(nd2))
  786. # Gn_new.append(G_new)
  787. # print(Gn_new[1].nodes(data=True))
  788. # print(Gn_new[1].edges(data=True))
  789. # print(Gn_new[1])
  790. # filename = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/gxl/monoterpenoides'
  791. # xparams = {'method': 'gedlib'}
  792. # saveDataset(Gn, y, gformat='gxl', group='xml', filename=filename, xparams=xparams)
  793. # save dataset.
  794. # ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
  795. # 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  796. # Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  797. # saveDataset(Gn, y, group='xml', filename='temp/temp')
  798. # test - new way to add labels and attributes.
  799. # dataset = '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
  800. # filename = '../../datasets/Fingerprint/Fingerprint_A.txt'
  801. # dataset = '../../datasets/Letter-med/Letter-med_A.txt'
  802. # dataset = '../../datasets/AIDS/AIDS_A.txt'
  803. # dataset = '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt'
  804. # Gn, targets, label_names = load_dataset(filename)
  805. pass

A Python package for graph kernels, graph edit distances and graph pre-image problem.