You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel.py 5.9 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 30 11:52:47 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import multiprocessing
  10. import time
  11. from gklearn.utils import normalize_gram_matrix
  12. class GraphKernel(object):
  13. def __init__(self):
  14. self._graphs = None
  15. self._parallel = ''
  16. self._n_jobs = 0
  17. self._verbose = None
  18. self._normalize = True
  19. self._run_time = 0
  20. self._gram_matrix = None
  21. self._gram_matrix_unnorm = None
  22. def compute(self, *graphs, **kwargs):
  23. self._parallel = kwargs.get('parallel', 'imap_unordered')
  24. self._n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count())
  25. self._normalize = kwargs.get('normalize', True)
  26. self._verbose = kwargs.get('verbose', 2)
  27. if len(graphs) == 1:
  28. if not isinstance(graphs[0], list):
  29. raise Exception('Cannot detect graphs.')
  30. elif len(graphs[0]) == 0:
  31. raise Exception('The graph list given is empty. No computation was performed.')
  32. else:
  33. self._graphs = [g.copy() for g in graphs[0]]
  34. self._gram_matrix = self._compute_gram_matrix()
  35. self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  36. if self._normalize:
  37. self._gram_matrix = normalize_gram_matrix(self._gram_matrix)
  38. return self._gram_matrix, self._run_time
  39. elif len(graphs) == 2:
  40. if self.is_graph(graphs[0]) and self.is_graph(graphs[1]):
  41. kernel = self._compute_single_kernel(graphs[0].copy(), graphs[1].copy())
  42. return kernel, self._run_time
  43. elif self.is_graph(graphs[0]) and isinstance(graphs[1], list):
  44. g1 = graphs[0].copy()
  45. g_list = [g.copy() for g in graphs[1]]
  46. kernel_list = self._compute_kernel_list(g1, g_list)
  47. return kernel_list, self._run_time
  48. elif isinstance(graphs[0], list) and self.is_graph(graphs[1]):
  49. g1 = graphs[1].copy()
  50. g_list = [g.copy() for g in graphs[0]]
  51. kernel_list = self._compute_kernel_list(g1, g_list)
  52. return kernel_list, self._run_time
  53. else:
  54. raise Exception('Cannot detect graphs.')
  55. elif len(graphs) == 0 and self._graphs is None:
  56. raise Exception('Please add graphs before computing.')
  57. else:
  58. raise Exception('Cannot detect graphs.')
  59. def normalize_gm(self, gram_matrix):
  60. import warnings
  61. warnings.warn('gklearn.kernels.graph_kernel.normalize_gm will be deprecated, use gklearn.utils.normalize_gram_matrix instead', DeprecationWarning)
  62. diag = gram_matrix.diagonal().copy()
  63. for i in range(len(gram_matrix)):
  64. for j in range(i, len(gram_matrix)):
  65. gram_matrix[i][j] /= np.sqrt(diag[i] * diag[j])
  66. gram_matrix[j][i] = gram_matrix[i][j]
  67. return gram_matrix
  68. def compute_distance_matrix(self):
  69. if self._gram_matrix is None:
  70. raise Exception('Please compute the Gram matrix before computing distance matrix.')
  71. dis_mat = np.empty((len(self._gram_matrix), len(self._gram_matrix)))
  72. for i in range(len(self._gram_matrix)):
  73. for j in range(i, len(self._gram_matrix)):
  74. dis = self._gram_matrix[i, i] + self._gram_matrix[j, j] - 2 * self._gram_matrix[i, j]
  75. if dis < 0:
  76. if dis > -1e-10:
  77. dis = 0
  78. else:
  79. raise ValueError('The distance is negative.')
  80. dis_mat[i, j] = np.sqrt(dis)
  81. dis_mat[j, i] = dis_mat[i, j]
  82. dis_max = np.max(np.max(dis_mat))
  83. dis_min = np.min(np.min(dis_mat[dis_mat != 0]))
  84. dis_mean = np.mean(np.mean(dis_mat))
  85. return dis_mat, dis_max, dis_min, dis_mean
  86. def _compute_gram_matrix(self):
  87. start_time = time.time()
  88. if self._parallel == 'imap_unordered':
  89. gram_matrix = self._compute_gm_imap_unordered()
  90. elif self._parallel is None:
  91. gram_matrix = self._compute_gm_series()
  92. else:
  93. raise Exception('Parallel mode is not set correctly.')
  94. self._run_time = time.time() - start_time
  95. if self._verbose:
  96. print('Gram matrix of size %d built in %s seconds.'
  97. % (len(self._graphs), self._run_time))
  98. return gram_matrix
  99. def _compute_gm_series(self):
  100. pass
  101. def _compute_gm_imap_unordered(self):
  102. pass
  103. def _compute_kernel_list(self, g1, g_list):
  104. start_time = time.time()
  105. if self._parallel == 'imap_unordered':
  106. kernel_list = self._compute_kernel_list_imap_unordered(g1, g_list)
  107. elif self._parallel is None:
  108. kernel_list = self._compute_kernel_list_series(g1, g_list)
  109. else:
  110. raise Exception('Parallel mode is not set correctly.')
  111. self._run_time = time.time() - start_time
  112. if self._verbose:
  113. print('Graph kernel bewteen a graph and a list of %d graphs built in %s seconds.'
  114. % (len(g_list), self._run_time))
  115. return kernel_list
  116. def _compute_kernel_list_series(self, g1, g_list):
  117. pass
  118. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  119. pass
  120. def _compute_single_kernel(self, g1, g2):
  121. start_time = time.time()
  122. kernel = self._compute_single_kernel_series(g1, g2)
  123. self._run_time = time.time() - start_time
  124. if self._verbose:
  125. print('Graph kernel bewteen two graphs built in %s seconds.' % (self._run_time))
  126. return kernel
  127. def _compute_single_kernel_series(self, g1, g2):
  128. pass
  129. def is_graph(self, graph):
  130. if isinstance(graph, nx.Graph):
  131. return True
  132. if isinstance(graph, nx.DiGraph):
  133. return True
  134. if isinstance(graph, nx.MultiGraph):
  135. return True
  136. if isinstance(graph, nx.MultiDiGraph):
  137. return True
  138. return False
  139. @property
  140. def graphs(self):
  141. return self._graphs
  142. @property
  143. def parallel(self):
  144. return self._parallel
  145. @property
  146. def n_jobs(self):
  147. return self._n_jobs
  148. @property
  149. def verbose(self):
  150. return self._verbose
  151. @property
  152. def normalize(self):
  153. return self._normalize
  154. @property
  155. def run_time(self):
  156. return self._run_time
  157. @property
  158. def gram_matrix(self):
  159. return self._gram_matrix
  160. @gram_matrix.setter
  161. def gram_matrix(self, value):
  162. self._gram_matrix = value
  163. @property
  164. def gram_matrix_unnorm(self):
  165. return self._gram_matrix_unnorm
  166. @gram_matrix_unnorm.setter
  167. def gram_matrix_unnorm(self, value):
  168. self._gram_matrix_unnorm = value

A Python package for graph kernels, graph edit distances and graph pre-image problem.