diff --git a/gklearn/kernels/common_walk.py b/gklearn/kernels/common_walk.py index f6ee71d..505dc18 100644 --- a/gklearn/kernels/common_walk.py +++ b/gklearn/kernels/common_walk.py @@ -75,9 +75,9 @@ class CommonWalk(GraphKernel): # compute Gram matrix. gram_matrix = np.zeros((len(self._graphs), len(self._graphs))) - def init_worker(gn_toshare): - global G_gn - G_gn = gn_toshare +# def init_worker(gn_toshare): +# global G_gn +# G_gn = gn_toshare # direct product graph method - exponential if self.__compute_method == 'exp': @@ -86,12 +86,17 @@ class CommonWalk(GraphKernel): elif self.__compute_method == 'geo': do_fun = self._wrapper_kernel_do_geo - parallel_gm(do_fun, gram_matrix, self._graphs, init_worker=init_worker, - glbv=(self._graphs,), n_jobs=self._n_jobs, verbose=self._verbose) + parallel_gm(do_fun, gram_matrix, self._graphs, init_worker=self._init_worker_gm, + glbv=(self._graphs,), n_jobs=self._n_jobs, verbose=self._verbose) return gram_matrix + def _init_worker_gm(gn_toshare): + global G_gn + G_gn = gn_toshare + + def _compute_kernel_list_series(self, g1, g_list): self.__check_graphs(g_list + [g1]) self.__add_dummy_labels(g_list + [g1]) @@ -130,10 +135,10 @@ class CommonWalk(GraphKernel): # compute kernel list. kernel_list = [None] * len(g_list) - def init_worker(g1_toshare, g_list_toshare): - global G_g1, G_g_list - G_g1 = g1_toshare - G_g_list = g_list_toshare +# def init_worker(g1_toshare, g_list_toshare): +# global G_g1, G_g_list +# G_g1 = g1_toshare +# G_g_list = g_list_toshare # direct product graph method - exponential if self.__compute_method == 'exp': @@ -147,12 +152,19 @@ class CommonWalk(GraphKernel): itr = range(len(g_list)) len_itr = len(g_list) parallel_me(do_fun, func_assign, kernel_list, itr, len_itr=len_itr, - init_worker=init_worker, glbv=(g1, g_list), method='imap_unordered', + init_worker=self._init_worker_list, glbv=(g1, g_list), method='imap_unordered', n_jobs=self._n_jobs, itr_desc='calculating kernels', verbose=self._verbose) return kernel_list + def _init_worker_list(g1_toshare, g_list_toshare): + global G_g1, G_g_list + G_g1 = g1_toshare + G_g_list = g_list_toshare + + + def _wrapper_kernel_list_do_exp(self, itr): return itr, self.__kernel_do_exp(G_g1, G_g_list[itr], self.__weight)