You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 30 11:52:47 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import multiprocessing
  10. import time
  11. class GraphKernel(object):
  12. def __init__(self):
  13. self._graphs = None
  14. self._parallel = ''
  15. self._n_jobs = 0
  16. self._verbose = None
  17. self._normalize = True
  18. self._run_time = 0
  19. self._gram_matrix = None
  20. self._gram_matrix_unnorm = None
  21. def compute(self, *graphs, **kwargs):
  22. self._parallel = kwargs.get('parallel', 'imap_unordered')
  23. self._n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count())
  24. self._normalize = kwargs.get('normalize', True)
  25. self._verbose = kwargs.get('verbose', 2)
  26. if len(graphs) == 1:
  27. if not isinstance(graphs[0], list):
  28. raise Exception('Cannot detect graphs.')
  29. elif len(graphs[0]) == 0:
  30. raise Exception('The graph list given is empty. No computation was performed.')
  31. else:
  32. self._graphs = [g.copy() for g in graphs[0]]
  33. self._gram_matrix = self.__compute_gram_matrix()
  34. self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  35. if self._normalize:
  36. self._gram_matrix = self.normalize_gm(self._gram_matrix)
  37. return self._gram_matrix, self._run_time
  38. elif len(graphs) == 2:
  39. if self.is_graph(graphs[0]) and self.is_graph(graphs[1]):
  40. kernel = self.__compute_single_kernel(graphs[0], graphs[1])
  41. return kernel, self._run_time
  42. elif self.is_graph(graphs[0]) and isinstance(graphs[1], list):
  43. g1 = graphs[0].copy()
  44. g_list = [g.copy() for g in graphs[1]]
  45. kernel_list = self.__compute_kernel_list(g1, g_list)
  46. return kernel_list, self._run_time
  47. elif isinstance(graphs[0], list) and self.is_graph(graphs[1]):
  48. g1 = graphs[1].copy()
  49. g_list = [g.copy() for g in graphs[0]]
  50. kernel_list = self.__compute_kernel_list(g1, g_list)
  51. return kernel_list, self._run_time
  52. else:
  53. raise Exception('Cannot detect graphs.')
  54. elif len(graphs) == 0 and self._graphs is None:
  55. raise Exception('Please add graphs before computing.')
  56. else:
  57. raise Exception('Cannot detect graphs.')
  58. def normalize_gm(self, gram_matrix):
  59. diag = gram_matrix.diagonal().copy()
  60. for i in range(len(gram_matrix)):
  61. for j in range(i, len(gram_matrix)):
  62. gram_matrix[i][j] /= np.sqrt(diag[i] * diag[j])
  63. gram_matrix[j][i] = gram_matrix[i][j]
  64. return gram_matrix
  65. def compute_distance_matrix(self):
  66. dis_mat = np.empty((len(self._graphs), len(self._graphs)))
  67. if self._gram_matrix is None:
  68. raise Exception('Please compute the Gram matrix before computing distance matrix.')
  69. for i in range(len(self._graphs)):
  70. for j in range(i, len(self._graphs)):
  71. dis = self._gram_matrix[i, i] + self._gram_matrix[j, j] - 2 * self._gram_matrix[i, j]
  72. if dis < 0:
  73. if dis > -1e-10:
  74. dis = 0
  75. else:
  76. raise ValueError('The distance is negative.')
  77. dis_mat[i, j] = np.sqrt(dis)
  78. dis_mat[j, i] = dis_mat[i, j]
  79. dis_max = np.max(np.max(dis_mat))
  80. dis_min = np.min(np.min(dis_mat[dis_mat != 0]))
  81. dis_mean = np.mean(np.mean(dis_mat))
  82. return dis_mat, dis_max, dis_min, dis_mean
  83. def __compute_gram_matrix(self):
  84. start_time = time.time()
  85. if self._parallel == 'imap_unordered':
  86. gram_matrix = self._compute_gm_imap_unordered()
  87. elif self._parallel == None:
  88. gram_matrix = self._compute_gm_series()
  89. else:
  90. raise Exception('Parallel mode is not set correctly.')
  91. self._run_time = time.time() - start_time
  92. if self._verbose:
  93. print('Gram matrix of size %d built in %s seconds.'
  94. % (len(self._graphs), self._run_time))
  95. return gram_matrix
  96. def _compute_gm_series(self):
  97. pass
  98. def _compute_gm_imap_unordered(self):
  99. pass
  100. def __compute_kernel_list(self, g1, g_list):
  101. start_time = time.time()
  102. if self._parallel == 'imap_unordered':
  103. kernel_list = self._compute_kernel_list_imap_unordered(g1, g_list)
  104. elif self._parallel == None:
  105. kernel_list = self._compute_kernel_list_series(g1, g_list)
  106. else:
  107. raise Exception('Parallel mode is not set correctly.')
  108. self._run_time = time.time() - start_time
  109. if self._verbose:
  110. print('Graph kernel bewteen a graph and a list of %d graphs built in %s seconds.'
  111. % (len(g_list), self._run_time))
  112. return kernel_list
  113. def _compute_kernel_list_series(self, g1, g_list):
  114. pass
  115. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  116. pass
  117. def __compute_single_kernel(self, g1, g2):
  118. start_time = time.time()
  119. kernel = self._compute_single_kernel_series(g1, g2)
  120. self._run_time = time.time() - start_time
  121. if self._verbose:
  122. print('Graph kernel bewteen two graphs built in %s seconds.' % (self._run_time))
  123. return kernel
  124. def _compute_single_kernel_series(self, g1, g2):
  125. pass
  126. def is_graph(self, graph):
  127. if isinstance(graph, nx.Graph):
  128. return True
  129. if isinstance(graph, nx.DiGraph):
  130. return True
  131. if isinstance(graph, nx.MultiGraph):
  132. return True
  133. if isinstance(graph, nx.MultiDiGraph):
  134. return True
  135. return False
  136. @property
  137. def graphs(self):
  138. return self._graphs
  139. @property
  140. def parallel(self):
  141. return self._parallel
  142. @property
  143. def n_jobs(self):
  144. return self._n_jobs
  145. @property
  146. def verbose(self):
  147. return self._verbose
  148. @property
  149. def normalize(self):
  150. return self._normalize
  151. @property
  152. def run_time(self):
  153. return self._run_time
  154. @property
  155. def gram_matrix(self):
  156. return self._gram_matrix
  157. @property
  158. def gram_matrix_unnorm(self):
  159. return self._gram_matrix_unnorm

A Python package for graph kernels, graph edit distances and graph pre-image problem.