You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opt.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import ot
  2. import sys
  3. import pathlib
  4. sys.path.insert(0, "../")
  5. from pygraph.utils.graphfiles import loadDataset
  6. from pygraph.ged.costfunctions import ConstantCostFunction
  7. from pygraph.utils.utils import getSPLengths
  8. from tqdm import tqdm
  9. import numpy as np
  10. from scipy.optimize import linear_sum_assignment
  11. from pygraph.ged.GED import ged
  12. import scipy
  13. def pad(C, n):
  14. C_pad = np.zeros((n, n))
  15. C_pad[:C.shape[0], :C.shape[1]] = C
  16. return C_pad
  17. if (__name__ == "__main__"):
  18. ds_filename = "/home/bgauzere/work/Datasets/Acyclic/dataset_bps.ds"
  19. dataset, y = loadDataset(ds_filename)
  20. cf = ConstantCostFunction(1, 3, 1, 3)
  21. N = len(dataset)
  22. pairs = list()
  23. ged_distances = list() #np.zeros((N, N))
  24. gw_distances = list() #np.zeros((N, N))
  25. for i in tqdm(range(0, N)):
  26. for j in tqdm(range(i, N)):
  27. G1 = dataset[i]
  28. G2 = dataset[j]
  29. n = G1.number_of_nodes()
  30. m = G2.number_of_nodes()
  31. if(n == m):
  32. C1 = getSPLengths(G1)
  33. C2 = getSPLengths(G2)
  34. C1 /= C1.max()
  35. C2 /= C2.max()
  36. dim = max(n, m)
  37. if(n < m):
  38. C1 = pad(C1, dim)
  39. elif (m < n):
  40. C2 = pad(C2, dim)
  41. p = ot.unif(dim)
  42. q = ot.unif(dim)
  43. gw = ot.gromov_wasserstein(C1, C2, p, q,
  44. 'square_loss', epsilon=5e-3)
  45. row_ind, col_ind = linear_sum_assignment(-gw)
  46. rho = col_ind
  47. varrho = row_ind[np.argsort(col_ind)]
  48. pairs.append((i,j))
  49. gw_distances.append(ged(G1, G2, cf=cf, rho=rho, varrho=varrho)[0])
  50. ged_distances.append(ged(G1, G2, cf=cf)[0])
  51. print("Moyenne sur Riesen : {}".format(np.mean(ged_distances)))
  52. print("Moyenne sur GW : {} ".format(np.mean(gw_distances)))
  53. np.save("distances_riesen", ged_distances)
  54. np.save("distances_gw", gw_distances)

A Python package for graph kernels, graph edit distances and graph pre-image problem.