Browse Source

[Enhancement] Allow deciding whether or not to make a copy of input graphs in GraphKernel class.

v0.2.x
jajupmochi 4 years ago
parent
commit
609c8c1518
1 changed files with 32 additions and 11 deletions
  1. +32
    -11
      gklearn/kernels/graph_kernel.py

+ 32
- 11
gklearn/kernels/graph_kernel.py View File

@@ -77,8 +77,6 @@ class GraphKernel(BaseEstimator): #, ABC):
# Clear any prior attributes stored on the estimator, # @todo: unless warm_start is used;
self.clear_attributes()

# X = check_array(X, accept_sparse=True)

# Validate parameters for the transformer.
self.validate_parameters()

@@ -386,35 +384,58 @@ class GraphKernel(BaseEstimator): #, ABC):
self.n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count())
self.normalize = kwargs.get('normalize', True)
self.verbose = kwargs.get('verbose', 2)
self.copy_graphs = kwargs.get('copy_graphs', True)
self.save_unnormed = kwargs.get('save_unnormed', True)
self.validate_parameters()

# If the inputs is a list of graphs.
if len(graphs) == 1:
if not isinstance(graphs[0], list):
raise Exception('Cannot detect graphs.')
elif len(graphs[0]) == 0:
raise Exception('The graph list given is empty. No computation was performed.')
else:
self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow.
if self.copy_graphs:
self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow.
else:
self._graphs = graphs
self._gram_matrix = self._compute_gram_matrix()
self._gram_matrix_unnorm = np.copy(self._gram_matrix)

if self.save_unnormed:
self._gram_matrix_unnorm = np.copy(self._gram_matrix)
if self.normalize:
self._gram_matrix = normalize_gram_matrix(self._gram_matrix)
return self._gram_matrix, self._run_time

elif len(graphs) == 2:
# If the inputs are two graphs.
if self.is_graph(graphs[0]) and self.is_graph(graphs[1]):
kernel = self._compute_single_kernel(graphs[0].copy(), graphs[1].copy())
if self.copy_graphs:
G0, G1 = graphs[0].copy(), graphs[1].copy()
else:
G0, G1 = graphs[0], graphs[1]
kernel = self._compute_single_kernel(G0, G1)
return kernel, self._run_time

# If the inputs are a graph and a list of graphs.
elif self.is_graph(graphs[0]) and isinstance(graphs[1], list):
g1 = graphs[0].copy()
g_list = [g.copy() for g in graphs[1]]
kernel_list = self._compute_kernel_list(g1, g_list)
if self.copy_graphs:
g1 = graphs[0].copy()
g_list = [g.copy() for g in graphs[1]]
kernel_list = self._compute_kernel_list(g1, g_list)
else:
kernel_list = self._compute_kernel_list(graphs[0], graphs[1])
return kernel_list, self._run_time

elif isinstance(graphs[0], list) and self.is_graph(graphs[1]):
g1 = graphs[1].copy()
g_list = [g.copy() for g in graphs[0]]
kernel_list = self._compute_kernel_list(g1, g_list)
if self.copy_graphs:
g1 = graphs[1].copy()
g_list = [g.copy() for g in graphs[0]]
kernel_list = self._compute_kernel_list(g1, g_list)
else:
kernel_list = self._compute_kernel_list(graphs[1], graphs[0])
return kernel_list, self._run_time

else:
raise Exception('Cannot detect graphs.')



Loading…
Cancel
Save