diff --git a/lang/zh/gklearn/kernels/treeletKernel.py b/lang/zh/gklearn/kernels/treeletKernel.py new file mode 100644 index 0000000..809a623 --- /dev/null +++ b/lang/zh/gklearn/kernels/treeletKernel.py @@ -0,0 +1,490 @@ +""" +@author: linlin + +@references: + + [1] Gaüzère B, Brun L, Villemin D. Two new graphs kernels in + chemoinformatics. Pattern Recognition Letters. 2012 Nov 1;33(15):2038-47. +""" + +import sys +import time +from collections import Counter +from itertools import chain +from functools import partial +from multiprocessing import Pool +from tqdm import tqdm + +import networkx as nx +import numpy as np + +from gklearn.utils.graphdataset import get_dataset_attributes +from gklearn.utils.parallel import parallel_gm + +def treeletkernel(*args, + sub_kernel, + node_label='atom', + edge_label='bond_type', + parallel='imap_unordered', + n_jobs=None, + chunksize=None, + verbose=True): + """Calculate treelet graph kernels between graphs. + + Parameters + ---------- + Gn : List of NetworkX graph + List of graphs between which the kernels are calculated. + + G1, G2 : NetworkX graphs + Two graphs between which the kernel is calculated. + + sub_kernel : function + The sub-kernel between 2 real number vectors. Each vector counts the + numbers of isomorphic treelets in a graph. + + node_label : string + Node attribute used as label. The default node label is atom. + + edge_label : string + Edge attribute used as label. The default edge label is bond_type. + + parallel : string/None + Which paralleliztion method is applied to compute the kernel. The + Following choices are available: + + 'imap_unordered': use Python's multiprocessing.Pool.imap_unordered + method. + + None: no parallelization is applied. + + n_jobs : int + Number of jobs for parallelization. The default is to use all + computational cores. This argument is only valid when one of the + parallelization method is applied. + + Return + ------ + Kmatrix : Numpy matrix + Kernel matrix, each element of which is the treelet kernel between 2 praphs. + """ + # pre-process + Gn = args[0] if len(args) == 1 else [args[0], args[1]] + Gn = [g.copy() for g in Gn] + Kmatrix = np.zeros((len(Gn), len(Gn))) + ds_attrs = get_dataset_attributes(Gn, + attr_names=['node_labeled', 'edge_labeled', 'is_directed'], + node_label=node_label, edge_label=edge_label) + labeled = False + if ds_attrs['node_labeled'] or ds_attrs['edge_labeled']: + labeled = True + if not ds_attrs['node_labeled']: + for G in Gn: + nx.set_node_attributes(G, '0', 'atom') + if not ds_attrs['edge_labeled']: + for G in Gn: + nx.set_edge_attributes(G, '0', 'bond_type') + + start_time = time.time() + + # ---- use pool.imap_unordered to parallel and track progress. ---- + if parallel == 'imap_unordered': + # get all canonical keys of all graphs before calculating kernels to save + # time, but this may cost a lot of memory for large dataset. + pool = Pool(n_jobs) + itr = zip(Gn, range(0, len(Gn))) + if chunksize is None: + if len(Gn) < 100 * n_jobs: + chunksize = int(len(Gn) / n_jobs) + 1 + else: + chunksize = 100 + canonkeys = [[] for _ in range(len(Gn))] + get_partial = partial(wrapper_get_canonkeys, node_label, edge_label, + labeled, ds_attrs['is_directed']) + if verbose: + iterator = tqdm(pool.imap_unordered(get_partial, itr, chunksize), + desc='getting canonkeys', file=sys.stdout) + else: + iterator = pool.imap_unordered(get_partial, itr, chunksize) + for i, ck in iterator: + canonkeys[i] = ck + pool.close() + pool.join() + + # compute kernels. + def init_worker(canonkeys_toshare): + global G_canonkeys + G_canonkeys = canonkeys_toshare + do_partial = partial(wrapper_treeletkernel_do, sub_kernel) + parallel_gm(do_partial, Kmatrix, Gn, init_worker=init_worker, + glbv=(canonkeys,), n_jobs=n_jobs, chunksize=chunksize, verbose=verbose) + + # ---- do not use parallelization. ---- + elif parallel == None: + # get all canonical keys of all graphs before calculating kernels to save + # time, but this may cost a lot of memory for large dataset. + canonkeys = [] + for g in (tqdm(Gn, desc='getting canonkeys', file=sys.stdout) if verbose else Gn): + canonkeys.append(get_canonkeys(g, node_label, edge_label, labeled, + ds_attrs['is_directed'])) + + # compute kernels. + from itertools import combinations_with_replacement + itr = combinations_with_replacement(range(0, len(Gn)), 2) + for i, j in (tqdm(itr, desc='getting canonkeys', file=sys.stdout) if verbose else itr): + Kmatrix[i][j] = _treeletkernel_do(canonkeys[i], canonkeys[j], sub_kernel) + Kmatrix[j][i] = Kmatrix[i][j] # @todo: no directed graph considered? + + else: + raise Exception('No proper parallelization method designated.') + + + run_time = time.time() - start_time + if verbose: + print("\n --- treelet kernel matrix of size %d built in %s seconds ---" + % (len(Gn), run_time)) + + return Kmatrix, run_time + + +def _treeletkernel_do(canonkey1, canonkey2, sub_kernel): + """Calculate treelet graph kernel between 2 graphs. + + Parameters + ---------- + canonkey1, canonkey2 : list + List of canonical keys in 2 graphs, where each key is represented by a string. + + Return + ------ + kernel : float + Treelet Kernel between 2 graphs. + """ + keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs + vector1 = np.array([(canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys]) + vector2 = np.array([(canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys]) + kernel = sub_kernel(vector1, vector2) + return kernel + + +def wrapper_treeletkernel_do(sub_kernel, itr): + i = itr[0] + j = itr[1] + return i, j, _treeletkernel_do(G_canonkeys[i], G_canonkeys[j], sub_kernel) + + +def get_canonkeys(G, node_label, edge_label, labeled, is_directed): + """Generate canonical keys of all treelets in a graph. + + Parameters + ---------- + G : NetworkX graphs + The graph in which keys are generated. + node_label : string + node attribute used as label. The default node label is atom. + edge_label : string + edge attribute used as label. The default edge label is bond_type. + labeled : boolean + Whether the graphs are labeled. The default is True. + + Return + ------ + canonkey/canonkey_l : dict + For unlabeled graphs, canonkey is a dictionary which records amount of + every tree pattern. For labeled graphs, canonkey_l is one which keeps + track of amount of every treelet. + """ + patterns = {} # a dictionary which consists of lists of patterns for all graphlet. + canonkey = {} # canonical key, a dictionary which records amount of every tree pattern. + + ### structural analysis ### + ### In this section, a list of patterns is generated for each graphlet, + ### where every pattern is represented by nodes ordered by Morgan's + ### extended labeling. + # linear patterns + patterns['0'] = G.nodes() + canonkey['0'] = nx.number_of_nodes(G) + for i in range(1, 6): # for i in range(1, 6): + patterns[str(i)] = find_all_paths(G, i, is_directed) + canonkey[str(i)] = len(patterns[str(i)]) + + # n-star patterns + patterns['3star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3] + patterns['4star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4] + patterns['5star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5] + # n-star patterns + canonkey['6'] = len(patterns['3star']) + canonkey['8'] = len(patterns['4star']) + canonkey['d'] = len(patterns['5star']) + + # pattern 7 + patterns['7'] = [] # the 1st line of Table 1 in Ref [1] + for pattern in patterns['3star']: + for i in range(1, len(pattern)): # for each neighbor of node 0 + if G.degree(pattern[i]) >= 2: + pattern_t = pattern[:] + # set the node with degree >= 2 as the 4th node + pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i] + for neighborx in G[pattern[i]]: + if neighborx != pattern[0]: + new_pattern = pattern_t + [neighborx] + patterns['7'].append(new_pattern) + canonkey['7'] = len(patterns['7']) + + # pattern 11 + patterns['11'] = [] # the 4th line of Table 1 in Ref [1] + for pattern in patterns['4star']: + for i in range(1, len(pattern)): + if G.degree(pattern[i]) >= 2: + pattern_t = pattern[:] + pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i] + for neighborx in G[pattern[i]]: + if neighborx != pattern[0]: + new_pattern = pattern_t + [ neighborx ] + patterns['11'].append(new_pattern) + canonkey['b'] = len(patterns['11']) + + # pattern 12 + patterns['12'] = [] # the 5th line of Table 1 in Ref [1] + rootlist = [] # a list of root nodes, whose extended labels are 3 + for pattern in patterns['3star']: + if pattern[0] not in rootlist: # prevent to count the same pattern twice from each of the two root nodes + rootlist.append(pattern[0]) + for i in range(1, len(pattern)): + if G.degree(pattern[i]) >= 3: + rootlist.append(pattern[i]) + pattern_t = pattern[:] + pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i] + for neighborx1 in G[pattern[i]]: + if neighborx1 != pattern[0]: + for neighborx2 in G[pattern[i]]: + if neighborx1 > neighborx2 and neighborx2 != pattern[0]: + new_pattern = pattern_t + [neighborx1] + [neighborx2] +# new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ] + patterns['12'].append(new_pattern) + canonkey['c'] = int(len(patterns['12']) / 2) + + # pattern 9 + patterns['9'] = [] # the 2nd line of Table 1 in Ref [1] + for pattern in patterns['3star']: + for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \ + for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2 ]: + pattern_t = pattern[:] + # move nodes with extended labels 4 to specific position to correspond to their children + pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])] + pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])] + for neighborx1 in G[pairs[0]]: + if neighborx1 != pattern[0]: + for neighborx2 in G[pairs[1]]: + if neighborx2 != pattern[0]: + new_pattern = pattern_t + [neighborx1] + [neighborx2] + patterns['9'].append(new_pattern) + canonkey['9'] = len(patterns['9']) + + # pattern 10 + patterns['10'] = [] # the 3rd line of Table 1 in Ref [1] + for pattern in patterns['3star']: + for i in range(1, len(pattern)): + if G.degree(pattern[i]) >= 2: + for neighborx in G[pattern[i]]: + if neighborx != pattern[0] and G.degree(neighborx) >= 2: + pattern_t = pattern[:] + pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i] + new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ] + patterns['10'].extend(new_patterns) + canonkey['a'] = len(patterns['10']) + + ### labeling information ### + ### In this section, a list of canonical keys is generated for every + ### pattern obtained in the structural analysis section above, which is a + ### string corresponding to a unique treelet. A dictionary is built to keep + ### track of the amount of every treelet. + if labeled == True: + canonkey_l = {} # canonical key, a dictionary which keeps track of amount of every treelet. + + # linear patterns + canonkey_t = Counter(list(nx.get_node_attributes(G, node_label).values())) + for key in canonkey_t: + canonkey_l[('0', key)] = canonkey_t[key] + + for i in range(1, 6): # for i in range(1, 6): + treelet = [] + for pattern in patterns[str(i)]: + canonlist = list(chain.from_iterable((G.nodes[node][node_label], \ + G[node][pattern[idx+1]][edge_label]) for idx, node in enumerate(pattern[:-1]))) + canonlist.append(G.nodes[pattern[-1]][node_label]) + canonkey_t = canonlist if canonlist < canonlist[::-1] else canonlist[::-1] + treelet.append(tuple([str(i)] + canonkey_t)) + canonkey_l.update(Counter(treelet)) + + # n-star patterns + for i in range(3, 6): + treelet = [] + for pattern in patterns[str(i) + 'star']: + canonlist = [tuple((G.nodes[leaf][node_label], + G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:]] + canonlist.sort() + canonlist = list(chain.from_iterable(canonlist)) + canonkey_t = tuple(['d' if i == 5 else str(i * 2)] + + [G.nodes[pattern[0]][node_label]] + canonlist) + treelet.append(canonkey_t) + canonkey_l.update(Counter(treelet)) + + # pattern 7 + treelet = [] + for pattern in patterns['7']: + canonlist = [tuple((G.nodes[leaf][node_label], + G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:3]] + canonlist.sort() + canonlist = list(chain.from_iterable(canonlist)) + canonkey_t = tuple(['7'] + [G.nodes[pattern[0]][node_label]] + canonlist + + [G.nodes[pattern[3]][node_label]] + + [G[pattern[3]][pattern[0]][edge_label]] + + [G.nodes[pattern[4]][node_label]] + + [G[pattern[4]][pattern[3]][edge_label]]) + treelet.append(canonkey_t) + canonkey_l.update(Counter(treelet)) + + # pattern 11 + treelet = [] + for pattern in patterns['11']: + canonlist = [tuple((G.nodes[leaf][node_label], + G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:4]] + canonlist.sort() + canonlist = list(chain.from_iterable(canonlist)) + canonkey_t = tuple(['b'] + [G.nodes[pattern[0]][node_label]] + canonlist + + [G.nodes[pattern[4]][node_label]] + + [G[pattern[4]][pattern[0]][edge_label]] + + [G.nodes[pattern[5]][node_label]] + + [G[pattern[5]][pattern[4]][edge_label]]) + treelet.append(canonkey_t) + canonkey_l.update(Counter(treelet)) + + # pattern 10 + treelet = [] + for pattern in patterns['10']: + canonkey4 = [G.nodes[pattern[5]][node_label], G[pattern[5]][pattern[4]][edge_label]] + canonlist = [tuple((G.nodes[leaf][node_label], + G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:3]] + canonlist.sort() + canonkey0 = list(chain.from_iterable(canonlist)) + canonkey_t = tuple(['a'] + [G.nodes[pattern[3]][node_label]] + + [G.nodes[pattern[4]][node_label]] + + [G[pattern[4]][pattern[3]][edge_label]] + + [G.nodes[pattern[0]][node_label]] + + [G[pattern[0]][pattern[3]][edge_label]] + + canonkey4 + canonkey0) + treelet.append(canonkey_t) + canonkey_l.update(Counter(treelet)) + + # pattern 12 + treelet = [] + for pattern in patterns['12']: + canonlist0 = [tuple((G.nodes[leaf][node_label], + G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:3]] + canonlist0.sort() + canonlist0 = list(chain.from_iterable(canonlist0)) + canonlist3 = [tuple((G.nodes[leaf][node_label], + G[leaf][pattern[3]][edge_label])) for leaf in pattern[4:6]] + canonlist3.sort() + canonlist3 = list(chain.from_iterable(canonlist3)) + + # 2 possible key can be generated from 2 nodes with extended label 3, + # select the one with lower lexicographic order. + canonkey_t1 = tuple(['c'] + [G.nodes[pattern[0]][node_label]] + canonlist0 + + [G.nodes[pattern[3]][node_label]] + + [G[pattern[3]][pattern[0]][edge_label]] + + canonlist3) + canonkey_t2 = tuple(['c'] + [G.nodes[pattern[3]][node_label]] + canonlist3 + + [G.nodes[pattern[0]][node_label]] + + [G[pattern[0]][pattern[3]][edge_label]] + + canonlist0) + treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2) + canonkey_l.update(Counter(treelet)) + + # pattern 9 + treelet = [] + for pattern in patterns['9']: + canonkey2 = [G.nodes[pattern[4]][node_label], G[pattern[4]][pattern[2]][edge_label]] + canonkey3 = [G.nodes[pattern[5]][node_label], G[pattern[5]][pattern[3]][edge_label]] + prekey2 = [G.nodes[pattern[2]][node_label], G[pattern[2]][pattern[0]][edge_label]] + prekey3 = [G.nodes[pattern[3]][node_label], G[pattern[3]][pattern[0]][edge_label]] + if prekey2 + canonkey2 < prekey3 + canonkey3: + canonkey_t = [G.nodes[pattern[1]][node_label]] \ + + [G[pattern[1]][pattern[0]][edge_label]] \ + + prekey2 + prekey3 + canonkey2 + canonkey3 + else: + canonkey_t = [G.nodes[pattern[1]][node_label]] \ + + [G[pattern[1]][pattern[0]][edge_label]] \ + + prekey3 + prekey2 + canonkey3 + canonkey2 + treelet.append(tuple(['9'] + [G.nodes[pattern[0]][node_label]] + canonkey_t)) + canonkey_l.update(Counter(treelet)) + + return canonkey_l + + return canonkey + + +def wrapper_get_canonkeys(node_label, edge_label, labeled, is_directed, itr_item): + g = itr_item[0] + i = itr_item[1] + return i, get_canonkeys(g, node_label, edge_label, labeled, is_directed) + + +def find_paths(G, source_node, length): + """Find all paths with a certain length those start from a source node. + A recursive depth first search is applied. + + Parameters + ---------- + G : NetworkX graphs + The graph in which paths are searched. + source_node : integer + The number of the node from where all paths start. + length : integer + The length of paths. + + Return + ------ + path : list of list + List of paths retrieved, where each path is represented by a list of nodes. + """ + if length == 0: + return [[source_node]] + path = [[source_node] + path for neighbor in G[source_node] \ + for path in find_paths(G, neighbor, length - 1) if source_node not in path] + return path + + +def find_all_paths(G, length, is_directed): + """Find all paths with a certain length in a graph. A recursive depth first + search is applied. + + Parameters + ---------- + G : NetworkX graphs + The graph in which paths are searched. + length : integer + The length of paths. + + Return + ------ + path : list of list + List of paths retrieved, where each path is represented by a list of nodes. + """ + all_paths = [] + for node in G: + all_paths.extend(find_paths(G, node, length)) + + if not is_directed: + # For each path, two presentations are retrieved from its two extremities. + # Remove one of them. + all_paths_r = [path[::-1] for path in all_paths] + for idx, path in enumerate(all_paths[:-1]): + for path2 in all_paths_r[idx+1::]: + if path == path2: + all_paths[idx] = [] + break + all_paths = list(filter(lambda a: a != [], all_paths)) + + return all_paths