You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

median_benoit.py 6.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import sys
  2. import pathlib
  3. import numpy as np
  4. import networkx as nx
  5. import librariesImport
  6. import script
  7. sys.path.insert(0, "/home/bgauzere/dev/optim-graphes/")
  8. import gklearn
  9. def replace_graph_in_env(script, graph, old_id, label='median'):
  10. """
  11. Replace a graph in script
  12. If old_id is -1, add a new graph to the environnemt
  13. """
  14. if(old_id > -1):
  15. script.PyClearGraph(old_id)
  16. new_id = script.PyAddGraph(label)
  17. for i in graph.nodes():
  18. script.PyAddNode(new_id,str(i),graph.node[i]) # !! strings are required bt gedlib
  19. for e in graph.edges:
  20. script.PyAddEdge(new_id, str(e[0]),str(e[1]), {})
  21. script.PyInitEnv()
  22. script.PySetMethod("IPFP", "")
  23. script.PyInitMethod()
  24. return new_id
  25. #Dessin median courrant
  26. def draw_Letter_graph(graph):
  27. import numpy as np
  28. import networkx as nx
  29. import matplotlib.pyplot as plt
  30. plt.figure()
  31. pos = {}
  32. for n in graph.nodes:
  33. pos[n] = np.array([float(graph.node[n]['x']),float(graph.node[n]['y'])])
  34. nx.draw_networkx(graph,pos)
  35. plt.show()
  36. #compute new mappings
  37. def update_mappings(script,median_id,listID):
  38. med_distances = {}
  39. med_mappings = {}
  40. sod = 0
  41. for i in range(0,len(listID)):
  42. script.PyRunMethod(median_id,listID[i])
  43. med_distances[i] = script.PyGetUpperBound(median_id,listID[i])
  44. med_mappings[i] = script.PyGetForwardMap(median_id,listID[i])
  45. sod += med_distances[i]
  46. return med_distances, med_mappings, sod
  47. def calcul_Sij(all_mappings, all_graphs,i,j):
  48. s_ij = 0
  49. for k in range(0,len(all_mappings)):
  50. cur_graph = all_graphs[k]
  51. cur_mapping = all_mappings[k]
  52. size_graph = cur_graph.order()
  53. if ((cur_mapping[i] < size_graph) and
  54. (cur_mapping[j] < size_graph) and
  55. (cur_graph.has_edge(cur_mapping[i], cur_mapping[j]) == True)):
  56. s_ij += 1
  57. return s_ij
  58. # def update_median_nodes_L1(median,listIdSet,median_id,dataset, mappings):
  59. # from scipy.stats.mstats import gmean
  60. # for i in median.nodes():
  61. # for k in listIdSet:
  62. # vectors = [] #np.zeros((len(listIdSet),2))
  63. # if(k != median_id):
  64. # phi_i = mappings[k][i]
  65. # if(phi_i < dataset[k].order()):
  66. # vectors.append([float(dataset[k].node[phi_i]['x']),float(dataset[k].node[phi_i]['y'])])
  67. # new_labels = gmean(vectors)
  68. # median.node[i]['x'] = str(new_labels[0])
  69. # median.node[i]['y'] = str(new_labels[1])
  70. # return median
  71. def update_median_nodes(median,dataset,mappings):
  72. #update node attributes
  73. for i in median.nodes():
  74. nb_sub=0
  75. mean_label = {'x' : 0, 'y' : 0}
  76. for k in range(0,len(mappings)):
  77. phi_i = mappings[k][i]
  78. if ( phi_i < dataset[k].order() ):
  79. nb_sub += 1
  80. mean_label['x'] += 0.75*float(dataset[k].node[phi_i]['x'])
  81. mean_label['y'] += 0.75*float(dataset[k].node[phi_i]['y'])
  82. median.node[i]['x'] = str((1/0.75)*(mean_label['x']/nb_sub))
  83. median.node[i]['y'] = str((1/0.75)*(mean_label['y']/nb_sub))
  84. return median
  85. def update_median_edges(dataset, mappings, median, cei=0.425,cer=0.425):
  86. #for letter high, ceir = 1.7, alpha = 0.75
  87. size_dataset = len(dataset)
  88. ratio_cei_cer = cer/(cei + cer)
  89. threshold = size_dataset*ratio_cei_cer
  90. order_graph_median = median.order()
  91. for i in range(0,order_graph_median):
  92. for j in range(i+1,order_graph_median):
  93. s_ij = calcul_Sij(mappings,dataset,i,j)
  94. if(s_ij > threshold):
  95. median.add_edge(i,j)
  96. else:
  97. if(median.has_edge(i,j)):
  98. median.remove_edge(i,j)
  99. return median
  100. def compute_median(script, listID, dataset,verbose=False):
  101. """Compute a graph median of a dataset according to an environment
  102. Parameters
  103. script : An gedlib initialized environnement
  104. listID (list): a list of ID in script: encodes the dataset
  105. dataset (list): corresponding graphs in networkX format. We assume that graph
  106. listID[i] corresponds to dataset[i]
  107. Returns:
  108. A networkX graph, which is the median, with corresponding sod
  109. """
  110. print(len(listID))
  111. median_set_index, median_set_sod = compute_median_set(script, listID)
  112. print(median_set_index)
  113. print(median_set_sod)
  114. sods = []
  115. #Ajout median dans environnement
  116. set_median = dataset[median_set_index].copy()
  117. median = dataset[median_set_index].copy()
  118. cur_med_id = replace_graph_in_env(script,median,-1)
  119. med_distances, med_mappings, cur_sod = update_mappings(script,cur_med_id,listID)
  120. sods.append(cur_sod)
  121. if(verbose):
  122. print(cur_sod)
  123. ite_max = 50
  124. old_sod = cur_sod * 2
  125. ite = 0
  126. epsilon = 0.001
  127. best_median
  128. while((ite < ite_max) and (np.abs(old_sod - cur_sod) > epsilon )):
  129. median = update_median_nodes(median,dataset, med_mappings)
  130. median = update_median_edges(dataset,med_mappings,median)
  131. cur_med_id = replace_graph_in_env(script,median,cur_med_id)
  132. med_distances, med_mappings, cur_sod = update_mappings(script,cur_med_id,listID)
  133. sods.append(cur_sod)
  134. if(verbose):
  135. print(cur_sod)
  136. ite += 1
  137. return median, cur_sod, sods, set_median
  138. draw_Letter_graph(median)
  139. def compute_median_set(script,listID):
  140. 'Returns the id in listID corresponding to median set'
  141. #Calcul median set
  142. N=len(listID)
  143. map_id_to_index = {}
  144. map_index_to_id = {}
  145. for i in range(0,len(listID)):
  146. map_id_to_index[listID[i]] = i
  147. map_index_to_id[i] = listID[i]
  148. distances = np.zeros((N,N))
  149. for i in listID:
  150. for j in listID:
  151. script.PyRunMethod(i,j)
  152. distances[map_id_to_index[i],map_id_to_index[j]] = script.PyGetUpperBound(i,j)
  153. median_set_index = np.argmin(np.sum(distances,0))
  154. sod = np.min(np.sum(distances,0))
  155. return median_set_index, sod
  156. if __name__ == "__main__":
  157. #Chargement du dataset
  158. script.PyLoadGXLGraph('/home/bgauzere/dev/gedlib/data/datasets/Letter/HIGH/', '/home/bgauzere/dev/gedlib/data/collections/Letter_Z.xml')
  159. script.PySetEditCost("LETTER")
  160. script.PyInitEnv()
  161. script.PySetMethod("IPFP", "")
  162. script.PyInitMethod()
  163. dataset,my_y = gklearn.utils.graphfiles.loadDataset("/home/bgauzere/dev/gedlib/data/datasets/Letter/HIGH/Letter_Z.cxl")
  164. listID = script.PyGetAllGraphIds()
  165. median, sod = compute_median(script,listID,dataset,verbose=True)
  166. print(sod)
  167. draw_Letter_graph(median)

A Python package for graph kernels, graph edit distances and graph pre-image problem.