From 609c8c15183a9fd752cc2ccd36680520e4b22b52 Mon Sep 17 00:00:00 2001 From: jajupmochi Date: Wed, 9 Jun 2021 17:16:51 +0200 Subject: [PATCH] [Enhancement] Allow deciding whether or not to make a copy of input graphs in GraphKernel class. --- gklearn/kernels/graph_kernel.py | 43 ++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/gklearn/kernels/graph_kernel.py b/gklearn/kernels/graph_kernel.py index 6d9517f..1db38b3 100644 --- a/gklearn/kernels/graph_kernel.py +++ b/gklearn/kernels/graph_kernel.py @@ -77,8 +77,6 @@ class GraphKernel(BaseEstimator): #, ABC): # Clear any prior attributes stored on the estimator, # @todo: unless warm_start is used; self.clear_attributes() -# X = check_array(X, accept_sparse=True) - # Validate parameters for the transformer. self.validate_parameters() @@ -386,35 +384,58 @@ class GraphKernel(BaseEstimator): #, ABC): self.n_jobs = kwargs.get('n_jobs', multiprocessing.cpu_count()) self.normalize = kwargs.get('normalize', True) self.verbose = kwargs.get('verbose', 2) + self.copy_graphs = kwargs.get('copy_graphs', True) + self.save_unnormed = kwargs.get('save_unnormed', True) self.validate_parameters() + # If the inputs is a list of graphs. if len(graphs) == 1: if not isinstance(graphs[0], list): raise Exception('Cannot detect graphs.') elif len(graphs[0]) == 0: raise Exception('The graph list given is empty. No computation was performed.') else: - self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow. + if self.copy_graphs: + self._graphs = [g.copy() for g in graphs[0]] # @todo: might be very slow. + else: + self._graphs = graphs self._gram_matrix = self._compute_gram_matrix() - self._gram_matrix_unnorm = np.copy(self._gram_matrix) + + if self.save_unnormed: + self._gram_matrix_unnorm = np.copy(self._gram_matrix) if self.normalize: self._gram_matrix = normalize_gram_matrix(self._gram_matrix) return self._gram_matrix, self._run_time elif len(graphs) == 2: + # If the inputs are two graphs. if self.is_graph(graphs[0]) and self.is_graph(graphs[1]): - kernel = self._compute_single_kernel(graphs[0].copy(), graphs[1].copy()) + if self.copy_graphs: + G0, G1 = graphs[0].copy(), graphs[1].copy() + else: + G0, G1 = graphs[0], graphs[1] + kernel = self._compute_single_kernel(G0, G1) return kernel, self._run_time + + # If the inputs are a graph and a list of graphs. elif self.is_graph(graphs[0]) and isinstance(graphs[1], list): - g1 = graphs[0].copy() - g_list = [g.copy() for g in graphs[1]] - kernel_list = self._compute_kernel_list(g1, g_list) + if self.copy_graphs: + g1 = graphs[0].copy() + g_list = [g.copy() for g in graphs[1]] + kernel_list = self._compute_kernel_list(g1, g_list) + else: + kernel_list = self._compute_kernel_list(graphs[0], graphs[1]) return kernel_list, self._run_time + elif isinstance(graphs[0], list) and self.is_graph(graphs[1]): - g1 = graphs[1].copy() - g_list = [g.copy() for g in graphs[0]] - kernel_list = self._compute_kernel_list(g1, g_list) + if self.copy_graphs: + g1 = graphs[1].copy() + g_list = [g.copy() for g in graphs[0]] + kernel_list = self._compute_kernel_list(g1, g_list) + else: + kernel_list = self._compute_kernel_list(graphs[1], graphs[0]) return kernel_list, self._run_time + else: raise Exception('Cannot detect graphs.')