You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel.py 15 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 30 11:52:47 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import multiprocessing
  10. import time
  11. # from abc import ABC, abstractmethod
  12. from sklearn.base import BaseEstimator # , TransformerMixin
  13. from sklearn.utils.validation import check_is_fitted # check_X_y, check_array,
  14. from sklearn.exceptions import NotFittedError
  15. from gklearn.utils import normalize_gram_matrix
  16. class GraphKernel(BaseEstimator): #, ABC):
  17. """The basic graph kernel class.
  18. Attributes
  19. ----------
  20. _graphs : list
  21. Stores the input graphs on fit input data.
  22. Default format of the list objects is `NetworkX` graphs.
  23. **We don't guarantee that the input graphs remain unchanged during the
  24. computation.**
  25. References
  26. ----------
  27. https://ysig.github.io/GraKeL/0.1a8/_modules/grakel/kernels/kernel.html#Kernel.
  28. """
  29. def __init__(self, parallel=None, n_jobs=None, chunksize=None, normalize=True, verbose=2):
  30. """`__init__` for `GraphKernel` object."""
  31. # @todo: the default settings of the parameters are different from those in the self.compute method.
  32. # self._graphs = None
  33. self.parallel = parallel
  34. self.n_jobs = n_jobs
  35. self.chunksize = chunksize
  36. self.normalize = normalize
  37. self.verbose = verbose
  38. # self._run_time = 0
  39. # self._gram_matrix = None
  40. # self._gram_matrix_unnorm = None
  41. ##########################################################################
  42. # The following is the 1st paradigm to compute kernel matrix, which is
  43. # compatible with `scikit-learn`.
  44. # -------------------------------------------------------------------
  45. # Special thanks to the "GraKeL" library for providing an excellent template!
  46. ##########################################################################
  47. def fit(self, X, y=None):
  48. """Fit a graph dataset for a transformer.
  49. Parameters
  50. ----------
  51. X : iterable
  52. DESCRIPTION.
  53. y : None, optional
  54. There is no need of a target in a transformer, yet the `scikit-learn`
  55. pipeline API requires this parameter.
  56. Returns
  57. -------
  58. object
  59. Returns self.
  60. """
  61. # self._is_tranformed = False
  62. # Clear any prior attributes stored on the estimator, # @todo: unless warm_start is used;
  63. self.clear_attributes()
  64. # X = check_array(X, accept_sparse=True)
  65. # Validate parameters for the transformer.
  66. self.validate_parameters()
  67. # Validate the input.
  68. self._graphs = self.validate_input(X)
  69. # self._X = X
  70. # self._kernel = self._get_kernel_instance()
  71. # Return the transformer.
  72. return self
  73. def transform(self, X):
  74. """Compute the graph kernel matrix between given and fitted data.
  75. Parameters
  76. ----------
  77. X : TYPE
  78. DESCRIPTION.
  79. Raises
  80. ------
  81. ValueError
  82. DESCRIPTION.
  83. Returns
  84. -------
  85. None.
  86. """
  87. # Check if method "fit" had been called.
  88. check_is_fitted(self, '_graphs')
  89. # Validate the input.
  90. Y = self.validate_input(X)
  91. # Transform: compute the graph kernel matrix.
  92. kernel_matrix = self.compute_kernel_matrix(Y)
  93. self._Y = Y
  94. # Self transform must appear before the diagonal call on normilization.
  95. self._is_transformed = True
  96. if self.normalize:
  97. X_diag, Y_diag = self.diagonals()
  98. old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
  99. try:
  100. kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag))
  101. except:
  102. raise
  103. finally:
  104. np.seterr(**old_settings)
  105. return kernel_matrix
  106. def fit_transform(self, X):
  107. """Fit and transform: compute Gram matrix on the same data.
  108. Parameters
  109. ----------
  110. X : list of graphs
  111. Input graphs.
  112. Returns
  113. -------
  114. gram_matrix : numpy array, shape = [len(X), len(X)]
  115. The Gram matrix of X.
  116. """
  117. self.fit(X)
  118. # Transform: compute Gram matrix.
  119. gram_matrix = self.compute_kernel_matrix()
  120. # Normalize.
  121. if self.normalize:
  122. self._X_diag = np.diagonal(gram_matrix).copy()
  123. old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
  124. try:
  125. gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag))
  126. except:
  127. raise
  128. finally:
  129. np.seterr(**old_settings)
  130. return gram_matrix
  131. def get_params(self):
  132. pass
  133. def set_params(self):
  134. pass
  135. def clear_attributes(self):
  136. if hasattr(self, '_X_diag'):
  137. delattr(self, '_X_diag')
  138. if hasattr(self, '_graphs'):
  139. delattr(self, '_graphs')
  140. if hasattr(self, '_Y'):
  141. delattr(self, '_Y')
  142. if hasattr(self, '_run_time'):
  143. delattr(self, '_run_time')
  144. def validate_parameters(self):
  145. """Validate all parameters for the transformer.
  146. Returns
  147. -------
  148. None.
  149. """
  150. if self.parallel is not None and self.parallel != 'imap_unordered':
  151. raise ValueError('Parallel mode is not set correctly.')
  152. if self.parallel == 'imap_unordered' and self.n_jobs is None:
  153. self.n_jobs = multiprocessing.cpu_count()
  154. def validate_input(self, X):
  155. """Validate the given input and raise errors if it is invalid.
  156. Parameters
  157. ----------
  158. X : list
  159. The input to check. Should be a list of graph.
  160. Raises
  161. ------
  162. ValueError
  163. Raise if the input is not correct.
  164. Returns
  165. -------
  166. X : list
  167. The input. A list of graph.
  168. """
  169. if X is None:
  170. raise ValueError('Please add graphs before computing.')
  171. elif not isinstance(X, list):
  172. raise ValueError('Cannot detect graphs.')
  173. elif len(X) == 0:
  174. raise ValueError('The graph list given is empty. No computation will be performed.')
  175. return X
  176. def compute_kernel_matrix(self, Y=None):
  177. """Compute the kernel matrix between a given target graphs (Y) and
  178. the fitted graphs (X / self._graphs) or the Gram matrix for the fitted
  179. graphs (X / self._graphs).
  180. Parameters
  181. ----------
  182. Y : list of graphs, optional
  183. The target graphs. The default is None. If None kernel is computed
  184. between X and itself.
  185. Returns
  186. -------
  187. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  188. The computed kernel matrix.
  189. """
  190. if Y is None:
  191. # Compute Gram matrix for self._graphs (X).
  192. kernel_matrix = self._compute_gram_matrix()
  193. # self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  194. else:
  195. # Compute kernel matrix between Y and self._graphs (X).
  196. start_time = time.time()
  197. if self.parallel == 'imap_unordered':
  198. kernel_matrix = self._compute_kernel_matrix_imap_unordered(Y)
  199. elif self.parallel is None:
  200. kernel_matrix = self._compute_kernel_matrix_series(Y)
  201. self._run_time = time.time() - start_time
  202. if self.verbose:
  203. print('Kernel matrix of size (%d, %d) built in %s seconds.'
  204. % (len(Y), len(self._graphs), self._run_time))
  205. return kernel_matrix
  206. def _compute_kernel_matrix_series(self, Y):
  207. """Compute the kernel matrix between a given target graphs (Y) and
  208. the fitted graphs (X / self._graphs) without parallelization.
  209. Parameters
  210. ----------
  211. Y : list of graphs, optional
  212. The target graphs.
  213. Returns
  214. -------
  215. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  216. The computed kernel matrix.
  217. """
  218. kernel_matrix = np.zeros((len(Y), len(self._graphs)))
  219. for i_y, g_y in enumerate(Y):
  220. for i_x, g_x in enumerate(self._graphs):
  221. kernel_matrix[i_y, i_x] = self.pairwise_kernel(g_y, g_x)
  222. return kernel_matrix
  223. def _compute_kernel_matrix_imap_unordered(self, Y):
  224. """Compute the kernel matrix between a given target graphs (Y) and
  225. the fitted graphs (X / self._graphs) using imap unordered parallelization.
  226. Parameters
  227. ----------
  228. Y : list of graphs, optional
  229. The target graphs.
  230. Returns
  231. -------
  232. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  233. The computed kernel matrix.
  234. """
  235. raise Exception('Parallelization for kernel matrix is not implemented.')
  236. def diagonals(self):
  237. """Compute the kernel matrix diagonals of the fit/transformed data.
  238. Returns
  239. -------
  240. X_diag : numpy array
  241. The diagonal of the kernel matrix between the fitted data.
  242. This consists of each element calculated with itself.
  243. Y_diag : numpy array
  244. The diagonal of the kernel matrix, of the transform.
  245. This consists of each element calculated with itself.
  246. """
  247. # Check if method "fit" had been called.
  248. check_is_fitted(self, ['_graphs'])
  249. # Check if the diagonals of X exist.
  250. try:
  251. check_is_fitted(self, ['_X_diag'])
  252. except NotFittedError:
  253. # Compute diagonals of X.
  254. self._X_diag = np.empty(shape=(len(self._graphs),))
  255. for i, x in enumerate(self._graphs):
  256. self._X_diag[i] = self.pairwise_kernel(x, x) # @todo: parallel?
  257. try:
  258. # If transform has happened, return both diagonals.
  259. check_is_fitted(self, ['_Y'])
  260. self._Y_diag = np.empty(shape=(len(self._Y),))
  261. for (i, y) in enumerate(self._Y):
  262. self._Y_diag[i] = self.pairwise_kernel(y, y) # @todo: parallel?
  263. return self._X_diag, self._Y_diag
  264. except NotFittedError:
  265. # Else just return both X_diag
  266. return self._X_diag
  267. # @abstractmethod
  268. def pairwise_kernel(self, x, y):
  269. """Compute pairwise kernel between two graphs.
  270. Parameters
  271. ----------
  272. x, y : NetworkX Graph.
  273. Graphs bewteen which the kernel is computed.
  274. Returns
  275. -------
  276. kernel: float
  277. The computed kernel.
  278. # Notes
  279. # -----
  280. # This method is abstract and must be implemented by a subclass.
  281. """
  282. raise NotImplementedError('Pairwise kernel computation is not implemented!')
  283. ##########################################################################
  284. # The following is the 2nd paradigm to compute kernel matrix. It is
  285. # simplified and not compatible with `scikit-learn`.
  286. ##########################################################################
  287. def compute(self, *graphs, **kwargs):
  288. self.parallel = kwargs.get('parallel', 'imap_unordered')
  289. self.n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count())
  290. self.normalize = kwargs.get('normalize', True)
  291. self.verbose = kwargs.get('verbose', 2)
  292. self.validate_parameters()
  293. if len(graphs) == 1:
  294. if not isinstance(graphs[0], list):
  295. raise Exception('Cannot detect graphs.')
  296. elif len(graphs[0]) == 0:
  297. raise Exception('The graph list given is empty. No computation was performed.')
  298. else:
  299. self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow.
  300. self._gram_matrix = self._compute_gram_matrix()
  301. self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  302. if self.normalize:
  303. self._gram_matrix = normalize_gram_matrix(self._gram_matrix)
  304. return self._gram_matrix, self._run_time
  305. elif len(graphs) == 2:
  306. if self.is_graph(graphs[0]) and self.is_graph(graphs[1]):
  307. kernel = self._compute_single_kernel(graphs[0].copy(), graphs[1].copy())
  308. return kernel, self._run_time
  309. elif self.is_graph(graphs[0]) and isinstance(graphs[1], list):
  310. g1 = graphs[0].copy()
  311. g_list = [g.copy() for g in graphs[1]]
  312. kernel_list = self._compute_kernel_list(g1, g_list)
  313. return kernel_list, self._run_time
  314. elif isinstance(graphs[0], list) and self.is_graph(graphs[1]):
  315. g1 = graphs[1].copy()
  316. g_list = [g.copy() for g in graphs[0]]
  317. kernel_list = self._compute_kernel_list(g1, g_list)
  318. return kernel_list, self._run_time
  319. else:
  320. raise Exception('Cannot detect graphs.')
  321. elif len(graphs) == 0 and self._graphs is None:
  322. raise Exception('Please add graphs before computing.')
  323. else:
  324. raise Exception('Cannot detect graphs.')
  325. def normalize_gm(self, gram_matrix):
  326. import warnings
  327. warnings.warn('gklearn.kernels.graph_kernel.normalize_gm will be deprecated, use gklearn.utils.normalize_gram_matrix instead', DeprecationWarning)
  328. diag = gram_matrix.diagonal().copy()
  329. for i in range(len(gram_matrix)):
  330. for j in range(i, len(gram_matrix)):
  331. gram_matrix[i][j] /= np.sqrt(diag[i] * diag[j])
  332. gram_matrix[j][i] = gram_matrix[i][j]
  333. return gram_matrix
  334. def compute_distance_matrix(self):
  335. if self._gram_matrix is None:
  336. raise Exception('Please compute the Gram matrix before computing distance matrix.')
  337. dis_mat = np.empty((len(self._gram_matrix), len(self._gram_matrix)))
  338. for i in range(len(self._gram_matrix)):
  339. for j in range(i, len(self._gram_matrix)):
  340. dis = self._gram_matrix[i, i] + self._gram_matrix[j, j] - 2 * self._gram_matrix[i, j]
  341. if dis < 0:
  342. if dis > -1e-10:
  343. dis = 0
  344. else:
  345. raise ValueError('The distance is negative.')
  346. dis_mat[i, j] = np.sqrt(dis)
  347. dis_mat[j, i] = dis_mat[i, j]
  348. dis_max = np.max(np.max(dis_mat))
  349. dis_min = np.min(np.min(dis_mat[dis_mat != 0]))
  350. dis_mean = np.mean(np.mean(dis_mat))
  351. return dis_mat, dis_max, dis_min, dis_mean
  352. def _compute_gram_matrix(self):
  353. start_time = time.time()
  354. if self.parallel == 'imap_unordered':
  355. gram_matrix = self._compute_gm_imap_unordered()
  356. elif self.parallel is None:
  357. gram_matrix = self._compute_gm_series()
  358. else:
  359. raise Exception('Parallel mode is not set correctly.')
  360. self._run_time = time.time() - start_time
  361. if self.verbose:
  362. print('Gram matrix of size %d built in %s seconds.'
  363. % (len(self._graphs), self._run_time))
  364. return gram_matrix
  365. def _compute_gm_series(self):
  366. pass
  367. def _compute_gm_imap_unordered(self):
  368. pass
  369. def _compute_kernel_list(self, g1, g_list):
  370. start_time = time.time()
  371. if self.parallel == 'imap_unordered':
  372. kernel_list = self._compute_kernel_list_imap_unordered(g1, g_list)
  373. elif self.parallel is None:
  374. kernel_list = self._compute_kernel_list_series(g1, g_list)
  375. else:
  376. raise Exception('Parallel mode is not set correctly.')
  377. self._run_time = time.time() - start_time
  378. if self.verbose:
  379. print('Graph kernel bewteen a graph and a list of %d graphs built in %s seconds.'
  380. % (len(g_list), self._run_time))
  381. return kernel_list
  382. def _compute_kernel_list_series(self, g1, g_list):
  383. pass
  384. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  385. pass
  386. def _compute_single_kernel(self, g1, g2):
  387. start_time = time.time()
  388. kernel = self._compute_single_kernel_series(g1, g2)
  389. self._run_time = time.time() - start_time
  390. if self.verbose:
  391. print('Graph kernel bewteen two graphs built in %s seconds.' % (self._run_time))
  392. return kernel
  393. def _compute_single_kernel_series(self, g1, g2):
  394. pass
  395. def is_graph(self, graph):
  396. if isinstance(graph, nx.Graph):
  397. return True
  398. if isinstance(graph, nx.DiGraph):
  399. return True
  400. if isinstance(graph, nx.MultiGraph):
  401. return True
  402. if isinstance(graph, nx.MultiDiGraph):
  403. return True
  404. return False
  405. @property
  406. def graphs(self):
  407. return self._graphs
  408. # @property
  409. # def parallel(self):
  410. # return self.parallel
  411. # @property
  412. # def n_jobs(self):
  413. # return self.n_jobs
  414. # @property
  415. # def verbose(self):
  416. # return self.verbose
  417. # @property
  418. # def normalize(self):
  419. # return self.normalize
  420. @property
  421. def run_time(self):
  422. return self._run_time
  423. @property
  424. def gram_matrix(self):
  425. return self._gram_matrix
  426. @gram_matrix.setter
  427. def gram_matrix(self, value):
  428. self._gram_matrix = value
  429. @property
  430. def gram_matrix_unnorm(self):
  431. return self._gram_matrix_unnorm
  432. @gram_matrix_unnorm.setter
  433. def gram_matrix_unnorm(self, value):
  434. self._gram_matrix_unnorm = value

A Python package for graph kernels, graph edit distances and graph pre-image problem.