You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphfiles.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. """ Utilities function to manage graph files
  2. """
  3. def loadCT(filename):
  4. """load data from .ct file.
  5. Notes
  6. ------
  7. a typical example of data in .ct is like this:
  8. 3 2 <- number of nodes and edges
  9. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  10. 0.0000 0.0000 0.0000 C
  11. 0.0000 0.0000 0.0000 O
  12. 1 3 1 1 <- each line describes an edge : to, from,?, label
  13. 2 3 1 1
  14. """
  15. import networkx as nx
  16. from os.path import basename
  17. g = nx.Graph()
  18. with open(filename) as f:
  19. content = f.read().splitlines()
  20. g = nx.Graph(
  21. name=str(content[0]),
  22. filename=basename(filename)) # set name of the graph
  23. tmp = content[1].split(" ")
  24. if tmp[0] == '':
  25. nb_nodes = int(tmp[1]) # number of the nodes
  26. nb_edges = int(tmp[2]) # number of the edges
  27. else:
  28. nb_nodes = int(tmp[0])
  29. nb_edges = int(tmp[1])
  30. # patch for compatibility : label will be removed later
  31. for i in range(0, nb_nodes):
  32. tmp = content[i + 2].split(" ")
  33. tmp = [x for x in tmp if x != '']
  34. g.add_node(i, atom=tmp[3], label=tmp[3])
  35. for i in range(0, nb_edges):
  36. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  37. tmp = [x for x in tmp if x != '']
  38. g.add_edge(
  39. int(tmp[0]) - 1,
  40. int(tmp[1]) - 1,
  41. bond_type=tmp[3].strip(),
  42. label=tmp[3].strip())
  43. # for i in range(0, nb_edges):
  44. # tmp = content[i + g.number_of_nodes() + 2]
  45. # tmp = [tmp[i:i+3] for i in range(0, len(tmp), 3)]
  46. # g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  47. # bond_type=tmp[3].strip(), label=tmp[3].strip())
  48. return g
  49. def loadGXL(filename):
  50. from os.path import basename
  51. import networkx as nx
  52. import xml.etree.ElementTree as ET
  53. tree = ET.parse(filename)
  54. root = tree.getroot()
  55. index = 0
  56. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  57. dic = {} # used to retrieve incident nodes of edges
  58. for node in root.iter('node'):
  59. dic[node.attrib['id']] = index
  60. labels = {}
  61. for attr in node.iter('attr'):
  62. labels[attr.attrib['name']] = attr[0].text
  63. if 'chem' in labels:
  64. labels['label'] = labels['chem']
  65. g.add_node(index, **labels)
  66. index += 1
  67. for edge in root.iter('edge'):
  68. labels = {}
  69. for attr in edge.iter('attr'):
  70. labels[attr.attrib['name']] = attr[0].text
  71. if 'valence' in labels:
  72. labels['label'] = labels['valence']
  73. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  74. return g
  75. def saveGXL(graph, filename):
  76. import xml.etree.ElementTree as ET
  77. root_node = ET.Element('gxl')
  78. attr = dict()
  79. attr['id'] = graph.graph['name']
  80. attr['edgeids'] = 'true'
  81. attr['edgemode'] = 'undirected'
  82. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  83. for v in graph:
  84. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  85. for attr in graph.nodes[v].keys():
  86. cur_attr = ET.SubElement(
  87. current_node, 'attr', attrib={'name': attr})
  88. cur_value = ET.SubElement(cur_attr,
  89. graph.nodes[v][attr].__class__.__name__)
  90. cur_value.text = graph.nodes[v][attr]
  91. for v1 in graph:
  92. for v2 in graph[v1]:
  93. if (v1 < v2): # Non oriented graphs
  94. cur_edge = ET.SubElement(
  95. graph_node,
  96. 'edge',
  97. attrib={
  98. 'from': str(v1),
  99. 'to': str(v2)
  100. })
  101. for attr in graph[v1][v2].keys():
  102. cur_attr = ET.SubElement(
  103. cur_edge, 'attr', attrib={'name': attr})
  104. cur_value = ET.SubElement(
  105. cur_attr, graph[v1][v2][attr].__class__.__name__)
  106. cur_value.text = str(graph[v1][v2][attr])
  107. tree = ET.ElementTree(root_node)
  108. tree.write(filename)
  109. def loadSDF(filename):
  110. """load data from structured data file (.sdf file).
  111. Notes
  112. ------
  113. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  114. Check http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx, 2018 for detailed structure.
  115. """
  116. import networkx as nx
  117. from os.path import basename
  118. from tqdm import tqdm
  119. import sys
  120. data = []
  121. with open(filename) as f:
  122. content = f.read().splitlines()
  123. index = 0
  124. pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  125. while index < len(content):
  126. index_old = index
  127. g = nx.Graph(name=content[index].strip()) # set name of the graph
  128. tmp = content[index + 3]
  129. nb_nodes = int(tmp[:3]) # number of the nodes
  130. nb_edges = int(tmp[3:6]) # number of the edges
  131. for i in range(0, nb_nodes):
  132. tmp = content[i + index + 4]
  133. g.add_node(i, atom=tmp[31:34].strip())
  134. for i in range(0, nb_edges):
  135. tmp = content[i + index + g.number_of_nodes() + 4]
  136. tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)]
  137. g.add_edge(
  138. int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  139. data.append(g)
  140. index += 4 + g.number_of_nodes() + g.number_of_edges()
  141. while content[index].strip() != '$$$$': # seperator
  142. index += 1
  143. index += 1
  144. pbar.update(index - index_old)
  145. pbar.update(1)
  146. pbar.close()
  147. return data
  148. def loadMAT(filename, extra_params):
  149. """Load graph data from a MATLAB (up to version 7.1) .mat file.
  150. Notes
  151. ------
  152. A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
  153. Check README in downloadable file in http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/, 2018 for detailed structure.
  154. """
  155. from scipy.io import loadmat
  156. import numpy as np
  157. import networkx as nx
  158. data = []
  159. content = loadmat(filename)
  160. order = extra_params['am_sp_al_nl_el']
  161. # print(content)
  162. # print('----')
  163. for key, value in content.items():
  164. if key[0] == 'l': # class label
  165. y = np.transpose(value)[0].tolist()
  166. # print(y)
  167. elif key[0] != '_':
  168. # print(value[0][0][0])
  169. # print()
  170. # print(value[0][0][1])
  171. # print()
  172. # print(value[0][0][2])
  173. # print()
  174. # if len(value[0][0]) > 3:
  175. # print(value[0][0][3])
  176. # print('----')
  177. # if adjacency matrix is not compressed / edge label exists
  178. if order[1] == 0:
  179. for i, item in enumerate(value[0]):
  180. # print(item)
  181. # print('------')
  182. g = nx.Graph(name=i) # set name of the graph
  183. nl = np.transpose(item[order[3]][0][0][0]) # node label
  184. # print(item[order[3]])
  185. # print()
  186. for index, label in enumerate(nl[0]):
  187. g.add_node(index, atom=str(label))
  188. el = item[order[4]][0][0][0] # edge label
  189. for edge in el:
  190. g.add_edge(
  191. edge[0] - 1, edge[1] - 1, bond_type=str(edge[2]))
  192. data.append(g)
  193. else:
  194. from scipy.sparse import csc_matrix
  195. for i, item in enumerate(value[0]):
  196. # print(item)
  197. # print('------')
  198. g = nx.Graph(name=i) # set name of the graph
  199. nl = np.transpose(item[order[3]][0][0][0]) # node label
  200. # print(nl)
  201. # print()
  202. for index, label in enumerate(nl[0]):
  203. g.add_node(index, atom=str(label))
  204. sam = item[order[0]] # sparse adjacency matrix
  205. index_no0 = sam.nonzero()
  206. for col, row in zip(index_no0[0], index_no0[1]):
  207. # print(col)
  208. # print(row)
  209. g.add_edge(col, row)
  210. data.append(g)
  211. # print(g.edges(data=True))
  212. return data, y
  213. def loadTXT(dirname_dataset):
  214. """Load graph data from a .txt file.
  215. Notes
  216. ------
  217. The graph data is loaded from separate files.
  218. Check README in downloadable file http://tiny.cc/PK_MLJ_data, 2018 for detailed structure.
  219. """
  220. import numpy as np
  221. import networkx as nx
  222. from os import listdir
  223. from os.path import dirname
  224. # load data file names
  225. for name in listdir(dirname_dataset):
  226. if '_A' in name:
  227. fam = dirname_dataset + '/' + name
  228. elif '_graph_indicator' in name:
  229. fgi = dirname_dataset + '/' + name
  230. elif '_graph_labels' in name:
  231. fgl = dirname_dataset + '/' + name
  232. elif '_node_labels' in name:
  233. fnl = dirname_dataset + '/' + name
  234. elif '_edge_labels' in name:
  235. fel = dirname_dataset + '/' + name
  236. elif '_edge_attributes' in name:
  237. fea = dirname_dataset + '/' + name
  238. elif '_node_attributes' in name:
  239. fna = dirname_dataset + '/' + name
  240. elif '_graph_attributes' in name:
  241. fga = dirname_dataset + '/' + name
  242. # this is supposed to be the node attrs, make sure to put this as the last 'elif'
  243. elif '_attributes' in name:
  244. fna = dirname_dataset + '/' + name
  245. content_gi = open(fgi).read().splitlines() # graph indicator
  246. content_am = open(fam).read().splitlines() # adjacency matrix
  247. content_gl = open(fgl).read().splitlines() # lass labels
  248. # create graphs and add nodes
  249. data = [nx.Graph(name=i) for i in range(0, len(content_gl))]
  250. if 'fnl' in locals():
  251. content_nl = open(fnl).read().splitlines() # node labels
  252. for i, line in enumerate(content_gi):
  253. # transfer to int first in case of unexpected blanks
  254. data[int(line) - 1].add_node(i, atom=str(int(content_nl[i])))
  255. else:
  256. for i, line in enumerate(content_gi):
  257. data[int(line) - 1].add_node(i)
  258. # add edges
  259. for line in content_am:
  260. tmp = line.split(',')
  261. n1 = int(tmp[0]) - 1
  262. n2 = int(tmp[1]) - 1
  263. # ignore edge weight here.
  264. g = int(content_gi[n1]) - 1
  265. data[g].add_edge(n1, n2)
  266. # add edge labels
  267. if 'fel' in locals():
  268. content_el = open(fel).read().splitlines()
  269. for index, line in enumerate(content_el):
  270. label = line.strip()
  271. n = [int(i) - 1 for i in content_am[index].split(',')]
  272. g = int(content_gi[n[0]]) - 1
  273. data[g].edges[n[0], n[1]]['bond_type'] = label
  274. # add node attributes
  275. if 'fna' in locals():
  276. content_na = open(fna).read().splitlines()
  277. for i, line in enumerate(content_na):
  278. attrs = [i.strip() for i in line.split(',')]
  279. g = int(content_gi[i]) - 1
  280. data[g].nodes[i]['attributes'] = attrs
  281. # add edge attributes
  282. if 'fea' in locals():
  283. content_ea = open(fea).read().splitlines()
  284. for index, line in enumerate(content_ea):
  285. attrs = [i.strip() for i in line.split(',')]
  286. n = [int(i) - 1 for i in content_am[index].split(',')]
  287. g = int(content_gi[n[0]]) - 1
  288. data[g].edges[n[0], n[1]]['attributes'] = attrs
  289. # load y
  290. y = [int(i) for i in content_gl]
  291. return data, y
  292. def loadDataset(filename, filename_y=None, extra_params=None):
  293. """load file list of the dataset.
  294. """
  295. from os.path import dirname, splitext
  296. dirname_dataset = dirname(filename)
  297. extension = splitext(filename)[1][1:]
  298. data = []
  299. y = []
  300. if extension == "ds":
  301. content = open(filename).read().splitlines()
  302. if filename_y is None or filename_y == '':
  303. for i in range(0, len(content)):
  304. tmp = content[i].split(' ')
  305. # remove the '#'s in file names
  306. data.append(
  307. loadCT(dirname_dataset + '/' + tmp[0].replace('#', '', 1)))
  308. y.append(float(tmp[1]))
  309. else: # y in a seperate file
  310. for i in range(0, len(content)):
  311. tmp = content[i]
  312. # remove the '#'s in file names
  313. data.append(
  314. loadCT(dirname_dataset + '/' + tmp.replace('#', '', 1)))
  315. content_y = open(filename_y).read().splitlines()
  316. # assume entries in filename and filename_y have the same order.
  317. for item in content_y:
  318. tmp = item.split(' ')
  319. # assume the 3rd entry in a line is y (for Alkane dataset)
  320. y.append(float(tmp[2]))
  321. elif extension == "cxl":
  322. import xml.etree.ElementTree as ET
  323. tree = ET.parse(filename)
  324. root = tree.getroot()
  325. data = []
  326. y = []
  327. for graph in root.iter('print'):
  328. mol_filename = graph.attrib['file']
  329. mol_class = graph.attrib['class']
  330. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  331. y.append(mol_class)
  332. elif extension == "sdf":
  333. import numpy as np
  334. from tqdm import tqdm
  335. import sys
  336. data = loadSDF(filename)
  337. y_raw = open(filename_y).read().splitlines()
  338. y_raw.pop(0)
  339. tmp0 = []
  340. tmp1 = []
  341. for i in range(0, len(y_raw)):
  342. tmp = y_raw[i].split(',')
  343. tmp0.append(tmp[0])
  344. tmp1.append(tmp[1].strip())
  345. y = []
  346. for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout):
  347. try:
  348. y.append(tmp1[tmp0.index(data[i].name)].strip())
  349. except ValueError: # if data[i].name not in tmp0
  350. data[i] = []
  351. data = list(filter(lambda a: a != [], data))
  352. elif extension == "mat":
  353. data, y = loadMAT(filename, extra_params)
  354. elif extension == 'txt':
  355. data, y = loadTXT(dirname_dataset)
  356. # print(len(y))
  357. # print(y)
  358. # print(data[0].nodes(data=True))
  359. # print('----')
  360. # print(data[0].edges(data=True))
  361. # for g in data:
  362. # print(g.nodes(data=True))
  363. # print('----')
  364. # print(g.edges(data=True))
  365. return data, y

A Python package for graph kernels, graph edit distances and graph pre-image problem.