From 23a938482abdf4412d94c3f4ea560d8083704906 Mon Sep 17 00:00:00 2001 From: jajupmochi Date: Tue, 25 May 2021 14:19:37 +0200 Subject: [PATCH] =?UTF-8?q?[[Enhancement]=20gklearn.utils.normalize=5Fgram?= =?UTF-8?q?=5Fmatrix=20function=20now=20raises=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gklearn/kernels/graph_kernel.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/gklearn/kernels/graph_kernel.py b/gklearn/kernels/graph_kernel.py index 90a0906..6d9517f 100644 --- a/gklearn/kernels/graph_kernel.py +++ b/gklearn/kernels/graph_kernel.py @@ -124,7 +124,13 @@ class GraphKernel(BaseEstimator): #, ABC): self._is_transformed = True if self.normalize: X_diag, Y_diag = self.diagonals() - kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag)) + old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt. + try: + kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag)) + except: + raise + finally: + np.seterr(**old_settings) return kernel_matrix @@ -150,9 +156,15 @@ class GraphKernel(BaseEstimator): #, ABC): gram_matrix = self.compute_kernel_matrix() # Normalize. - self._X_diag = np.diagonal(gram_matrix).copy() if self.normalize: - gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag)) + self._X_diag = np.diagonal(gram_matrix).copy() + old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt. + try: + gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag)) + except: + raise + finally: + np.seterr(**old_settings) return gram_matrix