You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

median_linlin.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import sys
  2. import pathlib
  3. import numpy as np
  4. import networkx as nx
  5. from gedlibpy import librariesImport, gedlibpy
  6. sys.path.insert(0, "/home/bgauzere/dev/optim-graphes/")
  7. import gklearn
  8. def replace_graph_in_env(script, graph, old_id, label='median'):
  9. """
  10. Replace a graph in script
  11. If old_id is -1, add a new graph to the environnemt
  12. """
  13. if(old_id > -1):
  14. script.PyClearGraph(old_id)
  15. new_id = script.PyAddGraph(label)
  16. for i in graph.nodes():
  17. script.PyAddNode(new_id,str(i),graph.node[i]) # !! strings are required bt gedlib
  18. for e in graph.edges:
  19. script.PyAddEdge(new_id, str(e[0]),str(e[1]), {})
  20. script.PyInitEnv()
  21. script.PySetMethod("IPFP", "")
  22. script.PyInitMethod()
  23. return new_id
  24. #Dessin median courrant
  25. def draw_Letter_graph(graph):
  26. import numpy as np
  27. import networkx as nx
  28. import matplotlib.pyplot as plt
  29. plt.figure()
  30. pos = {}
  31. for n in graph.nodes:
  32. pos[n] = np.array([float(graph.node[n]['x']),float(graph.node[n]['y'])])
  33. nx.draw_networkx(graph,pos)
  34. plt.show()
  35. #compute new mappings
  36. def update_mappings(script,median_id,listID):
  37. med_distances = {}
  38. med_mappings = {}
  39. sod = 0
  40. for i in range(0,len(listID)):
  41. script.PyRunMethod(median_id,listID[i])
  42. med_distances[i] = script.PyGetUpperBound(median_id,listID[i])
  43. med_mappings[i] = script.PyGetForwardMap(median_id,listID[i])
  44. sod += med_distances[i]
  45. return med_distances, med_mappings, sod
  46. def calcul_Sij(all_mappings, all_graphs,i,j):
  47. s_ij = 0
  48. for k in range(0,len(all_mappings)):
  49. cur_graph = all_graphs[k]
  50. cur_mapping = all_mappings[k]
  51. size_graph = cur_graph.order()
  52. if ((cur_mapping[i] < size_graph) and
  53. (cur_mapping[j] < size_graph) and
  54. (cur_graph.has_edge(cur_mapping[i], cur_mapping[j]) == True)):
  55. s_ij += 1
  56. return s_ij
  57. # def update_median_nodes_L1(median,listIdSet,median_id,dataset, mappings):
  58. # from scipy.stats.mstats import gmean
  59. # for i in median.nodes():
  60. # for k in listIdSet:
  61. # vectors = [] #np.zeros((len(listIdSet),2))
  62. # if(k != median_id):
  63. # phi_i = mappings[k][i]
  64. # if(phi_i < dataset[k].order()):
  65. # vectors.append([float(dataset[k].node[phi_i]['x']),float(dataset[k].node[phi_i]['y'])])
  66. # new_labels = gmean(vectors)
  67. # median.node[i]['x'] = str(new_labels[0])
  68. # median.node[i]['y'] = str(new_labels[1])
  69. # return median
  70. def update_median_nodes(median,dataset,mappings):
  71. #update node attributes
  72. for i in median.nodes():
  73. nb_sub=0
  74. mean_label = {'x' : 0, 'y' : 0}
  75. for k in range(0,len(mappings)):
  76. phi_i = mappings[k][i]
  77. if ( phi_i < dataset[k].order() ):
  78. nb_sub += 1
  79. mean_label['x'] += 0.75*float(dataset[k].node[phi_i]['x'])
  80. mean_label['y'] += 0.75*float(dataset[k].node[phi_i]['y'])
  81. median.node[i]['x'] = str((1/0.75)*(mean_label['x']/nb_sub))
  82. median.node[i]['y'] = str((1/0.75)*(mean_label['y']/nb_sub))
  83. return median
  84. def update_median_edges(dataset, mappings, median, cei=0.425,cer=0.425):
  85. #for letter high, ceir = 1.7, alpha = 0.75
  86. size_dataset = len(dataset)
  87. ratio_cei_cer = cer/(cei + cer)
  88. threshold = size_dataset*ratio_cei_cer
  89. order_graph_median = median.order()
  90. for i in range(0,order_graph_median):
  91. for j in range(i+1,order_graph_median):
  92. s_ij = calcul_Sij(mappings,dataset,i,j)
  93. if(s_ij > threshold):
  94. median.add_edge(i,j)
  95. else:
  96. if(median.has_edge(i,j)):
  97. median.remove_edge(i,j)
  98. return median
  99. def compute_median(script, listID, dataset,verbose=False):
  100. """Compute a graph median of a dataset according to an environment
  101. Parameters
  102. script : An gedlib initialized environnement
  103. listID (list): a list of ID in script: encodes the dataset
  104. dataset (list): corresponding graphs in networkX format. We assume that graph
  105. listID[i] corresponds to dataset[i]
  106. Returns:
  107. A networkX graph, which is the median, with corresponding sod
  108. """
  109. print(len(listID))
  110. median_set_index, median_set_sod = compute_median_set(script, listID)
  111. print(median_set_index)
  112. print(median_set_sod)
  113. sods = []
  114. #Ajout median dans environnement
  115. set_median = dataset[median_set_index].copy()
  116. median = dataset[median_set_index].copy()
  117. cur_med_id = replace_graph_in_env(script,median,-1)
  118. med_distances, med_mappings, cur_sod = update_mappings(script,cur_med_id,listID)
  119. sods.append(cur_sod)
  120. if(verbose):
  121. print(cur_sod)
  122. ite_max = 50
  123. old_sod = cur_sod * 2
  124. ite = 0
  125. epsilon = 0.001
  126. best_median
  127. while((ite < ite_max) and (np.abs(old_sod - cur_sod) > epsilon )):
  128. median = update_median_nodes(median,dataset, med_mappings)
  129. median = update_median_edges(dataset,med_mappings,median)
  130. cur_med_id = replace_graph_in_env(script,median,cur_med_id)
  131. med_distances, med_mappings, cur_sod = update_mappings(script,cur_med_id,listID)
  132. sods.append(cur_sod)
  133. if(verbose):
  134. print(cur_sod)
  135. ite += 1
  136. return median, cur_sod, sods, set_median
  137. draw_Letter_graph(median)
  138. def compute_median_set(script,listID):
  139. 'Returns the id in listID corresponding to median set'
  140. #Calcul median set
  141. N=len(listID)
  142. map_id_to_index = {}
  143. map_index_to_id = {}
  144. for i in range(0,len(listID)):
  145. map_id_to_index[listID[i]] = i
  146. map_index_to_id[i] = listID[i]
  147. distances = np.zeros((N,N))
  148. for i in listID:
  149. for j in listID:
  150. script.PyRunMethod(i,j)
  151. distances[map_id_to_index[i],map_id_to_index[j]] = script.PyGetUpperBound(i,j)
  152. median_set_index = np.argmin(np.sum(distances,0))
  153. sod = np.min(np.sum(distances,0))
  154. return median_set_index, sod
  155. def _convertGraph(G):
  156. """Convert a graph to the proper NetworkX format that can be
  157. recognized by library gedlibpy.
  158. """
  159. G_new = nx.Graph()
  160. for nd, attrs in G.nodes(data=True):
  161. G_new.add_node(str(nd), chem=attrs['atom'])
  162. # G_new.add_node(str(nd), x=str(attrs['attributes'][0]),
  163. # y=str(attrs['attributes'][1]))
  164. for nd1, nd2, attrs in G.edges(data=True):
  165. G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type'])
  166. # G_new.add_edge(str(nd1), str(nd2))
  167. return G_new
  168. if __name__ == "__main__":
  169. #Chargement du dataset
  170. gedlibpy.PyLoadGXLGraph('/home/bgauzere/dev/gedlib/data/datasets/Letter/HIGH/', '/home/bgauzere/dev/gedlib/data/collections/Letter_Z.xml')
  171. gedlibpy.PySetEditCost("LETTER")
  172. gedlibpy.PyInitEnv()
  173. gedlibpy.PySetMethod("IPFP", "")
  174. gedlibpy.PyInitMethod()
  175. dataset,my_y = gklearn.utils.graphfiles.loadDataset("/home/bgauzere/dev/gedlib/data/datasets/Letter/HIGH/Letter_Z.cxl")
  176. listID = gedlibpy.PyGetAllGraphIds()
  177. median, sod = compute_median(gedlibpy,listID,dataset,verbose=True)
  178. print(sod)
  179. draw_Letter_graph(median)

A Python package for graph kernels, graph edit distances and graph pre-image problem.