Browse Source

[[Enhancement] gklearn.utils.normalize_gram_matrix function now raises…

v0.2.x
jajupmochi 4 years ago
parent
commit
23a938482a
1 changed files with 15 additions and 3 deletions
  1. +15
    -3
      gklearn/kernels/graph_kernel.py

+ 15
- 3
gklearn/kernels/graph_kernel.py View File

@@ -124,7 +124,13 @@ class GraphKernel(BaseEstimator): #, ABC):
self._is_transformed = True
if self.normalize:
X_diag, Y_diag = self.diagonals()
kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag))
old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
try:
kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag))
except:
raise
finally:
np.seterr(**old_settings)

return kernel_matrix

@@ -150,9 +156,15 @@ class GraphKernel(BaseEstimator): #, ABC):
gram_matrix = self.compute_kernel_matrix()

# Normalize.
self._X_diag = np.diagonal(gram_matrix).copy()
if self.normalize:
gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag))
self._X_diag = np.diagonal(gram_matrix).copy()
old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
try:
gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag))
except:
raise
finally:
np.seterr(**old_settings)

return gram_matrix



Loading…
Cancel
Save