You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graphfiles.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """ Utilities function to manage graph files
  2. """
  3. def loadCT(filename):
  4. """load data from .ct file.
  5. nn
  6. Notes
  7. ------
  8. a typical example of data in .ct is like this:
  9. 3 2 <- number of nodes and edges
  10. 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
  11. 0.0000 0.0000 0.0000 C
  12. 0.0000 0.0000 0.0000 O
  13. 1 3 1 1 <- each line describes an edge : to, from,?, label
  14. 2 3 1 1
  15. """
  16. import networkx as nx
  17. from os.path import basename
  18. g = nx.Graph()
  19. with open(filename) as f:
  20. content = f.read().splitlines()
  21. g = nx.Graph(name=str(content[0]), filename=basename(filename)) # set name of the graph
  22. tmp = content[1].split(" ")
  23. if tmp[0] == '':
  24. nb_nodes = int(tmp[1]) # number of the nodes
  25. nb_edges = int(tmp[2]) # number of the edges
  26. else:
  27. nb_nodes = int(tmp[0])
  28. nb_edges = int(tmp[1])
  29. # patch for compatibility : label will be removed later
  30. for i in range(0, nb_nodes):
  31. tmp = content[i + 2].split(" ")
  32. tmp = [x for x in tmp if x != '']
  33. g.add_node(i, atom=tmp[3], label=tmp[3])
  34. for i in range(0, nb_edges):
  35. tmp = content[i + g.number_of_nodes() + 2].split(" ")
  36. tmp = [x for x in tmp if x != '']
  37. g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  38. bond_type=tmp[3].strip(), label=tmp[3].strip())
  39. # for i in range(0, nb_edges):
  40. # tmp = content[i + g.number_of_nodes() + 2]
  41. # tmp = [tmp[i:i+3] for i in range(0, len(tmp), 3)]
  42. # g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1,
  43. # bond_type=tmp[3].strip(), label=tmp[3].strip())
  44. return g
  45. def loadGXL(filename):
  46. from os.path import basename
  47. import networkx as nx
  48. import xml.etree.ElementTree as ET
  49. tree = ET.parse(filename)
  50. root = tree.getroot()
  51. index = 0
  52. g = nx.Graph(filename=basename(filename), name=root[0].attrib['id'])
  53. dic = {} #used to retrieve incident nodes of edges
  54. for node in root.iter('node'):
  55. dic[node.attrib['id']] = index
  56. labels = {}
  57. for attr in node.iter('attr'):
  58. labels[attr.attrib['name']] = attr[0].text
  59. if 'chem' in labels:
  60. labels['label'] = labels['chem']
  61. g.add_node(index, **labels)
  62. index += 1
  63. for edge in root.iter('edge'):
  64. labels = {}
  65. for attr in edge.iter('attr'):
  66. labels[attr.attrib['name']] = attr[0].text
  67. if 'valence' in labels:
  68. labels['label'] = labels['valence']
  69. g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
  70. return g
  71. def saveGXL(graph, filename):
  72. import xml.etree.ElementTree as ET
  73. root_node = ET.Element('gxl')
  74. attr = dict()
  75. attr['id'] = graph.graph['name']
  76. attr['edgeids'] = 'true'
  77. attr['edgemode'] = 'undirected'
  78. graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
  79. for v in graph:
  80. current_node = ET.SubElement(graph_node, 'node', attrib={'id' : str(v)})
  81. for attr in graph.nodes[v].keys():
  82. cur_attr = ET.SubElement(current_node, 'attr', attrib={'name' : attr})
  83. cur_value = ET.SubElement(cur_attr,graph.nodes[v][attr].__class__.__name__)
  84. cur_value.text = graph.nodes[v][attr]
  85. for v1 in graph:
  86. for v2 in graph[v1]:
  87. if(v1 < v2): #Non oriented graphs
  88. cur_edge = ET.SubElement(graph_node, 'edge', attrib={'from' : str(v1),
  89. 'to' : str(v2)})
  90. for attr in graph[v1][v2].keys():
  91. cur_attr = ET.SubElement(cur_edge, 'attr', attrib={'name' : attr})
  92. cur_value = ET.SubElement(cur_attr, graph[v1][v2][attr].__class__.__name__)
  93. cur_value.text = str(graph[v1][v2][attr])
  94. tree = ET.ElementTree(root_node)
  95. tree.write(filename)
  96. def loadSDF(filename):
  97. """load data from structured data file (.sdf file).
  98. Notes
  99. ------
  100. A SDF file contains a group of molecules, represented in the similar way as in MOL format.
  101. see http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx, 2018 for detailed structure.
  102. """
  103. import networkx as nx
  104. from os.path import basename
  105. from tqdm import tqdm
  106. import sys
  107. data = []
  108. with open(filename) as f:
  109. content = f.read().splitlines()
  110. index = 0
  111. pbar = tqdm(total = len(content) + 1, desc = 'load SDF', file=sys.stdout)
  112. while index < len(content):
  113. index_old = index
  114. g = nx.Graph(name=content[index].strip()) # set name of the graph
  115. tmp = content[index + 3]
  116. nb_nodes = int(tmp[:3]) # number of the nodes
  117. nb_edges = int(tmp[3:6]) # number of the edges
  118. for i in range(0, nb_nodes):
  119. tmp = content[i + index + 4]
  120. g.add_node(i, atom=tmp[31:34].strip())
  121. for i in range(0, nb_edges):
  122. tmp = content[i + index + g.number_of_nodes() + 4]
  123. tmp = [tmp[i:i+3] for i in range(0, len(tmp), 3)]
  124. g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip())
  125. data.append(g)
  126. index += 4 + g.number_of_nodes() + g.number_of_edges()
  127. while content[index].strip() != '$$$$': # seperator
  128. index += 1
  129. index += 1
  130. pbar.update(index - index_old)
  131. pbar.update(1)
  132. pbar.close()
  133. return data
  134. def loadDataset(filename, filename_y = ''):
  135. """load file list of the dataset.
  136. """
  137. from os.path import dirname, splitext
  138. dirname_dataset = dirname(filename)
  139. extension = splitext(filename)[1][1:]
  140. data = []
  141. y = []
  142. if extension == "ds":
  143. content = open(filename).read().splitlines()
  144. for i in range(0, len(content)):
  145. tmp = content[i].split(' ')
  146. data.append(loadCT(dirname_dataset + '/' + tmp[0].replace('#', '', 1))) # remove the '#'s in file names
  147. y.append(float(tmp[1]))
  148. elif(extension == "cxl"):
  149. import xml.etree.ElementTree as ET
  150. tree = ET.parse(filename)
  151. root = tree.getroot()
  152. data = []
  153. y = []
  154. for graph in root.iter('print'):
  155. mol_filename = graph.attrib['file']
  156. mol_class = graph.attrib['class']
  157. data.append(loadGXL(dirname_dataset + '/' + mol_filename))
  158. y.append(mol_class)
  159. elif extension == "sdf":
  160. import numpy as np
  161. from tqdm import tqdm
  162. import sys
  163. data = loadSDF(filename)
  164. y_raw = open(filename_y).read().splitlines()
  165. y_raw.pop(0)
  166. tmp0 = []
  167. tmp1 = []
  168. for i in range(0, len(y_raw)):
  169. tmp = y_raw[i].split(',')
  170. tmp0.append(tmp[0])
  171. tmp1.append(tmp[1].strip())
  172. y = []
  173. for i in tqdm(range(0, len(data)), desc = 'ajust data', file=sys.stdout):
  174. try:
  175. y.append(tmp1[tmp0.index(data[i].name)].strip())
  176. except ValueError: # if data[i].name not in tmp0
  177. data[i] = []
  178. data = list(filter(lambda a: a != [], data))
  179. return data, y

A Python package for graph kernels, graph edit distances and graph pre-image problem.