You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_preimage_random.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Sep 5 15:59:00 2019
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import matplotlib.pyplot as plt
  10. import time
  11. import random
  12. #from tqdm import tqdm
  13. #import os
  14. import sys
  15. sys.path.insert(0, "../")
  16. from pygraph.utils.graphfiles import loadDataset
  17. from preimage_random import preimage_random
  18. from ged import ged_median
  19. from utils import compute_kernel, get_same_item_indices, remove_edges
  20. ###############################################################################
  21. # tests on different values on grid of median-sets and k.
  22. def test_preimage_random_grid_k_median_nb():
  23. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG_A.txt',
  24. 'extra_params': {}} # node/edge symb
  25. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  26. # Gn = Gn[0:50]
  27. remove_edges(Gn)
  28. gkernel = 'marginalizedkernel'
  29. lmbda = 0.03 # termination probalility
  30. r_max = 5 # iteration limit for pre-image.
  31. l = 500 # update limit for random generation
  32. # alpha_range = np.linspace(0.5, 0.5, 1)
  33. # k = 5 # k nearest neighbors
  34. # parameters for GED function
  35. ged_cost='CHEM_1'
  36. ged_method='IPFP'
  37. saveGXL='gedlib'
  38. # number of graphs; we what to compute the median of these graphs.
  39. nb_median_range = [2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
  40. # number of nearest neighbors.
  41. k_range = [5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 100]
  42. # find out all the graphs classified to positive group 1.
  43. idx_dict = get_same_item_indices(y_all)
  44. Gn = [Gn[i] for i in idx_dict[1]]
  45. # # compute Gram matrix.
  46. # time0 = time.time()
  47. # km = compute_kernel(Gn, gkernel, True)
  48. # time_km = time.time() - time0
  49. # # write Gram matrix to file.
  50. # np.savez('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm', gm=km, gmtime=time_km)
  51. time_list = []
  52. dis_ks_min_list = []
  53. sod_gs_list = []
  54. sod_gs_min_list = []
  55. nb_updated_list = []
  56. g_best = []
  57. for idx_nb, nb_median in enumerate(nb_median_range):
  58. print('\n-------------------------------------------------------')
  59. print('number of median graphs =', nb_median)
  60. random.seed(1)
  61. idx_rdm = random.sample(range(len(Gn)), nb_median)
  62. print('graphs chosen:', idx_rdm)
  63. Gn_median = [Gn[idx].copy() for idx in idx_rdm]
  64. # for g in Gn_median:
  65. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  66. ## plt.savefig("results/preimage_mix/mutag.png", format="PNG")
  67. # plt.show()
  68. # plt.clf()
  69. ###################################################################
  70. gmfile = np.load('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm.npz')
  71. km_tmp = gmfile['gm']
  72. time_km = gmfile['gmtime']
  73. # modify mixed gram matrix.
  74. km = np.zeros((len(Gn) + nb_median, len(Gn) + nb_median))
  75. for i in range(len(Gn)):
  76. for j in range(i, len(Gn)):
  77. km[i, j] = km_tmp[i, j]
  78. km[j, i] = km[i, j]
  79. for i in range(len(Gn)):
  80. for j, idx in enumerate(idx_rdm):
  81. km[i, len(Gn) + j] = km[i, idx]
  82. km[len(Gn) + j, i] = km[i, idx]
  83. for i, idx1 in enumerate(idx_rdm):
  84. for j, idx2 in enumerate(idx_rdm):
  85. km[len(Gn) + i, len(Gn) + j] = km[idx1, idx2]
  86. ###################################################################
  87. alpha_range = [1 / nb_median] * nb_median
  88. time_list.append([])
  89. dis_ks_min_list.append([])
  90. sod_gs_list.append([])
  91. sod_gs_min_list.append([])
  92. nb_updated_list.append([])
  93. g_best.append([])
  94. for k in k_range:
  95. print('\n++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n')
  96. print('k =', k)
  97. time0 = time.time()
  98. dhat, ghat, nb_updated = preimage_random(Gn, Gn_median, alpha_range,
  99. range(len(Gn), len(Gn) + nb_median), km, k, r_max, l, gkernel)
  100. time_total = time.time() - time0 + time_km
  101. print('time: ', time_total)
  102. time_list[idx_nb].append(time_total)
  103. print('\nsmallest distance in kernel space: ', dhat)
  104. dis_ks_min_list[idx_nb].append(dhat)
  105. g_best[idx_nb].append(ghat)
  106. print('\nnumber of updates of the best graph: ', nb_updated)
  107. nb_updated_list[idx_nb].append(nb_updated)
  108. # show the best graph and save it to file.
  109. print('the shortest distance is', dhat)
  110. print('one of the possible corresponding pre-images is')
  111. nx.draw(ghat, labels=nx.get_node_attributes(ghat, 'atom'),
  112. with_labels=True)
  113. plt.savefig('results/preimage_random/mutag_median_nb' + str(nb_median) +
  114. '_k' + str(k) + '.png', format="PNG")
  115. # plt.show()
  116. plt.clf()
  117. # print(ghat_list[0].nodes(data=True))
  118. # print(ghat_list[0].edges(data=True))
  119. # compute the corresponding sod in graph space.
  120. sod_tmp, _ = ged_median([ghat], Gn_median, ged_cost=ged_cost,
  121. ged_method=ged_method, saveGXL=saveGXL)
  122. sod_gs_list[idx_nb].append(sod_tmp)
  123. sod_gs_min_list[idx_nb].append(np.min(sod_tmp))
  124. print('\nsmallest sod in graph space: ', np.min(sod_tmp))
  125. print('\nsods in graph space: ', sod_gs_list)
  126. print('\nsmallest sod in graph space for each set of median graphs and k: ',
  127. sod_gs_min_list)
  128. print('\nsmallest distance in kernel space for each set of median graphs and k: ',
  129. dis_ks_min_list)
  130. print('\nnumber of updates of the best graph for each set of median graphs and k by IAM: ',
  131. nb_updated_list)
  132. print('\ntimes:', time_list)
  133. ###############################################################################
  134. # tests on different numbers of median-sets.
  135. def test_preimage_random_median_nb():
  136. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG_A.txt',
  137. 'extra_params': {}} # node/edge symb
  138. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  139. # Gn = Gn[0:50]
  140. remove_edges(Gn)
  141. gkernel = 'marginalizedkernel'
  142. lmbda = 0.03 # termination probalility
  143. r_max = 5 # iteration limit for pre-image.
  144. l = 500 # update limit for random generation
  145. # alpha_range = np.linspace(0.5, 0.5, 1)
  146. k = 5 # k nearest neighbors
  147. # parameters for GED function
  148. ged_cost='CHEM_1'
  149. ged_method='IPFP'
  150. saveGXL='gedlib'
  151. # number of graphs; we what to compute the median of these graphs.
  152. nb_median_range = [2, 3, 4, 5, 10, 20, 30, 40, 50, 100]
  153. # find out all the graphs classified to positive group 1.
  154. idx_dict = get_same_item_indices(y_all)
  155. Gn = [Gn[i] for i in idx_dict[1]]
  156. # # compute Gram matrix.
  157. # time0 = time.time()
  158. # km = compute_kernel(Gn, gkernel, True)
  159. # time_km = time.time() - time0
  160. # # write Gram matrix to file.
  161. # np.savez('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm', gm=km, gmtime=time_km)
  162. time_list = []
  163. dis_ks_min_list = []
  164. sod_gs_list = []
  165. sod_gs_min_list = []
  166. nb_updated_list = []
  167. g_best = []
  168. for nb_median in nb_median_range:
  169. print('\n-------------------------------------------------------')
  170. print('number of median graphs =', nb_median)
  171. random.seed(1)
  172. idx_rdm = random.sample(range(len(Gn)), nb_median)
  173. print('graphs chosen:', idx_rdm)
  174. Gn_median = [Gn[idx].copy() for idx in idx_rdm]
  175. # for g in Gn_median:
  176. # nx.draw(g, labels=nx.get_node_attributes(g, 'atom'), with_labels=True)
  177. ## plt.savefig("results/preimage_mix/mutag.png", format="PNG")
  178. # plt.show()
  179. # plt.clf()
  180. ###################################################################
  181. gmfile = np.load('results/gram_matrix_marg_itr10_pq0.03_mutag_positive.gm.npz')
  182. km_tmp = gmfile['gm']
  183. time_km = gmfile['gmtime']
  184. # modify mixed gram matrix.
  185. km = np.zeros((len(Gn) + nb_median, len(Gn) + nb_median))
  186. for i in range(len(Gn)):
  187. for j in range(i, len(Gn)):
  188. km[i, j] = km_tmp[i, j]
  189. km[j, i] = km[i, j]
  190. for i in range(len(Gn)):
  191. for j, idx in enumerate(idx_rdm):
  192. km[i, len(Gn) + j] = km[i, idx]
  193. km[len(Gn) + j, i] = km[i, idx]
  194. for i, idx1 in enumerate(idx_rdm):
  195. for j, idx2 in enumerate(idx_rdm):
  196. km[len(Gn) + i, len(Gn) + j] = km[idx1, idx2]
  197. ###################################################################
  198. alpha_range = [1 / nb_median] * nb_median
  199. time0 = time.time()
  200. dhat, ghat, nb_updated = preimage_random(Gn, Gn_median, alpha_range,
  201. range(len(Gn), len(Gn) + nb_median), km, k, r_max, l, gkernel)
  202. time_total = time.time() - time0 + time_km
  203. print('time: ', time_total)
  204. time_list.append(time_total)
  205. print('\nsmallest distance in kernel space: ', dhat)
  206. dis_ks_min_list.append(dhat)
  207. g_best.append(ghat)
  208. print('\nnumber of updates of the best graph: ', nb_updated)
  209. nb_updated_list.append(nb_updated)
  210. # show the best graph and save it to file.
  211. print('the shortest distance is', dhat)
  212. print('one of the possible corresponding pre-images is')
  213. nx.draw(ghat, labels=nx.get_node_attributes(ghat, 'atom'),
  214. with_labels=True)
  215. plt.savefig('results/preimage_random/mutag_median_nb' + str(nb_median) +
  216. '.png', format="PNG")
  217. # plt.show()
  218. plt.clf()
  219. # print(ghat_list[0].nodes(data=True))
  220. # print(ghat_list[0].edges(data=True))
  221. # compute the corresponding sod in graph space.
  222. sod_tmp, _ = ged_median([ghat], Gn_median, ged_cost=ged_cost,
  223. ged_method=ged_method, saveGXL=saveGXL)
  224. sod_gs_list.append(sod_tmp)
  225. sod_gs_min_list.append(np.min(sod_tmp))
  226. print('\nsmallest sod in graph space: ', np.min(sod_tmp))
  227. print('\nsods in graph space: ', sod_gs_list)
  228. print('\nsmallest sod in graph space for each set of median graphs: ', sod_gs_min_list)
  229. print('\nsmallest distance in kernel space for each set of median graphs: ',
  230. dis_ks_min_list)
  231. print('\nnumber of updates of the best graph for each set of median graphs: ',
  232. nb_updated_list)
  233. print('\ntimes:', time_list)
  234. ###############################################################################
  235. # test on the combination of the two randomly chosen graphs. (the same as in the
  236. # random pre-image paper.)
  237. def test_random_preimage_2combination():
  238. ds = {'name': 'MUTAG', 'dataset': '../datasets/MUTAG/MUTAG_A.txt',
  239. 'extra_params': {}} # node/edge symb
  240. Gn, y_all = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
  241. # Gn = Gn[0:12]
  242. remove_edges(Gn)
  243. gkernel = 'marginalizedkernel'
  244. # dis_mat, dis_max, dis_min, dis_mean = kernel_distance_matrix(Gn, gkernel=gkernel)
  245. # print(dis_max, dis_min, dis_mean)
  246. lmbda = 0.03 # termination probalility
  247. r_max = 10 # iteration limit for pre-image.
  248. l = 500
  249. alpha_range = np.linspace(0, 1, 11)
  250. k = 5 # k nearest neighbors
  251. # randomly select two molecules
  252. np.random.seed(1)
  253. idx_gi = [187, 167] # np.random.randint(0, len(Gn), 2)
  254. g1 = Gn[idx_gi[0]].copy()
  255. g2 = Gn[idx_gi[1]].copy()
  256. # nx.draw(g1, labels=nx.get_node_attributes(g1, 'atom'), with_labels=True)
  257. # plt.savefig("results/random_preimage/mutag10.png", format="PNG")
  258. # plt.show()
  259. # nx.draw(g2, labels=nx.get_node_attributes(g2, 'atom'), with_labels=True)
  260. # plt.savefig("results/random_preimage/mutag11.png", format="PNG")
  261. # plt.show()
  262. ######################################################################
  263. # Gn_mix = [g.copy() for g in Gn]
  264. # Gn_mix.append(g1.copy())
  265. # Gn_mix.append(g2.copy())
  266. #
  267. ## g_tmp = iam([g1, g2])
  268. ## nx.draw_networkx(g_tmp)
  269. ## plt.show()
  270. #
  271. # # compute
  272. # time0 = time.time()
  273. # km = compute_kernel(Gn_mix, gkernel, True)
  274. # time_km = time.time() - time0
  275. ###################################################################
  276. idx1 = idx_gi[0]
  277. idx2 = idx_gi[1]
  278. gmfile = np.load('results/gram_matrix_marg_itr10_pq0.03.gm.npz')
  279. km = gmfile['gm']
  280. time_km = gmfile['gmtime']
  281. # modify mixed gram matrix.
  282. for i in range(len(Gn)):
  283. km[i, len(Gn)] = km[i, idx1]
  284. km[i, len(Gn) + 1] = km[i, idx2]
  285. km[len(Gn), i] = km[i, idx1]
  286. km[len(Gn) + 1, i] = km[i, idx2]
  287. km[len(Gn), len(Gn)] = km[idx1, idx1]
  288. km[len(Gn), len(Gn) + 1] = km[idx1, idx2]
  289. km[len(Gn) + 1, len(Gn)] = km[idx2, idx1]
  290. km[len(Gn) + 1, len(Gn) + 1] = km[idx2, idx2]
  291. ###################################################################
  292. time_list = []
  293. nb_updated_list = []
  294. g_best = []
  295. dis_ks_min_list = []
  296. # for each alpha
  297. for alpha in alpha_range:
  298. print('\n-------------------------------------------------------\n')
  299. print('alpha =', alpha)
  300. time0 = time.time()
  301. dhat, ghat, nb_updated = preimage_random(Gn, [g1, g2], [alpha, 1 - alpha],
  302. range(len(Gn), len(Gn) + 2), km,
  303. k, r_max, l, gkernel)
  304. time_total = time.time() - time0 + time_km
  305. print('time: ', time_total)
  306. time_list.append(time_total)
  307. dis_ks_min_list.append(dhat)
  308. g_best.append(ghat)
  309. nb_updated_list.append(nb_updated)
  310. # show best graphs and save them to file.
  311. for idx, item in enumerate(alpha_range):
  312. print('when alpha is', item, 'the shortest distance is', dis_ks_min_list[idx])
  313. print('one of the possible corresponding pre-images is')
  314. nx.draw(g_best[idx], labels=nx.get_node_attributes(g_best[idx], 'atom'),
  315. with_labels=True)
  316. plt.show()
  317. plt.savefig('results/random_preimage/mutag_alpha' + str(item) + '.png', format="PNG")
  318. plt.clf()
  319. print(g_best[idx].nodes(data=True))
  320. print(g_best[idx].edges(data=True))
  321. # # compute the corresponding sod in graph space. (alpha range not considered.)
  322. # sod_tmp, _ = median_distance(g_best[0], Gn_let)
  323. # sod_gs_list.append(sod_tmp)
  324. # sod_gs_min_list.append(np.min(sod_tmp))
  325. # sod_ks_min_list.append(sod_ks)
  326. # nb_updated_list.append(nb_updated)
  327. # print('\nsmallest sod in graph space for each alpha: ', sod_gs_min_list)
  328. print('\nsmallest distance in kernel space for each alpha: ', dis_ks_min_list)
  329. print('\nnumber of updates for each alpha: ', nb_updated_list)
  330. print('\ntimes:', time_list)
  331. ###############################################################################
  332. if __name__ == '__main__':
  333. ###############################################################################
  334. # test on the combination of the two randomly chosen graphs. (the same as in the
  335. # random pre-image paper.)
  336. # test_random_preimage_2combination()
  337. ###############################################################################
  338. # tests all algorithms on different numbers of median-sets.
  339. test_preimage_random_median_nb()
  340. ###############################################################################
  341. # tests all algorithms on different values on grid of median-sets and k.
  342. # test_preimage_random_grid_k_median_nb()

A Python package for graph kernels, graph edit distances and graph pre-image problem.