You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

median.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import sys
  2. sys.path.insert(0, "../")
  3. #import pathlib
  4. import numpy as np
  5. import networkx as nx
  6. import time
  7. from gedlibpy import librariesImport, gedlibpy
  8. #import script
  9. sys.path.insert(0, "/home/bgauzere/dev/optim-graphes/")
  10. import gklearn
  11. from gklearn.utils.graphfiles import loadDataset
  12. def replace_graph_in_env(script, graph, old_id, label='median'):
  13. """
  14. Replace a graph in script
  15. If old_id is -1, add a new graph to the environnemt
  16. """
  17. if(old_id > -1):
  18. script.PyClearGraph(old_id)
  19. new_id = script.PyAddGraph(label)
  20. for i in graph.nodes():
  21. script.PyAddNode(new_id,str(i),graph.node[i]) # !! strings are required bt gedlib
  22. for e in graph.edges:
  23. script.PyAddEdge(new_id, str(e[0]),str(e[1]), {})
  24. script.PyInitEnv()
  25. script.PySetMethod("IPFP", "")
  26. script.PyInitMethod()
  27. return new_id
  28. #Dessin median courrant
  29. def draw_Letter_graph(graph, savepath=''):
  30. import numpy as np
  31. import networkx as nx
  32. import matplotlib.pyplot as plt
  33. plt.figure()
  34. pos = {}
  35. for n in graph.nodes:
  36. pos[n] = np.array([float(graph.node[n]['attributes'][0]),
  37. float(graph.node[n]['attributes'][1])])
  38. nx.draw_networkx(graph, pos)
  39. if savepath != '':
  40. plt.savefig(savepath + str(time.time()) + '.eps', format='eps', dpi=300)
  41. plt.show()
  42. plt.clf()
  43. #compute new mappings
  44. def update_mappings(script,median_id,listID):
  45. med_distances = {}
  46. med_mappings = {}
  47. sod = 0
  48. for i in range(0,len(listID)):
  49. script.PyRunMethod(median_id,listID[i])
  50. med_distances[i] = script.PyGetUpperBound(median_id,listID[i])
  51. med_mappings[i] = script.PyGetForwardMap(median_id,listID[i])
  52. sod += med_distances[i]
  53. return med_distances, med_mappings, sod
  54. def calcul_Sij(all_mappings, all_graphs,i,j):
  55. s_ij = 0
  56. for k in range(0,len(all_mappings)):
  57. cur_graph = all_graphs[k]
  58. cur_mapping = all_mappings[k]
  59. size_graph = cur_graph.order()
  60. if ((cur_mapping[i] < size_graph) and
  61. (cur_mapping[j] < size_graph) and
  62. (cur_graph.has_edge(cur_mapping[i], cur_mapping[j]) == True)):
  63. s_ij += 1
  64. return s_ij
  65. # def update_median_nodes_L1(median,listIdSet,median_id,dataset, mappings):
  66. # from scipy.stats.mstats import gmean
  67. # for i in median.nodes():
  68. # for k in listIdSet:
  69. # vectors = [] #np.zeros((len(listIdSet),2))
  70. # if(k != median_id):
  71. # phi_i = mappings[k][i]
  72. # if(phi_i < dataset[k].order()):
  73. # vectors.append([float(dataset[k].node[phi_i]['x']),float(dataset[k].node[phi_i]['y'])])
  74. # new_labels = gmean(vectors)
  75. # median.node[i]['x'] = str(new_labels[0])
  76. # median.node[i]['y'] = str(new_labels[1])
  77. # return median
  78. def update_median_nodes(median,dataset,mappings):
  79. #update node attributes
  80. for i in median.nodes():
  81. nb_sub=0
  82. mean_label = {'x' : 0, 'y' : 0}
  83. for k in range(0,len(mappings)):
  84. phi_i = mappings[k][i]
  85. if ( phi_i < dataset[k].order() ):
  86. nb_sub += 1
  87. mean_label['x'] += 0.75*float(dataset[k].node[phi_i]['x'])
  88. mean_label['y'] += 0.75*float(dataset[k].node[phi_i]['y'])
  89. median.node[i]['x'] = str((1/0.75)*(mean_label['x']/nb_sub))
  90. median.node[i]['y'] = str((1/0.75)*(mean_label['y']/nb_sub))
  91. return median
  92. def update_median_edges(dataset, mappings, median, cei=0.425,cer=0.425):
  93. #for letter high, ceir = 1.7, alpha = 0.75
  94. size_dataset = len(dataset)
  95. ratio_cei_cer = cer/(cei + cer)
  96. threshold = size_dataset*ratio_cei_cer
  97. order_graph_median = median.order()
  98. for i in range(0,order_graph_median):
  99. for j in range(i+1,order_graph_median):
  100. s_ij = calcul_Sij(mappings,dataset,i,j)
  101. if(s_ij > threshold):
  102. median.add_edge(i,j)
  103. else:
  104. if(median.has_edge(i,j)):
  105. median.remove_edge(i,j)
  106. return median
  107. def compute_median(script, listID, dataset,verbose=False):
  108. """Compute a graph median of a dataset according to an environment
  109. Parameters
  110. script : An gedlib initialized environnement
  111. listID (list): a list of ID in script: encodes the dataset
  112. dataset (list): corresponding graphs in networkX format. We assume that graph
  113. listID[i] corresponds to dataset[i]
  114. Returns:
  115. A networkX graph, which is the median, with corresponding sod
  116. """
  117. print(len(listID))
  118. median_set_index, median_set_sod = compute_median_set(script, listID)
  119. print(median_set_index)
  120. print(median_set_sod)
  121. sods = []
  122. #Ajout median dans environnement
  123. set_median = dataset[median_set_index].copy()
  124. median = dataset[median_set_index].copy()
  125. cur_med_id = replace_graph_in_env(script,median,-1)
  126. med_distances, med_mappings, cur_sod = update_mappings(script,cur_med_id,listID)
  127. sods.append(cur_sod)
  128. if(verbose):
  129. print(cur_sod)
  130. ite_max = 50
  131. old_sod = cur_sod * 2
  132. ite = 0
  133. epsilon = 0.001
  134. best_median
  135. while((ite < ite_max) and (np.abs(old_sod - cur_sod) > epsilon )):
  136. median = update_median_nodes(median,dataset, med_mappings)
  137. median = update_median_edges(dataset,med_mappings,median)
  138. cur_med_id = replace_graph_in_env(script,median,cur_med_id)
  139. med_distances, med_mappings, cur_sod = update_mappings(script,cur_med_id,listID)
  140. sods.append(cur_sod)
  141. if(verbose):
  142. print(cur_sod)
  143. ite += 1
  144. return median, cur_sod, sods, set_median
  145. draw_Letter_graph(median)
  146. def compute_median_set(script,listID):
  147. 'Returns the id in listID corresponding to median set'
  148. #Calcul median set
  149. N=len(listID)
  150. map_id_to_index = {}
  151. map_index_to_id = {}
  152. for i in range(0,len(listID)):
  153. map_id_to_index[listID[i]] = i
  154. map_index_to_id[i] = listID[i]
  155. distances = np.zeros((N,N))
  156. for i in listID:
  157. for j in listID:
  158. script.PyRunMethod(i,j)
  159. distances[map_id_to_index[i],map_id_to_index[j]] = script.PyGetUpperBound(i,j)
  160. median_set_index = np.argmin(np.sum(distances,0))
  161. sod = np.min(np.sum(distances,0))
  162. return median_set_index, sod
  163. if __name__ == "__main__":
  164. #Chargement du dataset
  165. script.PyLoadGXLGraph('/home/bgauzere/dev/gedlib/data/datasets/Letter/HIGH/', '/home/bgauzere/dev/gedlib/data/collections/Letter_Z.xml')
  166. script.PySetEditCost("LETTER")
  167. script.PyInitEnv()
  168. script.PySetMethod("IPFP", "")
  169. script.PyInitMethod()
  170. dataset,my_y = gklearn.utils.graphfiles.loadDataset("/home/bgauzere/dev/gedlib/data/datasets/Letter/HIGH/Letter_Z.cxl")
  171. listID = script.PyGetAllGraphIds()
  172. median, sod = compute_median(script,listID,dataset,verbose=True)
  173. print(sod)
  174. draw_Letter_graph(median)
  175. #if __name__ == '__main__':
  176. # # test draw_Letter_graph
  177. # ds = {'name': 'Letter-high', 'dataset': '../datasets/Letter-high/Letter-high_A.txt',
  178. # 'extra_params': {}} # node nsymb
  179. # Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  180. # print(y_all)
  181. # for g in Gn:
  182. # draw_Letter_graph(g)

A Python package for graph kernels, graph edit distances and graph pre-image problem.