|
|
@@ -124,7 +124,13 @@ class GraphKernel(BaseEstimator): #, ABC): |
|
|
|
self._is_transformed = True |
|
|
|
if self.normalize: |
|
|
|
X_diag, Y_diag = self.diagonals() |
|
|
|
kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag)) |
|
|
|
old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt. |
|
|
|
try: |
|
|
|
kernel_matrix /= np.sqrt(np.outer(Y_diag, X_diag)) |
|
|
|
except: |
|
|
|
raise |
|
|
|
finally: |
|
|
|
np.seterr(**old_settings) |
|
|
|
|
|
|
|
return kernel_matrix |
|
|
|
|
|
|
@@ -150,9 +156,15 @@ class GraphKernel(BaseEstimator): #, ABC): |
|
|
|
gram_matrix = self.compute_kernel_matrix() |
|
|
|
|
|
|
|
# Normalize. |
|
|
|
self._X_diag = np.diagonal(gram_matrix).copy() |
|
|
|
if self.normalize: |
|
|
|
gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag)) |
|
|
|
self._X_diag = np.diagonal(gram_matrix).copy() |
|
|
|
old_settings = np.seterr(invalid='raise') # Catch FloatingPointError: invalid value encountered in sqrt. |
|
|
|
try: |
|
|
|
gram_matrix /= np.sqrt(np.outer(self._X_diag, self._X_diag)) |
|
|
|
except: |
|
|
|
raise |
|
|
|
finally: |
|
|
|
np.seterr(**old_settings) |
|
|
|
|
|
|
|
return gram_matrix |
|
|
|
|
|
|
|