#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Apr 13 18:02:46 2020 @author: ljia @references: [1] Gaüzère B, Brun L, Villemin D. Two new graphs kernels in chemoinformatics. Pattern Recognition Letters. 2012 Nov 1;33(15):2038-47. """ import sys from multiprocessing import Pool from gklearn.utils import get_iters import numpy as np import networkx as nx from collections import Counter from itertools import chain from gklearn.utils import SpecialLabel from gklearn.utils.parallel import parallel_gm, parallel_me from gklearn.utils.utils import find_all_paths, get_mlti_dim_node_attrs from gklearn.kernels import GraphKernel class Treelet(GraphKernel): def __init__(self, **kwargs): GraphKernel.__init__(self) self._node_labels = kwargs.get('node_labels', []) self._edge_labels = kwargs.get('edge_labels', []) self._sub_kernel = kwargs.get('sub_kernel', None) self._ds_infos = kwargs.get('ds_infos', {}) if self._sub_kernel is None: raise Exception('Sub kernel not set.') def _compute_gm_series(self): self._add_dummy_labels(self._graphs) # get all canonical keys of all graphs before computing kernels to save # time, but this may cost a lot of memory for large dataset. canonkeys = [] iterator = get_iters(self._graphs, desc='getting canonkeys', file=sys.stdout, verbose=(self._verbose >= 2)) for g in iterator: canonkeys.append(self._get_canonkeys(g)) # compute Gram matrix. gram_matrix = np.zeros((len(self._graphs), len(self._graphs))) from itertools import combinations_with_replacement itr = combinations_with_replacement(range(0, len(self._graphs)), 2) len_itr = int(len(self._graphs) * (len(self._graphs) + 1) / 2) iterator = get_iters(itr, desc='Computing kernels', file=sys.stdout, length=len_itr, verbose=(self._verbose >= 2)) for i, j in iterator: kernel = self._kernel_do(canonkeys[i], canonkeys[j]) gram_matrix[i][j] = kernel gram_matrix[j][i] = kernel # @todo: no directed graph considered? return gram_matrix def _compute_gm_imap_unordered(self): self._add_dummy_labels(self._graphs) # get all canonical keys of all graphs before computing kernels to save # time, but this may cost a lot of memory for large dataset. pool = Pool(self._n_jobs) itr = zip(self._graphs, range(0, len(self._graphs))) if len(self._graphs) < 100 * self._n_jobs: chunksize = int(len(self._graphs) / self._n_jobs) + 1 else: chunksize = 100 canonkeys = [[] for _ in range(len(self._graphs))] get_fun = self._wrapper_get_canonkeys iterator = get_iters(pool.imap_unordered(get_fun, itr, chunksize), desc='getting canonkeys', file=sys.stdout, length=len(self._graphs), verbose=(self._verbose >= 2)) for i, ck in iterator: canonkeys[i] = ck pool.close() pool.join() # compute Gram matrix. gram_matrix = np.zeros((len(self._graphs), len(self._graphs))) def init_worker(canonkeys_toshare): global G_canonkeys G_canonkeys = canonkeys_toshare do_fun = self._wrapper_kernel_do parallel_gm(do_fun, gram_matrix, self._graphs, init_worker=init_worker, glbv=(canonkeys,), n_jobs=self._n_jobs, verbose=self._verbose) return gram_matrix def _compute_kernel_list_series(self, g1, g_list): self._add_dummy_labels(g_list + [g1]) # get all canonical keys of all graphs before computing kernels to save # time, but this may cost a lot of memory for large dataset. canonkeys_1 = self._get_canonkeys(g1) canonkeys_list = [] iterator = get_iters(g_list, desc='getting canonkeys', file=sys.stdout, verbose=(self._verbose >= 2)) for g in iterator: canonkeys_list.append(self._get_canonkeys(g)) # compute kernel list. kernel_list = [None] * len(g_list) iterator = get_iters(range(len(g_list)), desc='Computing kernels', file=sys.stdout, length=len(g_list), verbose=(self._verbose >= 2)) for i in iterator: kernel = self._kernel_do(canonkeys_1, canonkeys_list[i]) kernel_list[i] = kernel return kernel_list def _compute_kernel_list_imap_unordered(self, g1, g_list): self._add_dummy_labels(g_list + [g1]) # get all canonical keys of all graphs before computing kernels to save # time, but this may cost a lot of memory for large dataset. canonkeys_1 = self._get_canonkeys(g1) canonkeys_list = [[] for _ in range(len(g_list))] pool = Pool(self._n_jobs) itr = zip(g_list, range(0, len(g_list))) if len(g_list) < 100 * self._n_jobs: chunksize = int(len(g_list) / self._n_jobs) + 1 else: chunksize = 100 get_fun = self._wrapper_get_canonkeys iterator = get_iters(pool.imap_unordered(get_fun, itr, chunksize), desc='getting canonkeys', file=sys.stdout, length=len(g_list), verbose=(self._verbose >= 2)) for i, ck in iterator: canonkeys_list[i] = ck pool.close() pool.join() # compute kernel list. kernel_list = [None] * len(g_list) def init_worker(ck_1_toshare, ck_list_toshare): global G_ck_1, G_ck_list G_ck_1 = ck_1_toshare G_ck_list = ck_list_toshare do_fun = self._wrapper_kernel_list_do def func_assign(result, var_to_assign): var_to_assign[result[0]] = result[1] itr = range(len(g_list)) len_itr = len(g_list) parallel_me(do_fun, func_assign, kernel_list, itr, len_itr=len_itr, init_worker=init_worker, glbv=(canonkeys_1, canonkeys_list), method='imap_unordered', n_jobs=self._n_jobs, itr_desc='Computing kernels', verbose=self._verbose) return kernel_list def _wrapper_kernel_list_do(self, itr): return itr, self._kernel_do(G_ck_1, G_ck_list[itr]) def _compute_single_kernel_series(self, g1, g2): self._add_dummy_labels([g1] + [g2]) canonkeys_1 = self._get_canonkeys(g1) canonkeys_2 = self._get_canonkeys(g2) kernel = self._kernel_do(canonkeys_1, canonkeys_2) return kernel def _kernel_do(self, canonkey1, canonkey2): """Compute treelet graph kernel between 2 graphs. Parameters ---------- canonkey1, canonkey2 : list List of canonical keys in 2 graphs, where each key is represented by a string. Return ------ kernel : float Treelet kernel between 2 graphs. """ keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs vector1 = np.array([(canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys]) vector2 = np.array([(canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys]) kernel = self._sub_kernel(vector1, vector2) return kernel def _wrapper_kernel_do(self, itr): i = itr[0] j = itr[1] return i, j, self._kernel_do(G_canonkeys[i], G_canonkeys[j]) def _get_canonkeys(self, G): """Generate canonical keys of all treelets in a graph. Parameters ---------- G : NetworkX graphs The graph in which keys are generated. Return ------ canonkey/canonkey_l : dict For unlabeled graphs, canonkey is a dictionary which records amount of every tree pattern. For labeled graphs, canonkey_l is one which keeps track of amount of every treelet. """ patterns = {} # a dictionary which consists of lists of patterns for all graphlet. canonkey = {} # canonical key, a dictionary which records amount of every tree pattern. ### structural analysis ### ### In this section, a list of patterns is generated for each graphlet, ### where every pattern is represented by nodes ordered by Morgan's ### extended labeling. # linear patterns patterns['0'] = list(G.nodes()) canonkey['0'] = nx.number_of_nodes(G) for i in range(1, 6): # for i in range(1, 6): patterns[str(i)] = find_all_paths(G, i, self._ds_infos['directed']) canonkey[str(i)] = len(patterns[str(i)]) # n-star patterns patterns['3star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3] patterns['4star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4] patterns['5star'] = [[node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5] # n-star patterns canonkey['6'] = len(patterns['3star']) canonkey['8'] = len(patterns['4star']) canonkey['d'] = len(patterns['5star']) # pattern 7 patterns['7'] = [] # the 1st line of Table 1 in Ref [1] for pattern in patterns['3star']: for i in range(1, len(pattern)): # for each neighbor of node 0 if G.degree(pattern[i]) >= 2: pattern_t = pattern[:] # set the node with degree >= 2 as the 4th node pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i] for neighborx in G[pattern[i]]: if neighborx != pattern[0]: new_pattern = pattern_t + [neighborx] patterns['7'].append(new_pattern) canonkey['7'] = len(patterns['7']) # pattern 11 patterns['11'] = [] # the 4th line of Table 1 in Ref [1] for pattern in patterns['4star']: for i in range(1, len(pattern)): if G.degree(pattern[i]) >= 2: pattern_t = pattern[:] pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i] for neighborx in G[pattern[i]]: if neighborx != pattern[0]: new_pattern = pattern_t + [neighborx] patterns['11'].append(new_pattern) canonkey['b'] = len(patterns['11']) # pattern 12 patterns['12'] = [] # the 5th line of Table 1 in Ref [1] rootlist = [] # a list of root nodes, whose extended labels are 3 for pattern in patterns['3star']: if pattern[0] not in rootlist: # prevent to count the same pattern twice from each of the two root nodes rootlist.append(pattern[0]) for i in range(1, len(pattern)): if G.degree(pattern[i]) >= 3: rootlist.append(pattern[i]) pattern_t = pattern[:] pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i] for neighborx1 in G[pattern[i]]: if neighborx1 != pattern[0]: for neighborx2 in G[pattern[i]]: if neighborx1 > neighborx2 and neighborx2 != pattern[0]: new_pattern = pattern_t + [neighborx1] + [neighborx2] # new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ] patterns['12'].append(new_pattern) canonkey['c'] = int(len(patterns['12']) / 2) # pattern 9 patterns['9'] = [] # the 2nd line of Table 1 in Ref [1] for pattern in patterns['3star']: for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \ for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2]: pattern_t = pattern[:] # move nodes with extended labels 4 to specific position to correspond to their children pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])] pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])] for neighborx1 in G[pairs[0]]: if neighborx1 != pattern[0]: for neighborx2 in G[pairs[1]]: if neighborx2 != pattern[0]: new_pattern = pattern_t + [neighborx1] + [neighborx2] patterns['9'].append(new_pattern) canonkey['9'] = len(patterns['9']) # pattern 10 patterns['10'] = [] # the 3rd line of Table 1 in Ref [1] for pattern in patterns['3star']: for i in range(1, len(pattern)): if G.degree(pattern[i]) >= 2: for neighborx in G[pattern[i]]: if neighborx != pattern[0] and G.degree(neighborx) >= 2: pattern_t = pattern[:] pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i] new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ] patterns['10'].extend(new_patterns) canonkey['a'] = len(patterns['10']) ### labeling information ### ### In this section, a list of canonical keys is generated for every ### pattern obtained in the structural analysis section above, which is a ### string corresponding to a unique treelet. A dictionary is built to keep ### track of the amount of every treelet. if len(self._node_labels) > 0 or len(self._edge_labels) > 0: canonkey_l = {} # canonical key, a dictionary which keeps track of amount of every treelet. # linear patterns canonkey_t = Counter(get_mlti_dim_node_attrs(G, self._node_labels)) for key in canonkey_t: canonkey_l[('0', key)] = canonkey_t[key] for i in range(1, 6): # for i in range(1, 6): treelet = [] for pattern in patterns[str(i)]: canonlist = [] for idx, node in enumerate(pattern[:-1]): canonlist.append(tuple(G.nodes[node][nl] for nl in self._node_labels)) canonlist.append(tuple(G[node][pattern[idx+1]][el] for el in self._edge_labels)) canonlist.append(tuple(G.nodes[pattern[-1]][nl] for nl in self._node_labels)) canonkey_t = canonlist if canonlist < canonlist[::-1] else canonlist[::-1] treelet.append(tuple([str(i)] + canonkey_t)) canonkey_l.update(Counter(treelet)) # n-star patterns for i in range(3, 6): treelet = [] for pattern in patterns[str(i) + 'star']: canonlist = [] for leaf in pattern[1:]: nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels) elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels) canonlist.append(tuple((nlabels, elabels))) canonlist.sort() canonlist = list(chain.from_iterable(canonlist)) canonkey_t = tuple(['d' if i == 5 else str(i * 2)] + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + canonlist) treelet.append(canonkey_t) canonkey_l.update(Counter(treelet)) # pattern 7 treelet = [] for pattern in patterns['7']: canonlist = [] for leaf in pattern[1:3]: nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels) elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels) canonlist.append(tuple((nlabels, elabels))) canonlist.sort() canonlist = list(chain.from_iterable(canonlist)) canonkey_t = tuple(['7'] + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + canonlist + [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels)] + [tuple(G[pattern[3]][pattern[0]][el] for el in self._edge_labels)] + [tuple(G.nodes[pattern[4]][nl] for nl in self._node_labels)] + [tuple(G[pattern[4]][pattern[3]][el] for el in self._edge_labels)]) treelet.append(canonkey_t) canonkey_l.update(Counter(treelet)) # pattern 11 treelet = [] for pattern in patterns['11']: canonlist = [] for leaf in pattern[1:4]: nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels) elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels) canonlist.append(tuple((nlabels, elabels))) canonlist.sort() canonlist = list(chain.from_iterable(canonlist)) canonkey_t = tuple(['b'] + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + canonlist + [tuple(G.nodes[pattern[4]][nl] for nl in self._node_labels)] + [tuple(G[pattern[4]][pattern[0]][el] for el in self._edge_labels)] + [tuple(G.nodes[pattern[5]][nl] for nl in self._node_labels)] + [tuple(G[pattern[5]][pattern[4]][el] for el in self._edge_labels)]) treelet.append(canonkey_t) canonkey_l.update(Counter(treelet)) # pattern 10 treelet = [] for pattern in patterns['10']: canonkey4 = [tuple(G.nodes[pattern[5]][nl] for nl in self._node_labels), tuple(G[pattern[5]][pattern[4]][el] for el in self._edge_labels)] canonlist = [] for leaf in pattern[1:3]: nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels) elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels) canonlist.append(tuple((nlabels, elabels))) canonlist.sort() canonkey0 = list(chain.from_iterable(canonlist)) canonkey_t = tuple(['a'] + [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels)] + [tuple(G.nodes[pattern[4]][nl] for nl in self._node_labels)] + [tuple(G[pattern[4]][pattern[3]][el] for el in self._edge_labels)] + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + [tuple(G[pattern[0]][pattern[3]][el] for el in self._edge_labels)] + canonkey4 + canonkey0) treelet.append(canonkey_t) canonkey_l.update(Counter(treelet)) # pattern 12 treelet = [] for pattern in patterns['12']: canonlist0 = [] for leaf in pattern[1:3]: nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels) elabels = tuple(G[leaf][pattern[0]][el] for el in self._edge_labels) canonlist0.append(tuple((nlabels, elabels))) canonlist0.sort() canonlist0 = list(chain.from_iterable(canonlist0)) canonlist3 = [] for leaf in pattern[4:6]: nlabels = tuple(G.nodes[leaf][nl] for nl in self._node_labels) elabels = tuple(G[leaf][pattern[3]][el] for el in self._edge_labels) canonlist3.append(tuple((nlabels, elabels))) canonlist3.sort() canonlist3 = list(chain.from_iterable(canonlist3)) # 2 possible key can be generated from 2 nodes with extended label 3, # select the one with lower lexicographic order. canonkey_t1 = tuple(['c'] + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + canonlist0 + [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels)] + [tuple(G[pattern[3]][pattern[0]][el] for el in self._edge_labels)] + canonlist3) canonkey_t2 = tuple(['c'] + [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels)] + canonlist3 + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + [tuple(G[pattern[0]][pattern[3]][el] for el in self._edge_labels)] + canonlist0) treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2) canonkey_l.update(Counter(treelet)) # pattern 9 treelet = [] for pattern in patterns['9']: canonkey2 = [tuple(G.nodes[pattern[4]][nl] for nl in self._node_labels), tuple(G[pattern[4]][pattern[2]][el] for el in self._edge_labels)] canonkey3 = [tuple(G.nodes[pattern[5]][nl] for nl in self._node_labels), tuple(G[pattern[5]][pattern[3]][el] for el in self._edge_labels)] prekey2 = [tuple(G.nodes[pattern[2]][nl] for nl in self._node_labels), tuple(G[pattern[2]][pattern[0]][el] for el in self._edge_labels)] prekey3 = [tuple(G.nodes[pattern[3]][nl] for nl in self._node_labels), tuple(G[pattern[3]][pattern[0]][el] for el in self._edge_labels)] if prekey2 + canonkey2 < prekey3 + canonkey3: canonkey_t = [tuple(G.nodes[pattern[1]][nl] for nl in self._node_labels)] \ + [tuple(G[pattern[1]][pattern[0]][el] for el in self._edge_labels)] \ + prekey2 + prekey3 + canonkey2 + canonkey3 else: canonkey_t = [tuple(G.nodes[pattern[1]][nl] for nl in self._node_labels)] \ + [tuple(G[pattern[1]][pattern[0]][el] for el in self._edge_labels)] \ + prekey3 + prekey2 + canonkey3 + canonkey2 treelet.append(tuple(['9'] + [tuple(G.nodes[pattern[0]][nl] for nl in self._node_labels)] + canonkey_t)) canonkey_l.update(Counter(treelet)) return canonkey_l return canonkey def _wrapper_get_canonkeys(self, itr_item): g = itr_item[0] i = itr_item[1] return i, self._get_canonkeys(g) def _add_dummy_labels(self, Gn): if len(self._node_labels) == 0 or (len(self._node_labels) == 1 and self._node_labels[0] == SpecialLabel.DUMMY): for i in range(len(Gn)): nx.set_node_attributes(Gn[i], '0', SpecialLabel.DUMMY) self._node_labels = [SpecialLabel.DUMMY] if len(self._edge_labels) == 0 or (len(self._edge_labels) == 1 and self._edge_labels[0] == SpecialLabel.DUMMY): for i in range(len(Gn)): nx.set_edge_attributes(Gn[i], '0', SpecialLabel.DUMMY) self._edge_labels = [SpecialLabel.DUMMY]