diff --git a/gklearn/kernels/graph_kernel.py b/gklearn/kernels/graph_kernel.py index 90a0906..6d9517f 100644 --- a/gklearn/kernels/graph_kernel.py +++ b/gklearn/kernels/graph_kernel.py @@ -124,7 +124,13 @@ class GraphKernel(BaseEstimator): #, ABC): self._is_transformed = True if self.normalize: X_diag, Y_diag = self.diagonals() - kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag)) + old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt. + try: + kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag)) + except: + raise + finally: + np.seterr(**old_settings) return kernel_matrix @@ -150,9 +156,15 @@ class GraphKernel(BaseEstimator): #, ABC): gram_matrix = self.compute_kernel_matrix() # Normalize. - self._X_diag = np.diagonal(gram_matrix).copy() if self.normalize: - gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag)) + self._X_diag = np.diagonal(gram_matrix).copy() + old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt. + try: + gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag)) + except: + raise + finally: + np.seterr(**old_settings) return gram_matrix