You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphfiles.py 25 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. """ Utilities function to manage graph files
  2. """
  3. from os.path import dirname, splitext
  4. def loadCT(filename):
  5. """load data from a Chemical Table (.ct) file.
  6. Notes
  7. ------
  8. a typical example of data in .ct is like this:
  9. 3 2 <- number of nodes and edges
  10. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  11. 0.0000 0.0000 0.0000 C
  12. 0.0000 0.0000 0.0000 O
  13. 1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo
  14. 2 3 1 1
  15. Check https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=10&ved=2ahUKEwivhaSdjsTlAhVhx4UKHczHA8gQFjAJegQIARAC&url=https%3A%2F%2Fwww.daylight.com%2Fmeetings%2Fmug05%2FKappler%2Fctfile.pdf&usg=AOvVaw1cDNrrmMClkFPqodlF2inS
  16. for detailed format discription.
  17. """
  18. import networkx as nx
  19. from os.path import basename
  20. g = nx.Graph()
  21. with open(filename) as f:
  22. content = f.read().splitlines()
  23. g = nx.Graph(
  24. name = str(content[0]),
  25. filename = basename(filename)) # set name of the graph
  26. tmp = content[1].split(" ")
  27. if tmp[0] == '':
  28. nb_nodes = int(tmp[1]) # number of the nodes
  29. nb_edges = int(tmp[2]) # number of the edges
  30. else:
  31. nb_nodes = int(tmp[0])
  32. nb_edges = int(tmp[1])
  33. # patch for compatibility : label will be removed later
  34. for i in range(0, nb_nodes):
  35. tmp = content[i + 2].split(" ")
  36. tmp = [x for x in tmp if x != '']
  37. g.add_node(i, atom=tmp[3].strip(),
  38. label=[item.strip() for item in tmp[3:]],
  39. attributes=[item.strip() for item in tmp[0:3]])
  40. for i in range(0, nb_edges):
  41. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  42. tmp = [x for x in tmp if x != '']
  43. g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  44. bond_type=tmp[2].strip(),
  45. label=[item.strip() for item in tmp[2:]])
  46. return g
  47. def loadGXL(filename):
  48. from os.path import basename
  49. import networkx as nx
  50. import xml.etree.ElementTree as ET
  51. tree = ET.parse(filename)
  52. root = tree.getroot()
  53. index = 0
  54. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  55. dic = {} # used to retrieve incident nodes of edges
  56. for node in root.iter('node'):
  57. dic[node.attrib['id']] = index
  58. labels = {}
  59. for attr in node.iter('attr'):
  60. labels[attr.attrib['name']] = attr[0].text
  61. if 'chem' in labels:
  62. labels['label'] = labels['chem']
  63. labels['atom'] = labels['chem']
  64. g.add_node(index, **labels)
  65. index += 1
  66. for edge in root.iter('edge'):
  67. labels = {}
  68. for attr in edge.iter('attr'):
  69. labels[attr.attrib['name']] = attr[0].text
  70. if 'valence' in labels:
  71. labels['label'] = labels['valence']
  72. labels['bond_type'] = labels['valence']
  73. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  74. return g
  75. def saveGXL(graph, filename, method='benoit'):
  76. if method == 'benoit':
  77. import xml.etree.ElementTree as ET
  78. root_node = ET.Element('gxl')
  79. attr = dict()
  80. attr['id'] = str(graph.graph['name'])
  81. attr['edgeids'] = 'true'
  82. attr['edgemode'] = 'undirected'
  83. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  84. for v in graph:
  85. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  86. for attr in graph.nodes[v].keys():
  87. cur_attr = ET.SubElement(
  88. current_node, 'attr', attrib={'name': attr})
  89. cur_value = ET.SubElement(cur_attr,
  90. graph.nodes[v][attr].__class__.__name__)
  91. cur_value.text = graph.nodes[v][attr]
  92. for v1 in graph:
  93. for v2 in graph[v1]:
  94. if (v1 < v2): # Non oriented graphs
  95. cur_edge = ET.SubElement(
  96. graph_node,
  97. 'edge',
  98. attrib={
  99. 'from': str(v1),
  100. 'to': str(v2)
  101. })
  102. for attr in graph[v1][v2].keys():
  103. cur_attr = ET.SubElement(
  104. cur_edge, 'attr', attrib={'name': attr})
  105. cur_value = ET.SubElement(
  106. cur_attr, graph[v1][v2][attr].__class__.__name__)
  107. cur_value.text = str(graph[v1][v2][attr])
  108. tree = ET.ElementTree(root_node)
  109. tree.write(filename)
  110. elif method == 'gedlib':
  111. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  112. # pass
  113. gxl_file = open(filename, 'w')
  114. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  115. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  116. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  117. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"true\" edgemode=\"undirected\">\n")
  118. for v, attrs in graph.nodes(data=True):
  119. gxl_file.write("<node id=\"_" + str(v) + "\">")
  120. gxl_file.write("<attr name=\"" + "chem" + "\"><int>" + str(attrs['chem']) + "</int></attr>")
  121. gxl_file.write("</node>\n")
  122. for v1, v2, attrs in graph.edges(data=True):
  123. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
  124. gxl_file.write("<attr name=\"valence\"><int>" + str(attrs['valence']) + "</int></attr>")
  125. # gxl_file.write("<attr name=\"valence\"><int>" + "1" + "</int></attr>")
  126. gxl_file.write("</edge>\n")
  127. gxl_file.write("</graph>\n")
  128. gxl_file.write("</gxl>")
  129. gxl_file.close()
  130. elif method == 'gedlib-letter':
  131. # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22
  132. # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl
  133. gxl_file = open(filename, 'w')
  134. gxl_file.write("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n")
  135. gxl_file.write("<!DOCTYPE gxl SYSTEM \"http://www.gupro.de/GXL/gxl-1.0.dtd\">\n")
  136. gxl_file.write("<gxl xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n")
  137. gxl_file.write("<graph id=\"" + str(graph.graph['name']) + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
  138. for v, attrs in graph.nodes(data=True):
  139. gxl_file.write("<node id=\"_" + str(v) + "\">")
  140. gxl_file.write("<attr name=\"x\"><float>" + str(attrs['attributes'][0]) + "</float></attr>")
  141. gxl_file.write("<attr name=\"y\"><float>" + str(attrs['attributes'][1]) + "</float></attr>")
  142. gxl_file.write("</node>\n")
  143. for v1, v2, attrs in graph.edges(data=True):
  144. gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\"/>\n")
  145. gxl_file.write("</graph>\n")
  146. gxl_file.write("</gxl>")
  147. gxl_file.close()
  148. def loadSDF(filename):
  149. """load data from structured data file (.sdf file).
  150. Notes
  151. ------
  152. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  153. Check http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx, 2018 for detailed structure.
  154. """
  155. import networkx as nx
  156. from os.path import basename
  157. from tqdm import tqdm
  158. import sys
  159. data = []
  160. with open(filename) as f:
  161. content = f.read().splitlines()
  162. index = 0
  163. pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  164. while index < len(content):
  165. index_old = index
  166. g = nx.Graph(name=content[index].strip()) # set name of the graph
  167. tmp = content[index + 3]
  168. nb_nodes = int(tmp[:3]) # number of the nodes
  169. nb_edges = int(tmp[3:6]) # number of the edges
  170. for i in range(0, nb_nodes):
  171. tmp = content[i + index + 4]
  172. g.add_node(i, atom=tmp[31:34].strip())
  173. for i in range(0, nb_edges):
  174. tmp = content[i + index + g.number_of_nodes() + 4]
  175. tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  176. g.add_edge(
  177. int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  178. data.append(g)
  179. index += 4 + g.number_of_nodes() + g.number_of_edges()
  180. while content[index].strip() != '$$$$': # seperator
  181. index += 1
  182. index += 1
  183. pbar.update(index - index_old)
  184. pbar.update(1)
  185. pbar.close()
  186. return data
  187. def loadMAT(filename, extra_params):
  188. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  189. Notes
  190. ------
  191. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  192. Check README in downloadable file in http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/, 2018 for detailed structure.
  193. """
  194. from scipy.io import loadmat
  195. import numpy as np
  196. import networkx as nx
  197. data = []
  198. content = loadmat(filename)
  199. order = extra_params['am_sp_al_nl_el']
  200. # print(content)
  201. # print('----')
  202. for key, value in content.items():
  203. if key[0] == 'l': # class label
  204. y = np.transpose(value)[0].tolist()
  205. # print(y)
  206. elif key[0] != '_':
  207. # print(value[0][0][0])
  208. # print()
  209. # print(value[0][0][1])
  210. # print()
  211. # print(value[0][0][2])
  212. # print()
  213. # if len(value[0][0]) > 3:
  214. # print(value[0][0][3])
  215. # print('----')
  216. # if adjacency matrix is not compressed / edge label exists
  217. if order[1] == 0:
  218. for i, item in enumerate(value[0]):
  219. # print(item)
  220. # print('------')
  221. g = nx.Graph(name=i) # set name of the graph
  222. nl = np.transpose(item[order[3]][0][0][0]) # node label
  223. # print(item[order[3]])
  224. # print()
  225. for index, label in enumerate(nl[0]):
  226. g.add_node(index, atom=str(label))
  227. el = item[order[4]][0][0][0] # edge label
  228. for edge in el:
  229. g.add_edge(
  230. edge[0] - 1, edge[1] - 1, bond_type=str(edge[2]))
  231. data.append(g)
  232. else:
  233. from scipy.sparse import csc_matrix
  234. for i, item in enumerate(value[0]):
  235. # print(item)
  236. # print('------')
  237. g = nx.Graph(name=i) # set name of the graph
  238. nl = np.transpose(item[order[3]][0][0][0]) # node label
  239. # print(nl)
  240. # print()
  241. for index, label in enumerate(nl[0]):
  242. g.add_node(index, atom=str(label))
  243. sam = item[order[0]] # sparse adjacency matrix
  244. index_no0 = sam.nonzero()
  245. for col, row in zip(index_no0[0], index_no0[1]):
  246. # print(col)
  247. # print(row)
  248. g.add_edge(col, row)
  249. data.append(g)
  250. # print(g.edges(data=True))
  251. return data, y
  252. def loadTXT(dirname_dataset):
  253. """Load graph data from a .txt file.
  254. Notes
  255. ------
  256. The graph data is loaded from separate files.
  257. Check README in downloadable file http://tiny.cc/PK_MLJ_data, 2018 for detailed structure.
  258. """
  259. import numpy as np
  260. import networkx as nx
  261. from os import listdir
  262. from os.path import dirname
  263. # load data file names
  264. for name in listdir(dirname_dataset):
  265. if '_A' in name:
  266. fam = dirname_dataset + '/' + name
  267. elif '_graph_indicator' in name:
  268. fgi = dirname_dataset + '/' + name
  269. elif '_graph_labels' in name:
  270. fgl = dirname_dataset + '/' + name
  271. elif '_node_labels' in name:
  272. fnl = dirname_dataset + '/' + name
  273. elif '_edge_labels' in name:
  274. fel = dirname_dataset + '/' + name
  275. elif '_edge_attributes' in name:
  276. fea = dirname_dataset + '/' + name
  277. elif '_node_attributes' in name:
  278. fna = dirname_dataset + '/' + name
  279. elif '_graph_attributes' in name:
  280. fga = dirname_dataset + '/' + name
  281. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  282. elif '_attributes' in name:
  283. fna = dirname_dataset + '/' + name
  284. content_gi = open(fgi).read().splitlines() # graph indicator
  285. content_am = open(fam).read().splitlines() # adjacency matrix
  286. content_gl = open(fgl).read().splitlines() # lass labels
  287. # create graphs and add nodes
  288. data = [nx.Graph(name=i) for i in range(0, len(content_gl))]
  289. if 'fnl' in locals():
  290. content_nl = open(fnl).read().splitlines() # node labels
  291. for i, line in enumerate(content_gi):
  292. # transfer to int first in case of unexpected blanks
  293. data[int(line) - 1].add_node(i, atom=str(int(content_nl[i])))
  294. else:
  295. for i, line in enumerate(content_gi):
  296. data[int(line) - 1].add_node(i)
  297. # add edges
  298. for line in content_am:
  299. tmp = line.split(',')
  300. n1 = int(tmp[0]) - 1
  301. n2 = int(tmp[1]) - 1
  302. # ignore edge weight here.
  303. g = int(content_gi[n1]) - 1
  304. data[g].add_edge(n1, n2)
  305. # add edge labels
  306. if 'fel' in locals():
  307. content_el = open(fel).read().splitlines()
  308. for index, line in enumerate(content_el):
  309. label = line.strip()
  310. n = [int(i) - 1 for i in content_am[index].split(',')]
  311. g = int(content_gi[n[0]]) - 1
  312. data[g].edges[n[0], n[1]]['bond_type'] = label
  313. # add node attributes
  314. if 'fna' in locals():
  315. content_na = open(fna).read().splitlines()
  316. for i, line in enumerate(content_na):
  317. attrs = [i.strip() for i in line.split(',')]
  318. g = int(content_gi[i]) - 1
  319. data[g].nodes[i]['attributes'] = attrs
  320. # add edge attributes
  321. if 'fea' in locals():
  322. content_ea = open(fea).read().splitlines()
  323. for index, line in enumerate(content_ea):
  324. attrs = [i.strip() for i in line.split(',')]
  325. n = [int(i) - 1 for i in content_am[index].split(',')]
  326. g = int(content_gi[n[0]]) - 1
  327. data[g].edges[n[0], n[1]]['attributes'] = attrs
  328. # load y
  329. y = [int(i) for i in content_gl]
  330. return data, y
  331. def loadDataset(filename, filename_y=None, extra_params=None):
  332. """Read graph data from filename and load them as NetworkX graphs.
  333. Parameters
  334. ----------
  335. filename : string
  336. The name of the file from where the dataset is read.
  337. filename_y : string
  338. The name of file of the targets corresponding to graphs.
  339. extra_params : dict
  340. Extra parameters only designated to '.mat' format.
  341. Return
  342. ------
  343. data : List of NetworkX graph.
  344. y : List
  345. Targets corresponding to graphs.
  346. Notes
  347. -----
  348. This function supports following graph dataset formats:
  349. 'ds': load data from .ds file. See comments of function loadFromDS for a example.
  350. 'cxl': load data from Graph eXchange Language file (.cxl file). See
  351. http://www.gupro.de/GXL/Introduction/background.html, 2019 for detail.
  352. 'sdf': load data from structured data file (.sdf file). See
  353. http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx,
  354. 2018 for details.
  355. 'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See
  356. README in downloadable file in http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/,
  357. 2018 for details.
  358. 'txt': Load graph data from a special .txt file. See
  359. https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets,
  360. 2019 for details. Note here filename is the name of either .txt file in
  361. the dataset directory.
  362. """
  363. extension = splitext(filename)[1][1:]
  364. if extension == "ds":
  365. data, y = loadFromDS(filename, filename_y)
  366. elif extension == "cxl":
  367. import xml.etree.ElementTree as ET
  368. dirname_dataset = dirname(filename)
  369. tree = ET.parse(filename)
  370. root = tree.getroot()
  371. data = []
  372. y = []
  373. for graph in root.iter('graph'):
  374. mol_filename = graph.attrib['file']
  375. mol_class = graph.attrib['class']
  376. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  377. y.append(mol_class)
  378. elif extension == 'xml':
  379. data, y = loadFromXML(filename, extra_params)
  380. elif extension == "sdf":
  381. import numpy as np
  382. from tqdm import tqdm
  383. import sys
  384. data = loadSDF(filename)
  385. y_raw = open(filename_y).read().splitlines()
  386. y_raw.pop(0)
  387. tmp0 = []
  388. tmp1 = []
  389. for i in range(0, len(y_raw)):
  390. tmp = y_raw[i].split(',')
  391. tmp0.append(tmp[0])
  392. tmp1.append(tmp[1].strip())
  393. y = []
  394. for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout):
  395. try:
  396. y.append(tmp1[tmp0.index(data[i].name)].strip())
  397. except ValueError: # if data[i].name not in tmp0
  398. data[i] = []
  399. data = list(filter(lambda a: a != [], data))
  400. elif extension == "mat":
  401. data, y = loadMAT(filename, extra_params)
  402. elif extension == 'txt':
  403. dirname_dataset = dirname(filename)
  404. data, y = loadTXT(dirname_dataset)
  405. # print(len(y))
  406. # print(y)
  407. # print(data[0].nodes(data=True))
  408. # print('----')
  409. # print(data[0].edges(data=True))
  410. # for g in data:
  411. # print(g.nodes(data=True))
  412. # print('----')
  413. # print(g.edges(data=True))
  414. return data, y
  415. def loadFromXML(filename, extra_params):
  416. import xml.etree.ElementTree as ET
  417. if extra_params:
  418. dirname_dataset = extra_params
  419. else:
  420. dirname_dataset = dirname(filename)
  421. tree = ET.parse(filename)
  422. root = tree.getroot()
  423. data = []
  424. y = []
  425. for graph in root.iter('graph'):
  426. mol_filename = graph.attrib['file']
  427. mol_class = graph.attrib['class']
  428. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  429. y.append(mol_class)
  430. return data, y
  431. def loadFromDS(filename, filename_y):
  432. """Load data from .ds file.
  433. Possible graph formats include:
  434. '.ct': see function loadCT for detail.
  435. '.gxl': see dunction loadGXL for detail.
  436. Note these graph formats are checked automatically by the extensions of
  437. graph files.
  438. """
  439. dirname_dataset = dirname(filename)
  440. data = []
  441. y = []
  442. content = open(filename).read().splitlines()
  443. extension = splitext(content[0].split(' ')[0])[1][1:]
  444. if filename_y is None or filename_y == '':
  445. if extension == 'ct':
  446. for i in range(0, len(content)):
  447. tmp = content[i].split(' ')
  448. # remove the '#'s in file names
  449. data.append(
  450. loadCT(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  451. y.append(float(tmp[1]))
  452. elif extension == 'gxl':
  453. for i in range(0, len(content)):
  454. tmp = content[i].split(' ')
  455. # remove the '#'s in file names
  456. data.append(
  457. loadGXL(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  458. y.append(float(tmp[1]))
  459. else: # y in a seperate file
  460. if extension == 'ct':
  461. for i in range(0, len(content)):
  462. tmp = content[i]
  463. # remove the '#'s in file names
  464. data.append(
  465. loadCT(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  466. elif extension == 'gxl':
  467. for i in range(0, len(content)):
  468. tmp = content[i]
  469. # remove the '#'s in file names
  470. data.append(
  471. loadGXL(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  472. content_y = open(filename_y).read().splitlines()
  473. # assume entries in filename and filename_y have the same order.
  474. for item in content_y:
  475. tmp = item.split(' ')
  476. # assume the 3rd entry in a line is y (for Alkane dataset)
  477. y.append(float(tmp[2]))
  478. return data, y
  479. def saveDataset(Gn, y, gformat='gxl', group=None, filename='gfile', xparams=None):
  480. """Save list of graphs.
  481. """
  482. import os
  483. dirname_ds = os.path.dirname(filename)
  484. if dirname_ds != '':
  485. dirname_ds += '/'
  486. if not os.path.exists(dirname_ds) :
  487. os.makedirs(dirname_ds)
  488. if 'graph_dir' in xparams:
  489. graph_dir = xparams['graph_dir'] + '/'
  490. if not os.path.exists(graph_dir):
  491. os.makedirs(graph_dir)
  492. else:
  493. graph_dir = dirname_ds
  494. if group == 'xml' and gformat == 'gxl':
  495. with open(filename + '.xml', 'w') as fgroup:
  496. fgroup.write("<?xml version=\"1.0\"?>")
  497. fgroup.write("\n<!DOCTYPE GraphCollection SYSTEM \"http://www.inf.unibz.it/~blumenthal/dtd/GraphCollection.dtd\">")
  498. fgroup.write("\n<GraphCollection>")
  499. for idx, g in enumerate(Gn):
  500. fname_tmp = "graph" + str(idx) + ".gxl"
  501. saveGXL(g, graph_dir + fname_tmp, method=xparams['method'])
  502. fgroup.write("\n\t<graph file=\"" + fname_tmp + "\" class=\"" + str(y[idx]) + "\"/>")
  503. fgroup.write("\n</GraphCollection>")
  504. fgroup.close()
  505. if __name__ == '__main__':
  506. # ### Load dataset from .ds file.
  507. # # .ct files.
  508. # ds = {'name': 'Alkane', 'dataset': '../../datasets/Alkane/dataset.ds',
  509. # 'dataset_y': '../../datasets/Alkane/dataset_boiling_point_names.txt'}
  510. # Gn, y = loadDataset(ds['dataset'], filename_y=ds['dataset_y'])
  511. ## ds = {'name': 'Acyclic', 'dataset': '../../datasets/acyclic/dataset_bps.ds'} # node symb
  512. ## Gn, y = loadDataset(ds['dataset'])
  513. ## ds = {'name': 'MAO', 'dataset': '../../datasets/MAO/dataset.ds'} # node/edge symb
  514. ## Gn, y = loadDataset(ds['dataset'])
  515. ## ds = {'name': 'PAH', 'dataset': '../../datasets/PAH/dataset.ds'} # unlabeled
  516. ## Gn, y = loadDataset(ds['dataset'])
  517. # print(Gn[1].nodes(data=True))
  518. # print(Gn[1].edges(data=True))
  519. # print(y[1])
  520. # # .gxl file.
  521. # ds = {'name': 'monoterpenoides',
  522. # 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  523. # Gn, y = loadDataset(ds['dataset'])
  524. # print(Gn[1].nodes(data=True))
  525. # print(Gn[1].edges(data=True))
  526. # print(y[1])
  527. ### Convert graph from one format to another.
  528. # .gxl file.
  529. import networkx as nx
  530. ds = {'name': 'monoterpenoides',
  531. 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
  532. Gn, y = loadDataset(ds['dataset'])
  533. y = [int(i) for i in y]
  534. print(Gn[1].nodes(data=True))
  535. print(Gn[1].edges(data=True))
  536. print(y[1])
  537. # Convert a graph to the proper NetworkX format that can be recognized by library gedlib.
  538. Gn_new = []
  539. for G in Gn:
  540. G_new = nx.Graph()
  541. for nd, attrs in G.nodes(data=True):
  542. G_new.add_node(str(nd), chem=attrs['atom'])
  543. for nd1, nd2, attrs in G.edges(data=True):
  544. G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  545. # G_new.add_edge(str(nd1), str(nd2))
  546. Gn_new.append(G_new)
  547. print(Gn_new[1].nodes(data=True))
  548. print(Gn_new[1].edges(data=True))
  549. print(Gn_new[1])
  550. filename = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/gxl/monoterpenoides'
  551. xparams = {'method': 'gedlib'}
  552. saveDataset(Gn, y, gformat='gxl', group='xml', filename=filename, xparams=xparams)
  553. # ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
  554. # 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
  555. # Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  556. # saveDataset(Gn, y, group='xml', filename='temp/temp')

A Python package for graph kernels, graph edit distances and graph pre-image problem.