Browse Source

[[Enhancement] gklearn.utils.normalize_gram_matrix function now raises…

v0.2.x
jajupmochi 4 years ago
parent
commit
23a938482a
1 changed files with 15 additions and 3 deletions
  1. +15
    -3
      gklearn/kernels/graph_kernel.py

+ 15
- 3
gklearn/kernels/graph_kernel.py View File

@@ -124,7 +124,13 @@ class GraphKernel(BaseEstimator): #, ABC):
self._is_transformed = True self._is_transformed = True
if self.normalize: if self.normalize:
X_diag, Y_diag = self.diagonals() X_diag, Y_diag = self.diagonals()
kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag))
old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
try:
kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag))
except:
raise
finally:
np.seterr(**old_settings)


return kernel_matrix return kernel_matrix


@@ -150,9 +156,15 @@ class GraphKernel(BaseEstimator): #, ABC):
gram_matrix = self.compute_kernel_matrix() gram_matrix = self.compute_kernel_matrix()


# Normalize. # Normalize.
self._X_diag = np.diagonal(gram_matrix).copy()
if self.normalize: if self.normalize:
gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag))
self._X_diag = np.diagonal(gram_matrix).copy()
old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt.
try:
gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag))
except:
raise
finally:
np.seterr(**old_settings)


return gram_matrix return gram_matrix




Loading…
Cancel
Save