You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

random_preimage_generator.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri May 29 14:29:52 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import time
  9. import sys
  10. from tqdm import tqdm
  11. import multiprocessing
  12. import networkx as nx
  13. from multiprocessing import Pool
  14. from functools import partial
  15. from gklearn.preimage import PreimageGenerator
  16. from gklearn.preimage.utils import compute_k_dis
  17. from gklearn.utils import Timer
  18. from gklearn.utils.utils import get_graph_kernel_by_name
  19. # from gklearn.utils.dataset import Dataset
  20. class RandomPreimageGenerator(PreimageGenerator):
  21. def __init__(self, dataset=None):
  22. PreimageGenerator.__init__(self, dataset=dataset)
  23. # arguments to set.
  24. self._k = 5 # number of nearest neighbors of phi in D_N.
  25. self._r_max = 10 # maximum number of iterations.
  26. self._l = 500 # numbers of graphs generated for each graph in D_k U {g_i_hat}.
  27. self._alphas = None # weights of linear combinations of points in kernel space.
  28. self._parallel = True
  29. self._n_jobs = multiprocessing.cpu_count()
  30. self._time_limit_in_sec = 0
  31. self._max_itrs = 20
  32. # values to compute.
  33. self._runtime_generate_preimage = None
  34. self._runtime_total = None
  35. self._preimage = None
  36. self._best_from_dataset = None
  37. self._k_dis_preimage = None
  38. self._k_dis_dataset = None
  39. self._itrs = 0
  40. self._converged = False # @todo
  41. self._num_updates = 0
  42. # values that can be set or to be computed.
  43. self._gram_matrix_unnorm = None
  44. self._runtime_precompute_gm = None
  45. def set_options(self, **kwargs):
  46. self._kernel_options = kwargs.get('kernel_options', {})
  47. self._graph_kernel = kwargs.get('graph_kernel', None)
  48. self._verbose = kwargs.get('verbose', 2)
  49. self._k = kwargs.get('k', 5)
  50. self._r_max = kwargs.get('r_max', 10)
  51. self._l = kwargs.get('l', 500)
  52. self._alphas = kwargs.get('alphas', None)
  53. self._parallel = kwargs.get('parallel', True)
  54. self._n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count())
  55. self._time_limit_in_sec = kwargs.get('time_limit_in_sec', 0)
  56. self._max_itrs = kwargs.get('max_itrs', 20)
  57. self._gram_matrix_unnorm = kwargs.get('gram_matrix_unnorm', None)
  58. self._runtime_precompute_gm = kwargs.get('runtime_precompute_gm', None)
  59. def run(self):
  60. self._graph_kernel = get_graph_kernel_by_name(self._kernel_options['name'],
  61. node_labels=self._dataset.node_labels,
  62. edge_labels=self._dataset.edge_labels,
  63. node_attrs=self._dataset.node_attrs,
  64. edge_attrs=self._dataset.edge_attrs,
  65. ds_infos=self._dataset.get_dataset_infos(keys=['directed']),
  66. kernel_options=self._kernel_options)
  67. # record start time.
  68. start = time.time()
  69. # 1. precompute gram matrix.
  70. if self._gram_matrix_unnorm is None:
  71. gram_matrix, run_time = self._graph_kernel.compute(self._dataset.graphs, **self._kernel_options)
  72. self._gram_matrix_unnorm = self._graph_kernel.gram_matrix_unnorm
  73. end_precompute_gm = time.time()
  74. self._runtime_precompute_gm = end_precompute_gm - start
  75. else:
  76. if self._runtime_precompute_gm is None:
  77. raise Exception('Parameter "runtime_precompute_gm" must be given when using pre-computed Gram matrix.')
  78. self._graph_kernel.gram_matrix_unnorm = self._gram_matrix_unnorm
  79. if self._kernel_options['normalize']:
  80. self._graph_kernel.gram_matrix = self._graph_kernel.normalize_gm(np.copy(self._gram_matrix_unnorm))
  81. else:
  82. self._graph_kernel.gram_matrix = np.copy(self._gram_matrix_unnorm)
  83. end_precompute_gm = time.time()
  84. start -= self._runtime_precompute_gm
  85. # 2. compute k nearest neighbors of phi in D_N.
  86. if self._verbose >= 2:
  87. print('\nstart computing k nearest neighbors of phi in D_N...\n')
  88. D_N = self._dataset.graphs
  89. if self._alphas is None:
  90. self._alphas = [1 / len(D_N)] * len(D_N)
  91. k_dis_list = [] # distance between g_star and each graph.
  92. term3 = 0
  93. for i1, a1 in enumerate(self._alphas):
  94. for i2, a2 in enumerate(self._alphas):
  95. term3 += a1 * a2 * self._graph_kernel.gram_matrix[i1, i2]
  96. for idx in range(len(D_N)):
  97. k_dis_list.append(compute_k_dis(idx, range(0, len(D_N)), self._alphas, self._graph_kernel.gram_matrix, term3=term3, withterm3=True))
  98. # sort.
  99. sort_idx = np.argsort(k_dis_list)
  100. dis_gs = [k_dis_list[idis] for idis in sort_idx[0:self._k]] # the k shortest distances.
  101. nb_best = len(np.argwhere(dis_gs == dis_gs[0]).flatten().tolist())
  102. g0hat_list = [D_N[idx].copy() for idx in sort_idx[0:nb_best]] # the nearest neighbors of phi in D_N
  103. self._best_from_dataset = g0hat_list[0] # get the first best graph if there are muitlple.
  104. self._k_dis_dataset = dis_gs[0]
  105. if self._k_dis_dataset == 0: # get the exact pre-image.
  106. end_generate_preimage = time.time()
  107. self._runtime_generate_preimage = end_generate_preimage - end_precompute_gm
  108. self._runtime_total = end_generate_preimage - start
  109. self._preimage = self._best_from_dataset.copy()
  110. self._k_dis_preimage = self._k_dis_dataset
  111. if self._verbose:
  112. print()
  113. print('=============================================================================')
  114. print('The exact pre-image is found from the input dataset.')
  115. print('-----------------------------------------------------------------------------')
  116. print('Distance in kernel space for the best graph from dataset and for preimage:', self._k_dis_dataset)
  117. print('Time to pre-compute Gram matrix:', self._runtime_precompute_gm)
  118. print('Time to generate pre-images:', self._runtime_generate_preimage)
  119. print('Total time:', self._runtime_total)
  120. print('=============================================================================')
  121. print()
  122. return
  123. dhat = dis_gs[0] # the nearest distance
  124. Gk = [D_N[ig].copy() for ig in sort_idx[0:self._k]] # the k nearest neighbors
  125. Gs_nearest = [nx.convert_node_labels_to_integers(g) for g in Gk] # [g.copy() for g in Gk]
  126. # 3. start iterations.
  127. if self._verbose >= 2:
  128. print('starting iterations...')
  129. gihat_list = []
  130. dihat_list = []
  131. r = 0
  132. dis_of_each_itr = [dhat]
  133. if self._parallel:
  134. self._kernel_options['parallel'] = None
  135. self._itrs = 0
  136. self._num_updates = 0
  137. timer = Timer(self._time_limit_in_sec)
  138. while not self._termination_criterion_met(timer, self._itrs, r):
  139. print('\n- r =', r)
  140. found = False
  141. dis_bests = dis_gs + dihat_list
  142. # compute numbers of edges to be inserted/deleted.
  143. # @todo what if the log is negetive? how to choose alpha (scalar)?
  144. fdgs_list = np.array(dis_bests)
  145. if np.min(fdgs_list) < 1: # in case the log is negetive.
  146. fdgs_list /= np.min(fdgs_list)
  147. fdgs_list = [int(item) for item in np.ceil(np.log(fdgs_list))]
  148. if np.min(fdgs_list) < 1: # in case the log is smaller than 1.
  149. fdgs_list = np.array(fdgs_list) + 1
  150. # expand the number of modifications to increase the possiblity.
  151. nb_vpairs_list = [nx.number_of_nodes(g) * (nx.number_of_nodes(g) - 1) for g in (Gs_nearest + gihat_list)]
  152. nb_vpairs_min = np.min(nb_vpairs_list)
  153. idx_fdgs_max = np.argmax(fdgs_list)
  154. fdgs_max_old = fdgs_list[idx_fdgs_max]
  155. fdgs_max = fdgs_max_old
  156. nb_modif = 1
  157. for idx, nb in enumerate(range(nb_vpairs_min, nb_vpairs_min - fdgs_max, -1)):
  158. nb_modif *= nb / (fdgs_max - idx)
  159. while fdgs_max < nb_vpairs_min and nb_modif < self._l:
  160. fdgs_max += 1
  161. nb_modif *= (nb_vpairs_min - fdgs_max + 1) / fdgs_max
  162. nb_increase = int(fdgs_max - fdgs_max_old)
  163. if nb_increase > 0:
  164. fdgs_list += 1
  165. for ig, gs in enumerate(Gs_nearest + gihat_list):
  166. if self._verbose >= 2:
  167. print('-- computing', ig + 1, 'graphs out of', len(Gs_nearest) + len(gihat_list))
  168. gnew, dhat, found = self._generate_l_graphs(gs, fdgs_list[ig], dhat, ig, found, term3)
  169. if found:
  170. r = 0
  171. gihat_list = [gnew]
  172. dihat_list = [dhat]
  173. else:
  174. r += 1
  175. dis_of_each_itr.append(dhat)
  176. self._itrs += 1
  177. if self._verbose >= 2:
  178. print('Total number of iterations is', self._itrs, '.')
  179. print('The preimage is updated', self._num_updates, 'times.')
  180. print('The shortest distances for previous iterations are', dis_of_each_itr, '.')
  181. # get results and print.
  182. end_generate_preimage = time.time()
  183. self._runtime_generate_preimage = end_generate_preimage - end_precompute_gm
  184. self._runtime_total = end_generate_preimage - start
  185. self._preimage = (g0hat_list[0] if len(gihat_list) == 0 else gihat_list[0])
  186. self._k_dis_preimage = dhat
  187. if self._verbose:
  188. print()
  189. print('=============================================================================')
  190. print('Finished generation of preimages.')
  191. print('-----------------------------------------------------------------------------')
  192. print('Distance in kernel space for the best graph from dataset:', self._k_dis_dataset)
  193. print('Distance in kernel space for the preimage:', self._k_dis_preimage)
  194. print('Total number of iterations for optimizing:', self._itrs)
  195. print('Total number of updating preimage:', self._num_updates)
  196. print('Time to pre-compute Gram matrix:', self._runtime_precompute_gm)
  197. print('Time to generate pre-images:', self._runtime_generate_preimage)
  198. print('Total time:', self._runtime_total)
  199. print('=============================================================================')
  200. print()
  201. def _generate_l_graphs(self, g_init, fdgs, dhat, ig, found, term3):
  202. if self._parallel:
  203. gnew, dhat, found = self._generate_l_graphs_parallel(g_init, fdgs, dhat, ig, found, term3)
  204. else:
  205. gnew, dhat, found = self._generate_l_graphs_series(g_init, fdgs, dhat, ig, found, term3)
  206. return gnew, dhat, found
  207. def _generate_l_graphs_series(self, g_init, fdgs, dhat, ig, found, term3):
  208. gnew = None
  209. updated = False
  210. for trial in range(0, self._l):
  211. if self._verbose >= 2:
  212. print('---', trial + 1, 'trial out of', self._l)
  213. gtemp, dnew = self._do_trial(g_init, fdgs, term3, trial)
  214. # get the better graph preimage.
  215. if dnew <= dhat: # @todo: the new distance is smaller or also equal?
  216. if dhat - dnew > 1e-6:
  217. if self._verbose >= 2:
  218. print('trial =', str(trial))
  219. print('\nI am smaller!')
  220. print('index (as in D_k U {gihat} =', str(ig))
  221. print('distance:', dhat, '->', dnew)
  222. updated = True
  223. else:
  224. if self._verbose >= 2:
  225. print('I am equal!')
  226. dhat = dnew
  227. gnew = gtemp.copy()
  228. found = True # found better or equally good graph.
  229. if updated:
  230. self._num_updates += 1
  231. return gnew, dhat, found
  232. def _generate_l_graphs_parallel(self, g_init, fdgs, dhat, ig, found, term3):
  233. gnew = None
  234. len_itr = self._l
  235. gnew_list = [None] * len_itr
  236. dnew_list = [None] * len_itr
  237. itr = range(0, len_itr)
  238. n_jobs = multiprocessing.cpu_count()
  239. if len_itr < 100 * n_jobs:
  240. chunksize = int(len_itr / n_jobs) + 1
  241. else:
  242. chunksize = 100
  243. do_fun = partial(self._generate_graph_parallel, g_init, fdgs, term3)
  244. pool = Pool(processes=n_jobs)
  245. if self._verbose >= 2:
  246. iterator = tqdm(pool.imap_unordered(do_fun, itr, chunksize),
  247. desc='Generating l graphs', file=sys.stdout)
  248. else:
  249. iterator = pool.imap_unordered(do_fun, itr, chunksize)
  250. for idx, gnew, dnew in iterator:
  251. gnew_list[idx] = gnew
  252. dnew_list[idx] = dnew
  253. pool.close()
  254. pool.join()
  255. # check if get the better graph preimage.
  256. idx_min = np.argmin(dnew_list)
  257. dnew = dnew_list[idx_min]
  258. if dnew <= dhat: # @todo: the new distance is smaller or also equal?
  259. if dhat - dnew > 1e-6: # @todo: use a proportion and watch out for 0.
  260. if self._verbose >= 2:
  261. print('I am smaller!')
  262. print('index (as in D_k U {gihat}) =', str(ig))
  263. print('distance:', dhat, '->', dnew, '\n')
  264. self._num_updates += 1
  265. else:
  266. if self._verbose >= 2:
  267. print('I am equal!')
  268. dhat = dnew
  269. gnew = gnew_list[idx_min]
  270. found = True # found better graph.
  271. return gnew, dhat, found
  272. def _generate_graph_parallel(self, g_init, fdgs, term3, itr):
  273. trial = itr
  274. gtemp, dnew = self._do_trial(g_init, fdgs, term3, trial)
  275. return trial, gtemp, dnew
  276. def _do_trial(self, g_init, fdgs, term3, trial):
  277. # add and delete edges.
  278. gtemp = g_init.copy()
  279. seed = (trial + int(time.time())) % (2 ** 32 - 1)
  280. rdm_state = np.random.RandomState(seed=seed)
  281. # which edges to change.
  282. # @todo: should we use just half of the adjacency matrix for undirected graphs?
  283. nb_vpairs = nx.number_of_nodes(g_init) * (nx.number_of_nodes(g_init) - 1)
  284. # @todo: what if fdgs is bigger than nb_vpairs?
  285. idx_change = rdm_state.randint(0, high=nb_vpairs, size=(fdgs if
  286. fdgs < nb_vpairs else nb_vpairs))
  287. # print(idx_change)
  288. for item in idx_change:
  289. node1 = int(item / (nx.number_of_nodes(g_init) - 1))
  290. node2 = (item - node1 * (nx.number_of_nodes(g_init) - 1))
  291. if node2 >= node1: # skip the self pair.
  292. node2 += 1
  293. # @todo: is the randomness correct?
  294. if not gtemp.has_edge(node1, node2):
  295. gtemp.add_edge(node1, node2)
  296. else:
  297. gtemp.remove_edge(node1, node2)
  298. # compute new distances.
  299. kernels_to_gtmp, _ = self._graph_kernel.compute(gtemp, self._dataset.graphs, **self._kernel_options)
  300. kernel_gtmp, _ = self._graph_kernel.compute(gtemp, gtemp, **self._kernel_options)
  301. if self._kernel_options['normalize']:
  302. kernels_to_gtmp = [kernels_to_gtmp[i] / np.sqrt(self._gram_matrix_unnorm[i, i] * kernel_gtmp) for i in range(len(kernels_to_gtmp))] # normalize
  303. kernel_gtmp = 1
  304. # @todo: not correct kernel value
  305. gram_with_gtmp = np.concatenate((np.array([kernels_to_gtmp]), np.copy(self._graph_kernel.gram_matrix)), axis=0)
  306. gram_with_gtmp = np.concatenate((np.array([[kernel_gtmp] + kernels_to_gtmp]).T, gram_with_gtmp), axis=1)
  307. dnew = compute_k_dis(0, range(1, 1 + len(self._dataset.graphs)), self._alphas, gram_with_gtmp, term3=term3, withterm3=True)
  308. return gtemp, dnew
  309. def get_results(self):
  310. results = {}
  311. results['runtime_precompute_gm'] = self._runtime_precompute_gm
  312. results['runtime_generate_preimage'] = self._runtime_generate_preimage
  313. results['runtime_total'] = self._runtime_total
  314. results['k_dis_dataset'] = self._k_dis_dataset
  315. results['k_dis_preimage'] = self._k_dis_preimage
  316. results['itrs'] = self._itrs
  317. results['num_updates'] = self._num_updates
  318. return results
  319. def _termination_criterion_met(self, timer, itr, r):
  320. if timer.expired() or (itr >= self._max_itrs if self._max_itrs >= 0 else False):
  321. # if self._state == AlgorithmState.TERMINATED:
  322. # self._state = AlgorithmState.INITIALIZED
  323. return True
  324. return (r >= self._r_max if self._r_max >= 0 else False)
  325. # return converged or (itrs_without_update > self._max_itrs_without_update if self._max_itrs_without_update >= 0 else False)
  326. @property
  327. def preimage(self):
  328. return self._preimage
  329. @property
  330. def best_from_dataset(self):
  331. return self._best_from_dataset
  332. @property
  333. def gram_matrix_unnorm(self):
  334. return self._gram_matrix_unnorm
  335. @gram_matrix_unnorm.setter
  336. def gram_matrix_unnorm(self, value):
  337. self._gram_matrix_unnorm = value

A Python package for graph kernels, graph edit distances and graph pre-image problem.