You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphfiles.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. """ Utilities function to manage graph files
  2. """
  3. def loadCT(filename):
  4. """load data from .ct file.
  5. Notes
  6. ------
  7. a typical example of data in .ct is like this:
  8. 3 2 <- number of nodes and edges
  9. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  10. 0.0000 0.0000 0.0000 C
  11. 0.0000 0.0000 0.0000 O
  12. 1 3 1 1 <- each line describes an edge : to, from,?, label
  13. 2 3 1 1
  14. """
  15. import networkx as nx
  16. from os.path import basename
  17. g = nx.Graph()
  18. with open(filename) as f:
  19. content = f.read().splitlines()
  20. g = nx.Graph(name=str(content[0]), filename=basename(
  21. filename)) # set name of the graph
  22. tmp = content[1].split(" ")
  23. if tmp[0] == '':
  24. nb_nodes = int(tmp[1]) # number of the nodes
  25. nb_edges = int(tmp[2]) # number of the edges
  26. else:
  27. nb_nodes = int(tmp[0])
  28. nb_edges = int(tmp[1])
  29. # patch for compatibility : label will be removed later
  30. for i in range(0, nb_nodes):
  31. tmp = content[i + 2].split(" ")
  32. tmp = [x for x in tmp if x != '']
  33. g.add_node(i, atom=tmp[3], label=tmp[3])
  34. for i in range(0, nb_edges):
  35. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  36. tmp = [x for x in tmp if x != '']
  37. g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  38. bond_type=tmp[3].strip(), label=tmp[3].strip())
  39. # for i in range(0, nb_edges):
  40. # tmp = content[i + g.number_of_nodes() + 2]
  41. # tmp = [tmp[i:i+3] for i in range(0, len(tmp), 3)]
  42. # g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  43. # bond_type=tmp[3].strip(), label=tmp[3].strip())
  44. return g
  45. def loadGXL(filename):
  46. from os.path import basename
  47. import networkx as nx
  48. import xml.etree.ElementTree as ET
  49. tree = ET.parse(filename)
  50. root = tree.getroot()
  51. index = 0
  52. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  53. dic = {} # used to retrieve incident nodes of edges
  54. for node in root.iter('node'):
  55. dic[node.attrib['id']] = index
  56. labels = {}
  57. for attr in node.iter('attr'):
  58. labels[attr.attrib['name']] = attr[0].text
  59. if 'chem' in labels:
  60. labels['label'] = labels['chem']
  61. g.add_node(index, **labels)
  62. index += 1
  63. for edge in root.iter('edge'):
  64. labels = {}
  65. for attr in edge.iter('attr'):
  66. labels[attr.attrib['name']] = attr[0].text
  67. if 'valence' in labels:
  68. labels['label'] = labels['valence']
  69. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  70. return g
  71. def saveGXL(graph, filename):
  72. import xml.etree.ElementTree as ET
  73. root_node = ET.Element('gxl')
  74. attr = dict()
  75. attr['id'] = graph.graph['name']
  76. attr['edgeids'] = 'true'
  77. attr['edgemode'] = 'undirected'
  78. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  79. for v in graph:
  80. current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
  81. for attr in graph.nodes[v].keys():
  82. cur_attr = ET.SubElement(
  83. current_node, 'attr', attrib={'name': attr})
  84. cur_value = ET.SubElement(
  85. cur_attr, graph.nodes[v][attr].__class__.__name__)
  86. cur_value.text = graph.nodes[v][attr]
  87. for v1 in graph:
  88. for v2 in graph[v1]:
  89. if(v1 < v2): # Non oriented graphs
  90. cur_edge = ET.SubElement(graph_node, 'edge', attrib={'from': str(v1),
  91. 'to': str(v2)})
  92. for attr in graph[v1][v2].keys():
  93. cur_attr = ET.SubElement(
  94. cur_edge, 'attr', attrib={'name': attr})
  95. cur_value = ET.SubElement(
  96. cur_attr, graph[v1][v2][attr].__class__.__name__)
  97. cur_value.text = str(graph[v1][v2][attr])
  98. tree = ET.ElementTree(root_node)
  99. tree.write(filename)
  100. def loadSDF(filename):
  101. """load data from structured data file (.sdf file).
  102. Notes
  103. ------
  104. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  105. see http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx, 2018 for detailed structure.
  106. """
  107. import networkx as nx
  108. from os.path import basename
  109. from tqdm import tqdm
  110. import sys
  111. data = []
  112. with open(filename) as f:
  113. content = f.read().splitlines()
  114. index = 0
  115. pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout)
  116. while index < len(content):
  117. index_old = index
  118. g = nx.Graph(name=content[index].strip()) # set name of the graph
  119. tmp = content[index + 3]
  120. nb_nodes = int(tmp[:3]) # number of the nodes
  121. nb_edges = int(tmp[3:6]) # number of the edges
  122. for i in range(0, nb_nodes):
  123. tmp = content[i + index + 4]
  124. g.add_node(i, atom=tmp[31:34].strip())
  125. for i in range(0, nb_edges):
  126. tmp = content[i + index + g.number_of_nodes() + 4]
  127. tmp = [tmp[i:i+3] for i in range(0, len(tmp), 3)]
  128. g.add_edge(int(tmp[0]) - 1, int(tmp[1]) -
  129. 1, bond_type=tmp[2].strip())
  130. data.append(g)
  131. index += 4 + g.number_of_nodes() + g.number_of_edges()
  132. while content[index].strip() != '$$$$': # seperator
  133. index += 1
  134. index += 1
  135. pbar.update(index - index_old)
  136. pbar.update(1)
  137. pbar.close()
  138. return data
  139. def loadDataset(filename, filename_y=''):
  140. """load file list of the dataset.
  141. """
  142. from os.path import dirname, splitext
  143. dirname_dataset = dirname(filename)
  144. extension = splitext(filename)[1][1:]
  145. data = []
  146. y = []
  147. if extension == "ds":
  148. content = open(filename).read().splitlines()
  149. for i in range(0, len(content)):
  150. tmp = content[i].split(' ')
  151. # remove the '#'s in file names
  152. data.append(loadCT(dirname_dataset + '/' +
  153. tmp[0].replace('#', '', 1)))
  154. y.append(float(tmp[1]))
  155. elif(extension == "cxl"):
  156. import xml.etree.ElementTree as ET
  157. tree = ET.parse(filename)
  158. root = tree.getroot()
  159. data = []
  160. y = []
  161. for graph in root.iter('print'):
  162. mol_filename = graph.attrib['file']
  163. mol_class = graph.attrib['class']
  164. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  165. y.append(mol_class)
  166. elif extension == "sdf":
  167. import numpy as np
  168. from tqdm import tqdm
  169. import sys
  170. data = loadSDF(filename)
  171. y_raw = open(filename_y).read().splitlines()
  172. y_raw.pop(0)
  173. tmp0 = []
  174. tmp1 = []
  175. for i in range(0, len(y_raw)):
  176. tmp = y_raw[i].split(',')
  177. tmp0.append(tmp[0])
  178. tmp1.append(tmp[1].strip())
  179. y = []
  180. for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout):
  181. try:
  182. y.append(tmp1[tmp0.index(data[i].name)].strip())
  183. except ValueError: # if data[i].name not in tmp0
  184. data[i] = []
  185. data = list(filter(lambda a: a != [], data))
  186. return data, y

A Python package for graph kernels, graph edit distances and graph pre-image problem.