You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel.py 16 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 30 11:52:47 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import multiprocessing
  10. import time
  11. # from abc import ABC, abstractmethod
  12. from sklearn.base import BaseEstimator # , TransformerMixin
  13. from sklearn.utils.validation import check_is_fitted # check_X_y, check_array,
  14. from sklearn.exceptions import NotFittedError
  15. from gklearn.utils import normalize_gram_matrix
  16. class GraphKernel(BaseEstimator): #, ABC):
  17. """The basic graph kernel class.
  18. Attributes
  19. ----------
  20. _graphs : list
  21. Stores the input graphs on fit input data.
  22. Default format of the list objects is `NetworkX` graphs.
  23. **We don't guarantee that the input graphs remain unchanged during the
  24. computation.**
  25. References
  26. ----------
  27. https://ysig.github.io/GraKeL/0.1a8/_modules/grakel/kernels/kernel.html#Kernel.
  28. """
  29. def __init__(self, parallel=None, n_jobs=None, chunksize=None, normalize=True, verbose=2):
  30. """`__init__` for `GraphKernel` object."""
  31. # @todo: the default settings of the parameters are different from those in the self.compute method.
  32. # self._graphs = None
  33. self.parallel = parallel
  34. self.n_jobs = n_jobs
  35. self.chunksize = chunksize
  36. self.normalize = normalize
  37. self.verbose = verbose
  38. # self._run_time = 0
  39. # self._gram_matrix = None
  40. # self._gram_matrix_unnorm = None
  41. ##########################################################################
  42. # The following is the 1st paradigm to compute kernel matrix, which is
  43. # compatible with `scikit-learn`.
  44. # -------------------------------------------------------------------
  45. # Special thanks to the "GraKeL" library for providing an excellent template!
  46. ##########################################################################
  47. def fit(self, X, y=None):
  48. """Fit a graph dataset for a transformer.
  49. Parameters
  50. ----------
  51. X : iterable
  52. DESCRIPTION.
  53. y : None, optional
  54. There is no need of a target in a transformer, yet the `scikit-learn`
  55. pipeline API requires this parameter.
  56. Returns
  57. -------
  58. object
  59. Returns self.
  60. """
  61. # self._is_tranformed = False
  62. # Clear any prior attributes stored on the estimator, # @todo: unless warm_start is used;
  63. self.clear_attributes()
  64. # Validate parameters for the transformer.
  65. self.validate_parameters()
  66. # Validate the input.
  67. self._graphs = self.validate_input(X)
  68. # self._X = X
  69. # self._kernel = self._get_kernel_instance()
  70. # Return the transformer.
  71. return self
  72. def transform(self, X):
  73. """Compute the graph kernel matrix between given and fitted data.
  74. Parameters
  75. ----------
  76. X : TYPE
  77. DESCRIPTION.
  78. Raises
  79. ------
  80. ValueError
  81. DESCRIPTION.
  82. Returns
  83. -------
  84. None.
  85. """
  86. # Check if method "fit" had been called.
  87. check_is_fitted(self, '_graphs')
  88. # Validate the input.
  89. Y = self.validate_input(X)
  90. # Transform: compute the graph kernel matrix.
  91. kernel_matrix = self.compute_kernel_matrix(Y)
  92. self._Y = Y
  93. # Self transform must appear before the diagonal call on normilization.
  94. self._is_transformed = True
  95. if self.normalize:
  96. X_diag, Y_diag = self.diagonals()
  97. old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
  98. try:
  99. kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag))
  100. except:
  101. raise
  102. finally:
  103. np.seterr(**old_settings)
  104. return kernel_matrix
  105. def fit_transform(self, X):
  106. """Fit and transform: compute Gram matrix on the same data.
  107. Parameters
  108. ----------
  109. X : list of graphs
  110. Input graphs.
  111. Returns
  112. -------
  113. gram_matrix : numpy array, shape = [len(X), len(X)]
  114. The Gram matrix of X.
  115. """
  116. self.fit(X)
  117. # Transform: compute Gram matrix.
  118. gram_matrix = self.compute_kernel_matrix()
  119. # Normalize.
  120. if self.normalize:
  121. self._X_diag = np.diagonal(gram_matrix).copy()
  122. old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
  123. try:
  124. gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag))
  125. except:
  126. raise
  127. finally:
  128. np.seterr(**old_settings)
  129. return gram_matrix
  130. def get_params(self):
  131. pass
  132. def set_params(self):
  133. pass
  134. def clear_attributes(self):
  135. if hasattr(self, '_X_diag'):
  136. delattr(self, '_X_diag')
  137. if hasattr(self, '_graphs'):
  138. delattr(self, '_graphs')
  139. if hasattr(self, '_Y'):
  140. delattr(self, '_Y')
  141. if hasattr(self, '_run_time'):
  142. delattr(self, '_run_time')
  143. def validate_parameters(self):
  144. """Validate all parameters for the transformer.
  145. Returns
  146. -------
  147. None.
  148. """
  149. if self.parallel is not None and self.parallel != 'imap_unordered':
  150. raise ValueError('Parallel mode is not set correctly.')
  151. if self.parallel == 'imap_unordered' and self.n_jobs is None:
  152. self.n_jobs = multiprocessing.cpu_count()
  153. def validate_input(self, X):
  154. """Validate the given input and raise errors if it is invalid.
  155. Parameters
  156. ----------
  157. X : list
  158. The input to check. Should be a list of graph.
  159. Raises
  160. ------
  161. ValueError
  162. Raise if the input is not correct.
  163. Returns
  164. -------
  165. X : list
  166. The input. A list of graph.
  167. """
  168. if X is None:
  169. raise ValueError('Please add graphs before computing.')
  170. elif not isinstance(X, list):
  171. raise ValueError('Cannot detect graphs.')
  172. elif len(X) == 0:
  173. raise ValueError('The graph list given is empty. No computation will be performed.')
  174. return X
  175. def compute_kernel_matrix(self, Y=None):
  176. """Compute the kernel matrix between a given target graphs (Y) and
  177. the fitted graphs (X / self._graphs) or the Gram matrix for the fitted
  178. graphs (X / self._graphs).
  179. Parameters
  180. ----------
  181. Y : list of graphs, optional
  182. The target graphs. The default is None. If None kernel is computed
  183. between X and itself.
  184. Returns
  185. -------
  186. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  187. The computed kernel matrix.
  188. """
  189. if Y is None:
  190. # Compute Gram matrix for self._graphs (X).
  191. kernel_matrix = self._compute_gram_matrix()
  192. # self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  193. else:
  194. # Compute kernel matrix between Y and self._graphs (X).
  195. start_time = time.time()
  196. if self.parallel == 'imap_unordered':
  197. kernel_matrix = self._compute_kernel_matrix_imap_unordered(Y)
  198. elif self.parallel is None:
  199. kernel_matrix = self._compute_kernel_matrix_series(Y)
  200. self._run_time = time.time() - start_time
  201. if self.verbose:
  202. print('Kernel matrix of size (%d, %d) built in %s seconds.'
  203. % (len(Y), len(self._graphs), self._run_time))
  204. return kernel_matrix
  205. def _compute_kernel_matrix_series(self, Y):
  206. """Compute the kernel matrix between a given target graphs (Y) and
  207. the fitted graphs (X / self._graphs) without parallelization.
  208. Parameters
  209. ----------
  210. Y : list of graphs, optional
  211. The target graphs.
  212. Returns
  213. -------
  214. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  215. The computed kernel matrix.
  216. """
  217. kernel_matrix = np.zeros((len(Y), len(self._graphs)))
  218. for i_y, g_y in enumerate(Y):
  219. for i_x, g_x in enumerate(self._graphs):
  220. kernel_matrix[i_y, i_x] = self.pairwise_kernel(g_y, g_x)
  221. return kernel_matrix
  222. def _compute_kernel_matrix_imap_unordered(self, Y):
  223. """Compute the kernel matrix between a given target graphs (Y) and
  224. the fitted graphs (X / self._graphs) using imap unordered parallelization.
  225. Parameters
  226. ----------
  227. Y : list of graphs, optional
  228. The target graphs.
  229. Returns
  230. -------
  231. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  232. The computed kernel matrix.
  233. """
  234. raise Exception('Parallelization for kernel matrix is not implemented.')
  235. def diagonals(self):
  236. """Compute the kernel matrix diagonals of the fit/transformed data.
  237. Returns
  238. -------
  239. X_diag : numpy array
  240. The diagonal of the kernel matrix between the fitted data.
  241. This consists of each element calculated with itself.
  242. Y_diag : numpy array
  243. The diagonal of the kernel matrix, of the transform.
  244. This consists of each element calculated with itself.
  245. """
  246. # Check if method "fit" had been called.
  247. check_is_fitted(self, ['_graphs'])
  248. # Check if the diagonals of X exist.
  249. try:
  250. check_is_fitted(self, ['_X_diag'])
  251. except NotFittedError:
  252. # Compute diagonals of X.
  253. self._X_diag = np.empty(shape=(len(self._graphs),))
  254. for i, x in enumerate(self._graphs):
  255. self._X_diag[i] = self.pairwise_kernel(x, x) # @todo: parallel?
  256. try:
  257. # If transform has happened, return both diagonals.
  258. check_is_fitted(self, ['_Y'])
  259. self._Y_diag = np.empty(shape=(len(self._Y),))
  260. for (i, y) in enumerate(self._Y):
  261. self._Y_diag[i] = self.pairwise_kernel(y, y) # @todo: parallel?
  262. return self._X_diag, self._Y_diag
  263. except NotFittedError:
  264. # Else just return both X_diag
  265. return self._X_diag
  266. # @abstractmethod
  267. def pairwise_kernel(self, x, y):
  268. """Compute pairwise kernel between two graphs.
  269. Parameters
  270. ----------
  271. x, y : NetworkX Graph.
  272. Graphs bewteen which the kernel is computed.
  273. Returns
  274. -------
  275. kernel: float
  276. The computed kernel.
  277. # Notes
  278. # -----
  279. # This method is abstract and must be implemented by a subclass.
  280. """
  281. raise NotImplementedError('Pairwise kernel computation is not implemented!')
  282. ##########################################################################
  283. # The following is the 2nd paradigm to compute kernel matrix. It is
  284. # simplified and not compatible with `scikit-learn`.
  285. ##########################################################################
  286. def compute(self, *graphs, **kwargs):
  287. self.parallel = kwargs.get('parallel', 'imap_unordered')
  288. self.n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count())
  289. self.normalize = kwargs.get('normalize', True)
  290. self.verbose = kwargs.get('verbose', 2)
  291. self.copy_graphs = kwargs.get('copy_graphs', True)
  292. self.save_unnormed = kwargs.get('save_unnormed', True)
  293. self.validate_parameters()
  294. # If the inputs is a list of graphs.
  295. if len(graphs) == 1:
  296. if not isinstance(graphs[0], list):
  297. raise Exception('Cannot detect graphs.')
  298. elif len(graphs[0]) == 0:
  299. raise Exception('The graph list given is empty. No computation was performed.')
  300. else:
  301. if self.copy_graphs:
  302. self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow.
  303. else:
  304. self._graphs = graphs
  305. self._gram_matrix = self._compute_gram_matrix()
  306. if self.save_unnormed:
  307. self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  308. if self.normalize:
  309. self._gram_matrix = normalize_gram_matrix(self._gram_matrix)
  310. return self._gram_matrix, self._run_time
  311. elif len(graphs) == 2:
  312. # If the inputs are two graphs.
  313. if self.is_graph(graphs[0]) and self.is_graph(graphs[1]):
  314. if self.copy_graphs:
  315. G0, G1 = graphs[0].copy(), graphs[1].copy()
  316. else:
  317. G0, G1 = graphs[0], graphs[1]
  318. kernel = self._compute_single_kernel(G0, G1)
  319. return kernel, self._run_time
  320. # If the inputs are a graph and a list of graphs.
  321. elif self.is_graph(graphs[0]) and isinstance(graphs[1], list):
  322. if self.copy_graphs:
  323. g1 = graphs[0].copy()
  324. g_list = [g.copy() for g in graphs[1]]
  325. kernel_list = self._compute_kernel_list(g1, g_list)
  326. else:
  327. kernel_list = self._compute_kernel_list(graphs[0], graphs[1])
  328. return kernel_list, self._run_time
  329. elif isinstance(graphs[0], list) and self.is_graph(graphs[1]):
  330. if self.copy_graphs:
  331. g1 = graphs[1].copy()
  332. g_list = [g.copy() for g in graphs[0]]
  333. kernel_list = self._compute_kernel_list(g1, g_list)
  334. else:
  335. kernel_list = self._compute_kernel_list(graphs[1], graphs[0])
  336. return kernel_list, self._run_time
  337. else:
  338. raise Exception('Cannot detect graphs.')
  339. elif len(graphs) == 0 and self._graphs is None:
  340. raise Exception('Please add graphs before computing.')
  341. else:
  342. raise Exception('Cannot detect graphs.')
  343. def normalize_gm(self, gram_matrix):
  344. import warnings
  345. warnings.warn('gklearn.kernels.graph_kernel.normalize_gm will be deprecated, use gklearn.utils.normalize_gram_matrix instead', DeprecationWarning)
  346. diag = gram_matrix.diagonal().copy()
  347. for i in range(len(gram_matrix)):
  348. for j in range(i, len(gram_matrix)):
  349. gram_matrix[i][j] /= np.sqrt(diag[i] * diag[j])
  350. gram_matrix[j][i] = gram_matrix[i][j]
  351. return gram_matrix
  352. def compute_distance_matrix(self):
  353. if self._gram_matrix is None:
  354. raise Exception('Please compute the Gram matrix before computing distance matrix.')
  355. dis_mat = np.empty((len(self._gram_matrix), len(self._gram_matrix)))
  356. for i in range(len(self._gram_matrix)):
  357. for j in range(i, len(self._gram_matrix)):
  358. dis = self._gram_matrix[i, i] + self._gram_matrix[j, j] - 2 * self._gram_matrix[i, j]
  359. if dis < 0:
  360. if dis > -1e-10:
  361. dis = 0
  362. else:
  363. raise ValueError('The distance is negative.')
  364. dis_mat[i, j] = np.sqrt(dis)
  365. dis_mat[j, i] = dis_mat[i, j]
  366. dis_max = np.max(np.max(dis_mat))
  367. dis_min = np.min(np.min(dis_mat[dis_mat != 0]))
  368. dis_mean = np.mean(np.mean(dis_mat))
  369. return dis_mat, dis_max, dis_min, dis_mean
  370. def _compute_gram_matrix(self):
  371. start_time = time.time()
  372. if self.parallel == 'imap_unordered':
  373. gram_matrix = self._compute_gm_imap_unordered()
  374. elif self.parallel is None:
  375. gram_matrix = self._compute_gm_series()
  376. else:
  377. raise Exception('Parallel mode is not set correctly.')
  378. self._run_time = time.time() - start_time
  379. if self.verbose:
  380. print('Gram matrix of size %d built in %s seconds.'
  381. % (len(self._graphs), self._run_time))
  382. return gram_matrix
  383. def _compute_gm_series(self):
  384. pass
  385. def _compute_gm_imap_unordered(self):
  386. pass
  387. def _compute_kernel_list(self, g1, g_list):
  388. start_time = time.time()
  389. if self.parallel == 'imap_unordered':
  390. kernel_list = self._compute_kernel_list_imap_unordered(g1, g_list)
  391. elif self.parallel is None:
  392. kernel_list = self._compute_kernel_list_series(g1, g_list)
  393. else:
  394. raise Exception('Parallel mode is not set correctly.')
  395. self._run_time = time.time() - start_time
  396. if self.verbose:
  397. print('Graph kernel bewteen a graph and a list of %d graphs built in %s seconds.'
  398. % (len(g_list), self._run_time))
  399. return kernel_list
  400. def _compute_kernel_list_series(self, g1, g_list):
  401. pass
  402. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  403. pass
  404. def _compute_single_kernel(self, g1, g2):
  405. start_time = time.time()
  406. kernel = self._compute_single_kernel_series(g1, g2)
  407. self._run_time = time.time() - start_time
  408. if self.verbose:
  409. print('Graph kernel bewteen two graphs built in %s seconds.' % (self._run_time))
  410. return kernel
  411. def _compute_single_kernel_series(self, g1, g2):
  412. pass
  413. def is_graph(self, graph):
  414. if isinstance(graph, nx.Graph):
  415. return True
  416. if isinstance(graph, nx.DiGraph):
  417. return True
  418. if isinstance(graph, nx.MultiGraph):
  419. return True
  420. if isinstance(graph, nx.MultiDiGraph):
  421. return True
  422. return False
  423. @property
  424. def graphs(self):
  425. return self._graphs
  426. # @property
  427. # def parallel(self):
  428. # return self.parallel
  429. # @property
  430. # def n_jobs(self):
  431. # return self.n_jobs
  432. # @property
  433. # def verbose(self):
  434. # return self.verbose
  435. # @property
  436. # def normalize(self):
  437. # return self.normalize
  438. @property
  439. def run_time(self):
  440. return self._run_time
  441. @property
  442. def gram_matrix(self):
  443. return self._gram_matrix
  444. @gram_matrix.setter
  445. def gram_matrix(self, value):
  446. self._gram_matrix = value
  447. @property
  448. def gram_matrix_unnorm(self):
  449. return self._gram_matrix_unnorm
  450. @gram_matrix_unnorm.setter
  451. def gram_matrix_unnorm(self, value):
  452. self._gram_matrix_unnorm = value

A Python package for graph kernels, graph edit distances and graph pre-image problem.