Browse Source

Update: Check if all graphs have edge(s) in ShorestPath.

v0.2.x
jajupmochi 4 years ago
parent
commit
6e1372e8fa
1 changed files with 13 additions and 1 deletions
  1. +13
    -1
      gklearn/kernels/shortest_path.py

+ 13
- 1
gklearn/kernels/shortest_path.py View File

@@ -17,6 +17,7 @@ from itertools import product
from multiprocessing import Pool
from tqdm import tqdm
import numpy as np
import networkx as nx
from gklearn.utils.parallel import parallel_gm, parallel_me
from gklearn.utils.utils import getSPGraph
from gklearn.kernels import GraphKernel
@@ -35,6 +36,7 @@ class ShortestPath(GraphKernel):


def _compute_gm_series(self):
self._all_graphs_have_edges(self._graphs)
# get shortest path graph of each graph.
if self._verbose >= 2:
iterator = tqdm(self._graphs, desc='getting sp graphs', file=sys.stdout)
@@ -60,6 +62,7 @@ class ShortestPath(GraphKernel):


def _compute_gm_imap_unordered(self):
self._all_graphs_have_edges(self._graphs)
# get shortest path graph of each graph.
pool = Pool(self._n_jobs)
get_sp_graphs_fun = self._wrapper_get_sp_graphs
@@ -92,6 +95,7 @@ class ShortestPath(GraphKernel):


def _compute_kernel_list_series(self, g1, g_list):
self._all_graphs_have_edges([g1] + g_list)
# get shortest path graphs of g1 and each graph in g_list.
g1 = getSPGraph(g1, edge_weight=self._edge_weight)
if self._verbose >= 2:
@@ -114,6 +118,7 @@ class ShortestPath(GraphKernel):


def _compute_kernel_list_imap_unordered(self, g1, g_list):
self._all_graphs_have_edges([g1] + g_list)
# get shortest path graphs of g1 and each graph in g_list.
g1 = getSPGraph(g1, edge_weight=self._edge_weight)
pool = Pool(self._n_jobs)
@@ -156,6 +161,7 @@ class ShortestPath(GraphKernel):


def _compute_single_kernel_series(self, g1, g2):
self._all_graphs_have_edges([g1] + [g2])
g1 = getSPGraph(g1, edge_weight=self._edge_weight)
g2 = getSPGraph(g2, edge_weight=self._edge_weight)
kernel = self._sp_do(g1, g2)
@@ -327,4 +333,10 @@ class ShortestPath(GraphKernel):
def _wrapper_sp_do(self, itr):
i = itr[0]
j = itr[1]
return i, j, self._sp_do(G_gs[i], G_gs[j])
return i, j, self._sp_do(G_gs[i], G_gs[j])


def _all_graphs_have_edges(self, graphs):
for G in graphs:
if nx.number_of_edges(G) == 0:
raise ValueError('Not all graphs have edges!!!')

Loading…
Cancel
Save