You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel.py 16 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 30 11:52:47 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import multiprocessing
  10. import time
  11. # from abc import ABC, abstractmethod
  12. from sklearn.base import BaseEstimator # , TransformerMixin
  13. from sklearn.utils.validation import check_is_fitted # check_X_y, check_array,
  14. from sklearn.exceptions import NotFittedError
  15. from gklearn.utils import normalize_gram_matrix
  16. class GraphKernel(BaseEstimator): #, ABC):
  17. """The basic graph kernel class.
  18. Attributes
  19. ----------
  20. _graphs : list
  21. Stores the input graphs on fit input data.
  22. Default format of the list objects is `NetworkX` graphs.
  23. **We don't guarantee that the input graphs remain unchanged during the
  24. computation.**
  25. References
  26. ----------
  27. https://ysig.github.io/GraKeL/0.1a8/_modules/grakel/kernels/kernel.html#Kernel.
  28. """
  29. def __init__(self,
  30. parallel=None,
  31. n_jobs=None,
  32. chunksize=None,
  33. normalize=True,
  34. copy_graphs=True, # make sure it is a full deep copy. and faster!
  35. verbose=2):
  36. """`__init__` for `GraphKernel` object."""
  37. # @todo: the default settings of the parameters are different from those in the self.compute method.
  38. # self._graphs = None
  39. self.parallel = parallel
  40. self.n_jobs = n_jobs
  41. self.chunksize = chunksize
  42. self.normalize = normalize
  43. self.copy_graphs = copy_graphs
  44. self.verbose = verbose
  45. # self._run_time = 0
  46. # self._gram_matrix = None
  47. # self._gram_matrix_unnorm = None
  48. ##########################################################################
  49. # The following is the 1st paradigm to compute kernel matrix, which is
  50. # compatible with `scikit-learn`.
  51. # -------------------------------------------------------------------
  52. # Special thanks to the "GraKeL" library for providing an excellent template!
  53. ##########################################################################
  54. def fit(self, X, y=None):
  55. """Fit a graph dataset for a transformer.
  56. Parameters
  57. ----------
  58. X : iterable
  59. DESCRIPTION.
  60. y : None, optional
  61. There is no need of a target in a transformer, yet the `scikit-learn`
  62. pipeline API requires this parameter.
  63. Returns
  64. -------
  65. object
  66. Returns self.
  67. """
  68. # self._is_tranformed = False
  69. # Clear any prior attributes stored on the estimator, # @todo: unless warm_start is used;
  70. self.clear_attributes()
  71. # Validate parameters for the transformer.
  72. self.validate_parameters()
  73. # Validate the input.
  74. self._graphs = self.validate_input(X)
  75. # self._X = X
  76. # self._kernel = self._get_kernel_instance()
  77. # Return the transformer.
  78. return self
  79. def transform(self, X=None, load_gm_train=False):
  80. """Compute the graph kernel matrix between given and fitted data.
  81. Parameters
  82. ----------
  83. X : TYPE
  84. DESCRIPTION.
  85. Raises
  86. ------
  87. ValueError
  88. DESCRIPTION.
  89. Returns
  90. -------
  91. None.
  92. """
  93. # If `load_gm_train`, load Gram matrix of training data.
  94. if load_gm_train:
  95. check_is_fitted(self, '_gm_train')
  96. self._is_transformed = True
  97. return self._gm_train # @todo: copy or not?
  98. # Check if method "fit" had been called.
  99. check_is_fitted(self, '_graphs')
  100. # Validate the input.
  101. Y = self.validate_input(X)
  102. # Transform: compute the graph kernel matrix.
  103. kernel_matrix = self.compute_kernel_matrix(Y)
  104. self._Y = Y
  105. # Self transform must appear before the diagonal call on normilization.
  106. self._is_transformed = True
  107. if self.normalize:
  108. X_diag, Y_diag = self.diagonals()
  109. old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
  110. try:
  111. kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag))
  112. except:
  113. raise
  114. finally:
  115. np.seterr(**old_settings)
  116. return kernel_matrix
  117. def fit_transform(self, X, save_gm_train=False):
  118. """Fit and transform: compute Gram matrix on the same data.
  119. Parameters
  120. ----------
  121. X : list of graphs
  122. Input graphs.
  123. Returns
  124. -------
  125. gram_matrix : numpy array, shape = [len(X), len(X)]
  126. The Gram matrix of X.
  127. """
  128. self.fit(X)
  129. # Transform: compute Gram matrix.
  130. gram_matrix = self.compute_kernel_matrix()
  131. # Normalize.
  132. if self.normalize:
  133. self._X_diag = np.diagonal(gram_matrix).copy()
  134. old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
  135. try:
  136. gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag))
  137. except:
  138. raise
  139. finally:
  140. np.seterr(**old_settings)
  141. if save_gm_train:
  142. self._gm_train = gram_matrix
  143. return gram_matrix
  144. def get_params(self):
  145. pass
  146. def set_params(self):
  147. pass
  148. def clear_attributes(self):
  149. if hasattr(self, '_X_diag'):
  150. delattr(self, '_X_diag')
  151. if hasattr(self, '_graphs'):
  152. delattr(self, '_graphs')
  153. if hasattr(self, '_Y'):
  154. delattr(self, '_Y')
  155. if hasattr(self, '_run_time'):
  156. delattr(self, '_run_time')
  157. def validate_parameters(self):
  158. """Validate all parameters for the transformer.
  159. Returns
  160. -------
  161. None.
  162. """
  163. if self.parallel is not None and self.parallel != 'imap_unordered':
  164. raise ValueError('Parallel mode is not set correctly.')
  165. if self.parallel == 'imap_unordered' and self.n_jobs is None:
  166. self.n_jobs = multiprocessing.cpu_count()
  167. def validate_input(self, X):
  168. """Validate the given input and raise errors if it is invalid.
  169. Parameters
  170. ----------
  171. X : list
  172. The input to check. Should be a list of graph.
  173. Raises
  174. ------
  175. ValueError
  176. Raise if the input is not correct.
  177. Returns
  178. -------
  179. X : list
  180. The input. A list of graph.
  181. """
  182. if X is None:
  183. raise ValueError('Please add graphs before computing.')
  184. elif not isinstance(X, list):
  185. raise ValueError('Cannot detect graphs.')
  186. elif len(X) == 0:
  187. raise ValueError('The graph list given is empty. No computation will be performed.')
  188. return X
  189. def compute_kernel_matrix(self, Y=None):
  190. """Compute the kernel matrix between a given target graphs (Y) and
  191. the fitted graphs (X / self._graphs) or the Gram matrix for the fitted
  192. graphs (X / self._graphs).
  193. Parameters
  194. ----------
  195. Y : list of graphs, optional
  196. The target graphs. The default is None. If None kernel is computed
  197. between X and itself.
  198. Returns
  199. -------
  200. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  201. The computed kernel matrix.
  202. """
  203. if Y is None:
  204. # Compute Gram matrix for self._graphs (X).
  205. kernel_matrix = self._compute_gram_matrix()
  206. # self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  207. else:
  208. # Compute kernel matrix between Y and self._graphs (X).
  209. start_time = time.time()
  210. if self.parallel == 'imap_unordered':
  211. kernel_matrix = self._compute_kernel_matrix_imap_unordered(Y)
  212. elif self.parallel is None:
  213. Y_copy = ([g.copy() for g in Y] if self.copy_graphs else Y)
  214. graphs_copy = ([g.copy() for g in self._graphs] if self.copy_graphs else self._graphs)
  215. kernel_matrix = self._compute_kernel_matrix_series(Y_copy, graphs_copy)
  216. self._run_time = time.time() - start_time
  217. if self.verbose:
  218. print('Kernel matrix of size (%d, %d) built in %s seconds.'
  219. % (len(Y), len(self._graphs), self._run_time))
  220. return kernel_matrix
  221. def _compute_kernel_matrix_series(self, X, Y):
  222. """Compute the kernel matrix between two sets of graphs (X and Y) without parallelization.
  223. Parameters
  224. ----------
  225. X, Y : list of graphs
  226. The input graphs.
  227. Returns
  228. -------
  229. kernel_matrix : numpy array, shape = [n_X, n_Y]
  230. The computed kernel matrix.
  231. """
  232. kernel_matrix = np.zeros((len(X), len(Y)))
  233. for i_x, g_x in enumerate(X):
  234. for i_y, g_y in enumerate(Y):
  235. kernel_matrix[i_x, i_y] = self.pairwise_kernel(g_x, g_y)
  236. return kernel_matrix
  237. def _compute_kernel_matrix_imap_unordered(self, Y):
  238. """Compute the kernel matrix between a given target graphs (Y) and
  239. the fitted graphs (X / self._graphs) using imap unordered parallelization.
  240. Parameters
  241. ----------
  242. Y : list of graphs, optional
  243. The target graphs.
  244. Returns
  245. -------
  246. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  247. The computed kernel matrix.
  248. """
  249. raise Exception('Parallelization for kernel matrix is not implemented.')
  250. def diagonals(self):
  251. """Compute the kernel matrix diagonals of the fit/transformed data.
  252. Returns
  253. -------
  254. X_diag : numpy array
  255. The diagonal of the kernel matrix between the fitted data.
  256. This consists of each element calculated with itself.
  257. Y_diag : numpy array
  258. The diagonal of the kernel matrix, of the transform.
  259. This consists of each element calculated with itself.
  260. """
  261. # Check if method "fit" had been called.
  262. check_is_fitted(self, ['_graphs'])
  263. # Check if the diagonals of X exist.
  264. try:
  265. check_is_fitted(self, ['_X_diag'])
  266. except NotFittedError:
  267. # Compute diagonals of X.
  268. self._X_diag = np.empty(shape=(len(self._graphs),))
  269. graphs = ([g.copy() for g in self._graphs] if self.copy_graphs else self._graphs)
  270. for i, x in enumerate(graphs):
  271. self._X_diag[i] = self.pairwise_kernel(x, x) # @todo: parallel?
  272. try:
  273. # If transform has happened, return both diagonals.
  274. check_is_fitted(self, ['_Y'])
  275. self._Y_diag = np.empty(shape=(len(self._Y),))
  276. Y = ([g.copy() for g in self._Y] if self.copy_graphs else self._Y)
  277. for (i, y) in enumerate(Y):
  278. self._Y_diag[i] = self.pairwise_kernel(y, y) # @todo: parallel?
  279. return self._X_diag, self._Y_diag
  280. except NotFittedError:
  281. # Else just return both X_diag
  282. return self._X_diag
  283. # @abstractmethod
  284. def pairwise_kernel(self, x, y):
  285. """Compute pairwise kernel between two graphs.
  286. Parameters
  287. ----------
  288. x, y : NetworkX Graph.
  289. Graphs bewteen which the kernel is computed.
  290. Returns
  291. -------
  292. kernel: float
  293. The computed kernel.
  294. # Notes
  295. # -----
  296. # This method is abstract and must be implemented by a subclass.
  297. """
  298. raise NotImplementedError('Pairwise kernel computation is not implemented!')
  299. ##########################################################################
  300. # The following is the 2nd paradigm to compute kernel matrix. It is
  301. # simplified and not compatible with `scikit-learn`.
  302. ##########################################################################
  303. def compute(self, *graphs, **kwargs):
  304. self.parallel = kwargs.get('parallel', 'imap_unordered')
  305. self.n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count())
  306. self.normalize = kwargs.get('normalize', True)
  307. self.verbose = kwargs.get('verbose', 2)
  308. self.copy_graphs = kwargs.get('copy_graphs', True)
  309. self.save_unnormed = kwargs.get('save_unnormed', True)
  310. self.validate_parameters()
  311. # If the inputs is a list of graphs.
  312. if len(graphs) == 1:
  313. if not isinstance(graphs[0], list):
  314. raise Exception('Cannot detect graphs.')
  315. elif len(graphs[0]) == 0:
  316. raise Exception('The graph list given is empty. No computation was performed.')
  317. else:
  318. if self.copy_graphs:
  319. self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow.
  320. else:
  321. self._graphs = graphs
  322. self._gram_matrix = self._compute_gram_matrix()
  323. if self.save_unnormed:
  324. self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  325. if self.normalize:
  326. self._gram_matrix = normalize_gram_matrix(self._gram_matrix)
  327. return self._gram_matrix, self._run_time
  328. elif len(graphs) == 2:
  329. # If the inputs are two graphs.
  330. if self.is_graph(graphs[0]) and self.is_graph(graphs[1]):
  331. if self.copy_graphs:
  332. G0, G1 = graphs[0].copy(), graphs[1].copy()
  333. else:
  334. G0, G1 = graphs[0], graphs[1]
  335. kernel = self._compute_single_kernel(G0, G1)
  336. return kernel, self._run_time
  337. # If the inputs are a graph and a list of graphs.
  338. elif self.is_graph(graphs[0]) and isinstance(graphs[1], list):
  339. if self.copy_graphs:
  340. g1 = graphs[0].copy()
  341. g_list = [g.copy() for g in graphs[1]]
  342. kernel_list = self._compute_kernel_list(g1, g_list)
  343. else:
  344. kernel_list = self._compute_kernel_list(graphs[0], graphs[1])
  345. return kernel_list, self._run_time
  346. elif isinstance(graphs[0], list) and self.is_graph(graphs[1]):
  347. if self.copy_graphs:
  348. g1 = graphs[1].copy()
  349. g_list = [g.copy() for g in graphs[0]]
  350. kernel_list = self._compute_kernel_list(g1, g_list)
  351. else:
  352. kernel_list = self._compute_kernel_list(graphs[1], graphs[0])
  353. return kernel_list, self._run_time
  354. else:
  355. raise Exception('Cannot detect graphs.')
  356. elif len(graphs) == 0 and self._graphs is None:
  357. raise Exception('Please add graphs before computing.')
  358. else:
  359. raise Exception('Cannot detect graphs.')
  360. def normalize_gm(self, gram_matrix):
  361. import warnings
  362. warnings.warn('gklearn.kernels.graph_kernel.normalize_gm will be deprecated, use gklearn.utils.normalize_gram_matrix instead', DeprecationWarning)
  363. diag = gram_matrix.diagonal().copy()
  364. for i in range(len(gram_matrix)):
  365. for j in range(i, len(gram_matrix)):
  366. gram_matrix[i][j] /= np.sqrt(diag[i] * diag[j])
  367. gram_matrix[j][i] = gram_matrix[i][j]
  368. return gram_matrix
  369. def compute_distance_matrix(self):
  370. if self._gram_matrix is None:
  371. raise Exception('Please compute the Gram matrix before computing distance matrix.')
  372. dis_mat = np.empty((len(self._gram_matrix), len(self._gram_matrix)))
  373. for i in range(len(self._gram_matrix)):
  374. for j in range(i, len(self._gram_matrix)):
  375. dis = self._gram_matrix[i, i] + self._gram_matrix[j, j] - 2 * self._gram_matrix[i, j]
  376. if dis < 0:
  377. if dis > -1e-10:
  378. dis = 0
  379. else:
  380. raise ValueError('The distance is negative.')
  381. dis_mat[i, j] = np.sqrt(dis)
  382. dis_mat[j, i] = dis_mat[i, j]
  383. dis_max = np.max(np.max(dis_mat))
  384. dis_min = np.min(np.min(dis_mat[dis_mat != 0]))
  385. dis_mean = np.mean(np.mean(dis_mat))
  386. return dis_mat, dis_max, dis_min, dis_mean
  387. def _compute_gram_matrix(self):
  388. start_time = time.time()
  389. if self.parallel == 'imap_unordered':
  390. gram_matrix = self._compute_gm_imap_unordered()
  391. elif self.parallel is None:
  392. graphs = ([g.copy() for g in self._graphs] if self.copy_graphs else self._graphs)
  393. gram_matrix = self._compute_gm_series(graphs)
  394. else:
  395. raise Exception('Parallel mode is not set correctly.')
  396. self._run_time = time.time() - start_time
  397. if self.verbose:
  398. print('Gram matrix of size %d built in %s seconds.'
  399. % (len(self._graphs), self._run_time))
  400. return gram_matrix
  401. def _compute_gm_series(self, graphs):
  402. pass
  403. def _compute_gm_imap_unordered(self, graphs):
  404. pass
  405. def _compute_kernel_list(self, g1, g_list):
  406. start_time = time.time()
  407. if self.parallel == 'imap_unordered':
  408. kernel_list = self._compute_kernel_list_imap_unordered(g1, g_list)
  409. elif self.parallel is None:
  410. kernel_list = self._compute_kernel_list_series(g1, g_list)
  411. else:
  412. raise Exception('Parallel mode is not set correctly.')
  413. self._run_time = time.time() - start_time
  414. if self.verbose:
  415. print('Graph kernel bewteen a graph and a list of %d graphs built in %s seconds.'
  416. % (len(g_list), self._run_time))
  417. return kernel_list
  418. def _compute_kernel_list_series(self, g1, g_list):
  419. pass
  420. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  421. pass
  422. def _compute_single_kernel(self, g1, g2):
  423. start_time = time.time()
  424. kernel = self._compute_single_kernel_series(g1, g2)
  425. self._run_time = time.time() - start_time
  426. if self.verbose:
  427. print('Graph kernel bewteen two graphs built in %s seconds.' % (self._run_time))
  428. return kernel
  429. def _compute_single_kernel_series(self, g1, g2):
  430. pass
  431. def is_graph(self, graph):
  432. if isinstance(graph, nx.Graph):
  433. return True
  434. if isinstance(graph, nx.DiGraph):
  435. return True
  436. if isinstance(graph, nx.MultiGraph):
  437. return True
  438. if isinstance(graph, nx.MultiDiGraph):
  439. return True
  440. return False
  441. @property
  442. def graphs(self):
  443. return self._graphs
  444. # @property
  445. # def parallel(self):
  446. # return self.parallel
  447. # @property
  448. # def n_jobs(self):
  449. # return self.n_jobs
  450. # @property
  451. # def verbose(self):
  452. # return self.verbose
  453. # @property
  454. # def normalize(self):
  455. # return self.normalize
  456. @property
  457. def run_time(self):
  458. return self._run_time
  459. @property
  460. def gram_matrix(self):
  461. return self._gram_matrix
  462. @gram_matrix.setter
  463. def gram_matrix(self, value):
  464. self._gram_matrix = value
  465. @property
  466. def gram_matrix_unnorm(self):
  467. return self._gram_matrix_unnorm
  468. @gram_matrix_unnorm.setter
  469. def gram_matrix_unnorm(self, value):
  470. self._gram_matrix_unnorm = value

A Python package for graph kernels, graph edit distances and graph pre-image problem.