You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_kernel.py 15 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 30 11:52:47 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import multiprocessing
  10. import time
  11. # from abc import ABC, abstractmethod
  12. from sklearn.base import BaseEstimator # , TransformerMixin
  13. from sklearn.utils.validation import check_is_fitted # check_X_y, check_array,
  14. from sklearn.exceptions import NotFittedError
  15. from gklearn.utils import normalize_gram_matrix
  16. class GraphKernel(BaseEstimator): #, ABC):
  17. """The basic graph kernel class.
  18. Attributes
  19. ----------
  20. _graphs : list
  21. Stores the input graphs on fit input data.
  22. Default format of the list objects is `NetworkX` graphs.
  23. **We don't guarantee that the input graphs remain unchanged during the
  24. computation.**
  25. References
  26. ----------
  27. https://ysig.github.io/GraKeL/0.1a8/_modules/grakel/kernels/kernel.html#Kernel.
  28. """
  29. def __init__(self, parallel=None, n_jobs=None, chunksize=None, normalize=True, verbose=2):
  30. """`__init__` for `GraphKernel` object."""
  31. # @todo: the default settings of the parameters are different from those in the self.compute method.
  32. # self._graphs = None
  33. self.parallel = parallel
  34. self.n_jobs = n_jobs
  35. self.chunksize = chunksize
  36. self.normalize = normalize
  37. self.verbose = verbose
  38. # self._run_time = 0
  39. # self._gram_matrix = None
  40. # self._gram_matrix_unnorm = None
  41. ##########################################################################
  42. # The following is the 1st paradigm to compute kernel matrix, which is
  43. # compatible with `scikit-learn`.
  44. # -------------------------------------------------------------------
  45. # Special thanks to the "GraKeL" library for providing an excellent template!
  46. ##########################################################################
  47. def fit(self, X, y=None):
  48. """Fit a graph dataset for a transformer.
  49. Parameters
  50. ----------
  51. X : iterable
  52. DESCRIPTION.
  53. y : None, optional
  54. There is no need of a target in a transformer, yet the `scikit-learn`
  55. pipeline API requires this parameter.
  56. Returns
  57. -------
  58. object
  59. Returns self.
  60. """
  61. # self._is_tranformed = False
  62. # Clear any prior attributes stored on the estimator, # @todo: unless warm_start is used;
  63. self.clear_attributes()
  64. # X = check_array(X, accept_sparse=True)
  65. # Validate parameters for the transformer.
  66. self.validate_parameters()
  67. # Validate the input.
  68. self._graphs = self.validate_input(X)
  69. # self._X = X
  70. # self._kernel = self._get_kernel_instance()
  71. # Return the transformer.
  72. return self
  73. def transform(self, X):
  74. """Compute the graph kernel matrix between given and fitted data.
  75. Parameters
  76. ----------
  77. X : TYPE
  78. DESCRIPTION.
  79. Raises
  80. ------
  81. ValueError
  82. DESCRIPTION.
  83. Returns
  84. -------
  85. None.
  86. """
  87. # Check if method "fit" had been called.
  88. check_is_fitted(self, '_graphs')
  89. # Validate the input.
  90. Y = self.validate_input(X)
  91. # Transform: compute the graph kernel matrix.
  92. kernel_matrix = self.compute_kernel_matrix(Y)
  93. self._Y = Y
  94. # Self transform must appear before the diagonal call on normilization.
  95. self._is_transformed = True
  96. if self.normalize:
  97. X_diag, Y_diag = self.diagonals()
  98. kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag))
  99. return kernel_matrix
  100. def fit_transform(self, X):
  101. """Fit and transform: compute Gram matrix on the same data.
  102. Parameters
  103. ----------
  104. X : list of graphs
  105. Input graphs.
  106. Returns
  107. -------
  108. gram_matrix : numpy array, shape = [len(X), len(X)]
  109. The Gram matrix of X.
  110. """
  111. self.fit(X)
  112. # Transform: compute Gram matrix.
  113. gram_matrix = self.compute_kernel_matrix()
  114. # Normalize.
  115. self._X_diag = np.diagonal(gram_matrix).copy()
  116. if self.normalize:
  117. gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag))
  118. return gram_matrix
  119. def get_params(self):
  120. pass
  121. def set_params(self):
  122. pass
  123. def clear_attributes(self):
  124. if hasattr(self, '_X_diag'):
  125. delattr(self, '_X_diag')
  126. if hasattr(self, '_graphs'):
  127. delattr(self, '_graphs')
  128. if hasattr(self, '_Y'):
  129. delattr(self, '_Y')
  130. if hasattr(self, '_run_time'):
  131. delattr(self, '_run_time')
  132. def validate_parameters(self):
  133. """Validate all parameters for the transformer.
  134. Returns
  135. -------
  136. None.
  137. """
  138. if self.parallel is not None and self.parallel != 'imap_unordered':
  139. raise ValueError('Parallel mode is not set correctly.')
  140. if self.parallel == 'imap_unordered' and self.n_jobs is None:
  141. self.n_jobs = multiprocessing.cpu_count()
  142. def validate_input(self, X):
  143. """Validate the given input and raise errors if it is invalid.
  144. Parameters
  145. ----------
  146. X : list
  147. The input to check. Should be a list of graph.
  148. Raises
  149. ------
  150. ValueError
  151. Raise if the input is not correct.
  152. Returns
  153. -------
  154. X : list
  155. The input. A list of graph.
  156. """
  157. if X is None:
  158. raise ValueError('Please add graphs before computing.')
  159. elif not isinstance(X, list):
  160. raise ValueError('Cannot detect graphs.')
  161. elif len(X) == 0:
  162. raise ValueError('The graph list given is empty. No computation will be performed.')
  163. return X
  164. def compute_kernel_matrix(self, Y=None):
  165. """Compute the kernel matrix between a given target graphs (Y) and
  166. the fitted graphs (X / self._graphs) or the Gram matrix for the fitted
  167. graphs (X / self._graphs).
  168. Parameters
  169. ----------
  170. Y : list of graphs, optional
  171. The target graphs. The default is None. If None kernel is computed
  172. between X and itself.
  173. Returns
  174. -------
  175. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  176. The computed kernel matrix.
  177. """
  178. if Y is None:
  179. # Compute Gram matrix for self._graphs (X).
  180. kernel_matrix = self._compute_gram_matrix()
  181. # self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  182. else:
  183. # Compute kernel matrix between Y and self._graphs (X).
  184. start_time = time.time()
  185. if self.parallel == 'imap_unordered':
  186. kernel_matrix = self._compute_kernel_matrix_imap_unordered(Y)
  187. elif self.parallel is None:
  188. kernel_matrix = self._compute_kernel_matrix_series(Y)
  189. self._run_time = time.time() - start_time
  190. if self.verbose:
  191. print('Kernel matrix of size (%d, %d) built in %s seconds.'
  192. % (len(Y), len(self._graphs), self._run_time))
  193. return kernel_matrix
  194. def _compute_kernel_matrix_series(self, Y):
  195. """Compute the kernel matrix between a given target graphs (Y) and
  196. the fitted graphs (X / self._graphs) without parallelization.
  197. Parameters
  198. ----------
  199. Y : list of graphs, optional
  200. The target graphs.
  201. Returns
  202. -------
  203. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  204. The computed kernel matrix.
  205. """
  206. kernel_matrix = np.zeros((len(Y), len(self._graphs)))
  207. for i_y, g_y in enumerate(Y):
  208. for i_x, g_x in enumerate(self._graphs):
  209. kernel_matrix[i_y, i_x] = self.pairwise_kernel(g_y, g_x)
  210. return kernel_matrix
  211. def _compute_kernel_matrix_imap_unordered(self, Y):
  212. """Compute the kernel matrix between a given target graphs (Y) and
  213. the fitted graphs (X / self._graphs) using imap unordered parallelization.
  214. Parameters
  215. ----------
  216. Y : list of graphs, optional
  217. The target graphs.
  218. Returns
  219. -------
  220. kernel_matrix : numpy array, shape = [n_targets, n_inputs]
  221. The computed kernel matrix.
  222. """
  223. raise Exception('Parallelization for kernel matrix is not implemented.')
  224. def diagonals(self):
  225. """Compute the kernel matrix diagonals of the fit/transformed data.
  226. Returns
  227. -------
  228. X_diag : numpy array
  229. The diagonal of the kernel matrix between the fitted data.
  230. This consists of each element calculated with itself.
  231. Y_diag : numpy array
  232. The diagonal of the kernel matrix, of the transform.
  233. This consists of each element calculated with itself.
  234. """
  235. # Check if method "fit" had been called.
  236. check_is_fitted(self, ['_graphs'])
  237. # Check if the diagonals of X exist.
  238. try:
  239. check_is_fitted(self, ['_X_diag'])
  240. except NotFittedError:
  241. # Compute diagonals of X.
  242. self._X_diag = np.empty(shape=(len(self._graphs),))
  243. for i, x in enumerate(self._graphs):
  244. self._X_diag[i] = self.pairwise_kernel(x, x) # @todo: parallel?
  245. try:
  246. # If transform has happened, return both diagonals.
  247. check_is_fitted(self, ['_Y'])
  248. self._Y_diag = np.empty(shape=(len(self._Y),))
  249. for (i, y) in enumerate(self._Y):
  250. self._Y_diag[i] = self.pairwise_kernel(y, y) # @todo: parallel?
  251. return self._X_diag, self._Y_diag
  252. except NotFittedError:
  253. # Else just return both X_diag
  254. return self._X_diag
  255. # @abstractmethod
  256. def pairwise_kernel(self, x, y):
  257. """Compute pairwise kernel between two graphs.
  258. Parameters
  259. ----------
  260. x, y : NetworkX Graph.
  261. Graphs bewteen which the kernel is computed.
  262. Returns
  263. -------
  264. kernel: float
  265. The computed kernel.
  266. # Notes
  267. # -----
  268. # This method is abstract and must be implemented by a subclass.
  269. """
  270. raise NotImplementedError('Pairwise kernel computation is not implemented!')
  271. ##########################################################################
  272. # The following is the 2nd paradigm to compute kernel matrix. It is
  273. # simplified and not compatible with `scikit-learn`.
  274. ##########################################################################
  275. def compute(self, *graphs, **kwargs):
  276. self.parallel = kwargs.get('parallel', 'imap_unordered')
  277. self.n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count())
  278. self.normalize = kwargs.get('normalize', True)
  279. self.verbose = kwargs.get('verbose', 2)
  280. self.validate_parameters()
  281. if len(graphs) == 1:
  282. if not isinstance(graphs[0], list):
  283. raise Exception('Cannot detect graphs.')
  284. elif len(graphs[0]) == 0:
  285. raise Exception('The graph list given is empty. No computation was performed.')
  286. else:
  287. self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow.
  288. self._gram_matrix = self._compute_gram_matrix()
  289. self._gram_matrix_unnorm = np.copy(self._gram_matrix)
  290. if self.normalize:
  291. self._gram_matrix = normalize_gram_matrix(self._gram_matrix)
  292. return self._gram_matrix, self._run_time
  293. elif len(graphs) == 2:
  294. if self.is_graph(graphs[0]) and self.is_graph(graphs[1]):
  295. kernel = self._compute_single_kernel(graphs[0].copy(), graphs[1].copy())
  296. return kernel, self._run_time
  297. elif self.is_graph(graphs[0]) and isinstance(graphs[1], list):
  298. g1 = graphs[0].copy()
  299. g_list = [g.copy() for g in graphs[1]]
  300. kernel_list = self._compute_kernel_list(g1, g_list)
  301. return kernel_list, self._run_time
  302. elif isinstance(graphs[0], list) and self.is_graph(graphs[1]):
  303. g1 = graphs[1].copy()
  304. g_list = [g.copy() for g in graphs[0]]
  305. kernel_list = self._compute_kernel_list(g1, g_list)
  306. return kernel_list, self._run_time
  307. else:
  308. raise Exception('Cannot detect graphs.')
  309. elif len(graphs) == 0 and self._graphs is None:
  310. raise Exception('Please add graphs before computing.')
  311. else:
  312. raise Exception('Cannot detect graphs.')
  313. def normalize_gm(self, gram_matrix):
  314. import warnings
  315. warnings.warn('gklearn.kernels.graph_kernel.normalize_gm will be deprecated, use gklearn.utils.normalize_gram_matrix instead', DeprecationWarning)
  316. diag = gram_matrix.diagonal().copy()
  317. for i in range(len(gram_matrix)):
  318. for j in range(i, len(gram_matrix)):
  319. gram_matrix[i][j] /= np.sqrt(diag[i] * diag[j])
  320. gram_matrix[j][i] = gram_matrix[i][j]
  321. return gram_matrix
  322. def compute_distance_matrix(self):
  323. if self._gram_matrix is None:
  324. raise Exception('Please compute the Gram matrix before computing distance matrix.')
  325. dis_mat = np.empty((len(self._gram_matrix), len(self._gram_matrix)))
  326. for i in range(len(self._gram_matrix)):
  327. for j in range(i, len(self._gram_matrix)):
  328. dis = self._gram_matrix[i, i] + self._gram_matrix[j, j] - 2 * self._gram_matrix[i, j]
  329. if dis < 0:
  330. if dis > -1e-10:
  331. dis = 0
  332. else:
  333. raise ValueError('The distance is negative.')
  334. dis_mat[i, j] = np.sqrt(dis)
  335. dis_mat[j, i] = dis_mat[i, j]
  336. dis_max = np.max(np.max(dis_mat))
  337. dis_min = np.min(np.min(dis_mat[dis_mat != 0]))
  338. dis_mean = np.mean(np.mean(dis_mat))
  339. return dis_mat, dis_max, dis_min, dis_mean
  340. def _compute_gram_matrix(self):
  341. start_time = time.time()
  342. if self.parallel == 'imap_unordered':
  343. gram_matrix = self._compute_gm_imap_unordered()
  344. elif self.parallel is None:
  345. gram_matrix = self._compute_gm_series()
  346. else:
  347. raise Exception('Parallel mode is not set correctly.')
  348. self._run_time = time.time() - start_time
  349. if self.verbose:
  350. print('Gram matrix of size %d built in %s seconds.'
  351. % (len(self._graphs), self._run_time))
  352. return gram_matrix
  353. def _compute_gm_series(self):
  354. pass
  355. def _compute_gm_imap_unordered(self):
  356. pass
  357. def _compute_kernel_list(self, g1, g_list):
  358. start_time = time.time()
  359. if self.parallel == 'imap_unordered':
  360. kernel_list = self._compute_kernel_list_imap_unordered(g1, g_list)
  361. elif self.parallel is None:
  362. kernel_list = self._compute_kernel_list_series(g1, g_list)
  363. else:
  364. raise Exception('Parallel mode is not set correctly.')
  365. self._run_time = time.time() - start_time
  366. if self.verbose:
  367. print('Graph kernel bewteen a graph and a list of %d graphs built in %s seconds.'
  368. % (len(g_list), self._run_time))
  369. return kernel_list
  370. def _compute_kernel_list_series(self, g1, g_list):
  371. pass
  372. def _compute_kernel_list_imap_unordered(self, g1, g_list):
  373. pass
  374. def _compute_single_kernel(self, g1, g2):
  375. start_time = time.time()
  376. kernel = self._compute_single_kernel_series(g1, g2)
  377. self._run_time = time.time() - start_time
  378. if self.verbose:
  379. print('Graph kernel bewteen two graphs built in %s seconds.' % (self._run_time))
  380. return kernel
  381. def _compute_single_kernel_series(self, g1, g2):
  382. pass
  383. def is_graph(self, graph):
  384. if isinstance(graph, nx.Graph):
  385. return True
  386. if isinstance(graph, nx.DiGraph):
  387. return True
  388. if isinstance(graph, nx.MultiGraph):
  389. return True
  390. if isinstance(graph, nx.MultiDiGraph):
  391. return True
  392. return False
  393. @property
  394. def graphs(self):
  395. return self._graphs
  396. # @property
  397. # def parallel(self):
  398. # return self.parallel
  399. # @property
  400. # def n_jobs(self):
  401. # return self.n_jobs
  402. # @property
  403. # def verbose(self):
  404. # return self.verbose
  405. # @property
  406. # def normalize(self):
  407. # return self.normalize
  408. @property
  409. def run_time(self):
  410. return self._run_time
  411. @property
  412. def gram_matrix(self):
  413. return self._gram_matrix
  414. @gram_matrix.setter
  415. def gram_matrix(self, value):
  416. self._gram_matrix = value
  417. @property
  418. def gram_matrix_unnorm(self):
  419. return self._gram_matrix_unnorm
  420. @gram_matrix_unnorm.setter
  421. def gram_matrix_unnorm(self, value):
  422. self._gram_matrix_unnorm = value

A Python package for graph kernels, graph edit distances and graph pre-image problem.