|
|
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- " --- This is a regression problem ---\n",
- "\n",
- "\n",
- " Loading dataset from file...\n",
- "\n",
- " Calculating kernel matrix, this could take a while...\n",
- "\n",
- " --- treelet kernel matrix of size 185 built in 0.47543811798095703 seconds ---\n",
- "[[4.00000000e+00 2.60653066e+00 1.00000000e+00 ... 1.26641655e-14\n",
- " 1.26641655e-14 1.26641655e-14]\n",
- " [2.60653066e+00 6.00000000e+00 1.00000000e+00 ... 1.26641655e-14\n",
- " 1.26641655e-14 1.26641655e-14]\n",
- " [1.00000000e+00 1.00000000e+00 4.00000000e+00 ... 3.00000000e+00\n",
- " 3.00000000e+00 3.00000000e+00]\n",
- " ...\n",
- " [1.26641655e-14 1.26641655e-14 3.00000000e+00 ... 1.80000000e+01\n",
- " 1.30548713e+01 8.19020657e+00]\n",
- " [1.26641655e-14 1.26641655e-14 3.00000000e+00 ... 1.30548713e+01\n",
- " 2.20000000e+01 9.71901120e+00]\n",
- " [1.26641655e-14 1.26641655e-14 3.00000000e+00 ... 8.19020657e+00\n",
- " 9.71901120e+00 1.60000000e+01]]\n",
- "\n",
- " Starting calculate accuracy/rmse...\n",
- "calculate performance: 98%|█████████▊| 983/1000 [00:01<00:00, 796.45it/s]\n",
- " Mean performance on train set: 2.688029\n",
- "With standard deviation: 1.541623\n",
- "\n",
- " Mean performance on test set: 10.099738\n",
- "With standard deviation: 5.035844\n",
- "calculate performance: 100%|██████████| 1000/1000 [00:01<00:00, 745.11it/s]\n",
- "\n",
- "\n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 10.0997 5.03584 2.68803 1.54162 0.475438\n"
- ]
- }
- ],
- "source": [
- "%load_ext line_profiler\n",
- "\n",
- "import sys\n",
- "sys.path.insert(0, \"../\")\n",
- "from pygraph.utils.utils import kernel_train_test\n",
- "from pygraph.kernels.treeletKernel import treeletkernel\n",
- "\n",
- "datafile = '../../../../datasets/acyclic/Acyclic/dataset_bps.ds'\n",
- "kernel_file_path = 'kernelmatrices_path_acyclic/'\n",
- "\n",
- "kernel_para = dict(node_label = 'atom', edge_label = 'bond_type', labeled = True)\n",
- "\n",
- "kernel_train_test(datafile, kernel_file_path, treeletkernel, kernel_para, normalize = False)\n",
- "\n",
- "# %lprun -f treeletkernel \\\n",
- "# kernel_train_test(datafile, kernel_file_path, treeletkernel, kernel_para, normalize = False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# results\n",
- "\n",
- "# with y normalization\n",
- " RMSE_test std_test RMSE_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 8.3079 3.37838 2.90887 1.2679 0.500302\n",
- "\n",
- "# without y normalization\n",
- " RMSE_test std_test RMSE_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 10.0997 5.03584 2.68803 1.54162 0.484171\n",
- "\n",
- " \n",
- "\n",
- "# G0 -> WL subtree h = 0\n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 13.9223 2.88611 13.373 0.653301 0.186731\n",
- "\n",
- "# G0 U G1 U G6 U G8 U G13 -> WL subtree h = 1\n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 8.97706 2.90771 6.7343 1.17505 0.223171\n",
- " \n",
- "# all patterns \\ { G3 U G4 U G5 U G10 } -> WL subtree h = 2 \n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 7.31274 1.96289 3.73909 0.406267 0.294902\n",
- "\n",
- "# all patterns \\ { G4 U G5 } -> WL subtree h = 3\n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 8.39977 2.78309 3.8606 1.58686 0.348912\n",
- "\n",
- "# all patterns \\ { G5 } \n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 9.47647 4.22113 3.18029 1.5669 0.423638\n",
- " \n",
- " \n",
- " \n",
- "# G0, -> WL subtree h = 0\n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 13.9223 2.88611 13.373 0.653301 0.186731 \n",
- " \n",
- "# G0 U G1 U G2 U G6 U G8 U G13 -> WL subtree h = 1\n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 8.62431 2.54327 5.63422 0.255002 0.290797\n",
- " \n",
- "# all patterns \\ { G5 U G10 } -> WL subtree h = 2\n",
- " rmse_test std_test rmse_train std_train k_time\n",
- "----------- ---------- ------------ ----------- --------\n",
- " 10.1294 3.50275 3.69664 1.55116 0.418498"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{0: 'C', 1: 'C', 2: 'C', 3: 'C', 4: 'C', 5: 'O', 6: 'O'}\n"
- ]
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<matplotlib.figure.Figure at 0x7f23a4d68ef0>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{0: 'C', 1: 'C', 2: 'C', 3: 'C', 4: 'C', 5: 'C', 6: 'O', 7: 'O'}\n"
- ]
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<matplotlib.figure.Figure at 0x7f23a02cfac8>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- " pattern 0: [0, 1, 2, 3, 4, 5, 6, 7]\n",
- " treelet 0: ['C', 'C', 'C', 'C', 'C', 'C', 'O', 'O']\n",
- "\n",
- " pattern 1 : [[4, 0], [4, 1], [5, 4], [6, 2], [6, 5], [7, 3], [7, 5]]\n",
- " treelet 1 : ['1C1C', '1C1C', '1C1C', '1C1O', '1C1O', '1C1O', '1C1O']\n",
- "\n",
- " pattern 2 : [[1, 4, 0], [5, 4, 0], [5, 4, 1], [5, 6, 2], [5, 7, 3], [6, 5, 4], [7, 5, 4], [7, 5, 6]]\n",
- " treelet 2 : ['2C1C1C', '2C1C1C', '2C1C1C', '2C1O1C', '2C1O1C', '2C1C1O', '2C1C1O', '2O1C1O']\n",
- "\n",
- " pattern 3 : [[4, 5, 6, 2], [4, 5, 7, 3], [6, 5, 4, 0], [6, 5, 4, 1], [6, 5, 7, 3], [7, 5, 4, 0], [7, 5, 4, 1], [7, 5, 6, 2]]\n",
- " treelet 3 : ['3C1C1O1C', '3C1C1O1C', '3C1C1C1O', '3C1C1C1O', '3C1O1C1O', '3C1C1C1O', '3C1C1C1O', '3C1O1C1O']\n",
- "\n",
- " pattern 4 : [[2, 6, 5, 4, 0], [2, 6, 5, 4, 1], [3, 7, 5, 4, 0], [3, 7, 5, 4, 1], [3, 7, 5, 6, 2]]\n",
- " treelet 4 : ['4C1C1C1O1C', '4C1C1C1O1C', '4C1C1C1O1C', '4C1C1C1O1C', '4C1O1C1O1C']\n",
- "\n",
- " pattern 5 : []\n",
- " treelet 5 : []\n",
- "\n",
- " pattern 3 star: [[4, 0, 1, 5], [5, 4, 6, 7]]\n",
- " treelet 3 star: ['6CC1C1C1', '6CC1O1O1']\n",
- "\n",
- " pattern 4 star: []\n",
- " treelet 4 star: []\n",
- "\n",
- " pattern 5 star: []\n",
- " treelet 5 star: []\n",
- "\n",
- " pattern 7: [[4, 0, 1, 5, 6], [4, 0, 1, 5, 7], [5, 7, 6, 4, 0], [5, 7, 6, 4, 1], [5, 4, 7, 6, 2], [5, 4, 6, 7, 3]]\n",
- " treelet 7: ['7CC1C1C1O1', '7CC1C1C1O1', '7CO1O1C1C1', '7CO1O1C1C1', '7CC1O1O1C1', '7CC1O1O1C1']\n",
- "\n",
- " pattern 11: []\n",
- " treelet 11: []\n",
- "\n",
- " pattern 10: [[4, 0, 1, 5, 6, 2], [4, 0, 1, 5, 7, 3]]\n",
- " treelet 10: ['aCO1C1C1C1C1', 'aCO1C1C1C1C1']\n",
- "\n",
- " pattern 12: [[4, 0, 1, 5, 7, 6]]\n",
- " treelet 12: ['cCC1C1C1O1O1']\n",
- "\n",
- " pattern 9: [[5, 7, 6, 4, 2, 0], [5, 7, 6, 4, 2, 1], [5, 6, 7, 4, 3, 0], [5, 6, 7, 4, 3, 1], [5, 4, 7, 6, 3, 2]]\n",
- " treelet 9: ['9CO1C1O1C1C1', '9CO1C1O1C1C1', '9CO1C1O1C1C1', '9CO1C1O1C1C1', '9CC1O1O1C1C1']\n",
- "\n",
- " numbers of canonical keys: {'2O1C1O': 1, '7CC1C1C1O1': 2, '7CC1O1O1C1': 2, 'aCO1C1C1C1C1': 2, '2C1C1C': 3, '6CC1C1C1': 1, '9CO1C1O1C1C1': 4, '1C1C': 3, '3C1C1C1O': 4, '4C1C1C1O1C': 4, '7CO1O1C1C1': 2, '2C1C1O': 2, '1C1O': 4, '9CC1O1O1C1C1': 1, '3C1C1O1C': 2, '6CC1O1O1': 1, '2C1O1C': 2, '0O': 2, '4C1O1C1O1C': 1, 'cCC1C1C1O1O1': 1, '0C': 6, '3C1O1C1O': 2}\n",
- "\n",
- " pattern 0: [0, 1, 2, 3, 4, 5, 6]\n",
- " treelet 0: ['C', 'C', 'C', 'C', 'C', 'O', 'O']\n",
- "\n",
- " pattern 1 : [[2, 0], [3, 1], [5, 2], [5, 4], [6, 3], [6, 4]]\n",
- " treelet 1 : ['1C1C', '1C1C', '1C1O', '1C1O', '1C1O', '1C1O']\n",
- "\n",
- " pattern 2 : [[4, 5, 2], [4, 6, 3], [5, 2, 0], [6, 3, 1], [6, 4, 5]]\n",
- " treelet 2 : ['2C1O1C', '2C1O1C', '2C1C1O', '2C1C1O', '2O1C1O']\n",
- "\n",
- " pattern 3 : [[4, 5, 2, 0], [4, 6, 3, 1], [5, 4, 6, 3], [6, 4, 5, 2]]\n",
- " treelet 3 : ['3C1C1O1C', '3C1C1O1C', '3C1O1C1O', '3C1O1C1O']\n",
- "\n",
- " pattern 4 : [[3, 6, 4, 5, 2], [5, 4, 6, 3, 1], [6, 4, 5, 2, 0]]\n",
- " treelet 4 : ['4C1O1C1O1C', '4C1C1O1C1O', '4C1C1O1C1O']\n",
- "\n",
- " pattern 5 : [[2, 5, 4, 6, 3, 1], [3, 6, 4, 5, 2, 0]]\n",
- " treelet 5 : ['5C1C1O1C1O1C', '5C1C1O1C1O1C']\n",
- "\n",
- " pattern 3 star: []\n",
- " treelet 3 star: []\n",
- "\n",
- " pattern 4 star: []\n",
- " treelet 4 star: []\n",
- "\n",
- " pattern 5 star: []\n",
- " treelet 5 star: []\n",
- "\n",
- " pattern 7: []\n",
- " treelet 7: []\n",
- "\n",
- " pattern 11: []\n",
- " treelet 11: []\n",
- "\n",
- " pattern 10: []\n",
- " treelet 10: []\n",
- "\n",
- " pattern 12: []\n",
- " treelet 12: []\n",
- "\n",
- " pattern 9: []\n",
- " treelet 9: []\n",
- "\n",
- " numbers of canonical keys: {'3C1C1O1C': 2, '2O1C1O': 1, '1C1O': 4, '2C1O1C': 2, '0O': 2, '5C1C1O1C1O1C': 2, '1C1C': 2, '4C1O1C1O1C': 1, '0C': 5, '3C1O1C1O': 2, '4C1C1O1C1O': 2, '2C1C1O': 2}\n"
- ]
- }
- ],
- "source": [
- "import sys\n",
- "import pathlib\n",
- "from collections import Counter\n",
- "from itertools import chain\n",
- "sys.path.insert(0, \"../\")\n",
- "\n",
- "import networkx as nx\n",
- "import numpy as np\n",
- "import time\n",
- "\n",
- "from sklearn.metrics.pairwise import rbf_kernel, paired_distances\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "# main\n",
- "import sys\n",
- "from collections import Counter\n",
- "import networkx as nx\n",
- "sys.path.insert(0, \"../\")\n",
- "from pygraph.utils.graphfiles import loadDataset\n",
- "\n",
- "\n",
- "def main(): \n",
- " dataset, y = loadDataset(\"../../../../datasets/acyclic/Acyclic/dataset_bps.ds\")\n",
- " G1 = dataset[15]\n",
- " print(nx.get_node_attributes(G1, 'label'))\n",
- " nx.draw_networkx(G1)\n",
- " plt.show()\n",
- " G2 = dataset[57] # 180 double 4, 57, 3, double 3\n",
- " print(nx.get_node_attributes(G2, 'label'))\n",
- " nx.draw_networkx(G2)\n",
- " plt.show()\n",
- "\n",
- " treeletkernel(G1, G2, labeled = True)\n",
- " # Kmatrix = weisfeilerlehmankernel(G1, G2)\n",
- " \n",
- "def find_paths(G, source_node, length):\n",
- " if length == 0:\n",
- " return [[source_node]]\n",
- " path = [ [source_node] + path for neighbor in G[source_node] \\\n",
- " for path in find_paths(G, neighbor, length - 1) if source_node not in path ]\n",
- " return path\n",
- "\n",
- "def find_all_paths(G, length):\n",
- " all_paths = []\n",
- " for node in G:\n",
- " all_paths.extend(find_paths(G, node, length))\n",
- " all_paths_r = [ path[::-1] for path in all_paths ]\n",
- " \n",
- " # remove double direction\n",
- " for idx, path in enumerate(all_paths[:-1]):\n",
- " for path2 in all_paths_r[idx+1::]:\n",
- " if path == path2:\n",
- " all_paths[idx] = []\n",
- " break\n",
- " \n",
- " return list(filter(lambda a: a != [], all_paths))\n",
- "\n",
- "def get_canonkey(G, node_label = 'atom', edge_label = 'bond_type', labeled = True):\n",
- " \n",
- " patterns = {}\n",
- " canonkey = {} # canonical key\n",
- " \n",
- " ### structural analysis ###\n",
- " # linear patterns\n",
- " patterns['0'] = G.nodes()\n",
- " canonkey['0'] = nx.number_of_nodes(G)\n",
- " for i in range(1, 6):\n",
- " patterns[str(i)] = find_all_paths(G, i)\n",
- " canonkey[str(i)] = len(patterns[str(i)])\n",
- " \n",
- " # n-star patterns\n",
- " patterns['3star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3 ]\n",
- " patterns['4star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4 ]\n",
- " patterns['5star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5 ] \n",
- " # n-star patterns\n",
- " canonkey['6'] = len(patterns['3star'])\n",
- " canonkey['8'] = len(patterns['4star'])\n",
- " canonkey['d'] = len(patterns['5star'])\n",
- " \n",
- " # pattern 7\n",
- " patterns['7'] = []\n",
- " for pattern in patterns['3star']:\n",
- " for i in range(1, len(pattern)):\n",
- " if G.degree(pattern[i]) >= 2:\n",
- " pattern_t = pattern[:]\n",
- " pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]\n",
- " for neighborx in G[pattern[i]]:\n",
- " if neighborx != pattern[0]:\n",
- " new_pattern = pattern_t + [ neighborx ]\n",
- "# new_patterns = [ pattern + [neighbor] for neighbor in G[pattern[i]] if neighbor != pattern[0] ]\n",
- " patterns['7'].append(new_pattern)\n",
- " canonkey['7'] = len(patterns['7'])\n",
- " \n",
- " # pattern 11\n",
- " patterns['11'] = []\n",
- " for pattern in patterns['4star']:\n",
- " for i in range(1, len(pattern)):\n",
- " if G.degree(pattern[i]) >= 2:\n",
- " pattern_t = pattern[:]\n",
- " pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i]\n",
- " for neighborx in G[pattern[i]]:\n",
- " if neighborx != pattern[0]:\n",
- " new_pattern = pattern_t + [ neighborx ]\n",
- "# new_patterns = [ pattern + [neighborx] for neighborx in G[pattern[i]] if neighborx != pattern[0] ]\n",
- " patterns['11'].append(new_pattern)\n",
- " canonkey['b'] = len(patterns['11'])\n",
- " \n",
- " # pattern 12\n",
- " patterns['12'] = []\n",
- " rootlist = []\n",
- " for pattern in patterns['3star']:\n",
- "# print(pattern)\n",
- " if pattern[0] not in rootlist:\n",
- " rootlist.append(pattern[0])\n",
- " for i in range(1, len(pattern)):\n",
- " if G.degree(pattern[i]) >= 3:\n",
- " rootlist.append(pattern[i])\n",
- " pattern_t = pattern[:]\n",
- " pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]\n",
- " for neighborx1 in G[pattern[i]]:\n",
- " if neighborx1 != pattern[0]:\n",
- " for neighborx2 in G[pattern[i]]:\n",
- " if neighborx1 > neighborx2 and neighborx2 != pattern[0]:\n",
- " new_pattern = pattern_t + [neighborx1] + [neighborx2]\n",
- "# new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ]\n",
- " patterns['12'].append(new_pattern)\n",
- " canonkey['c'] = int(len(patterns['12']) / 2)\n",
- " \n",
- " # pattern 9\n",
- " patterns['9'] = []\n",
- " for pattern in patterns['3star']:\n",
- "# print('pattern: ', pattern)\n",
- " for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \\\n",
- " for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2 ]:\n",
- "# print('pairs: ', pairs)\n",
- " pattern_t = pattern[:]\n",
- "# print('pattern_t: ', pattern_t)\n",
- " pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])]\n",
- "# print('pattern_t: ', pattern_t)\n",
- " pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])]\n",
- "# print('pattern_t: ', pattern_t)\n",
- " for neighborx1 in G[pairs[0]]:\n",
- " if neighborx1 != pattern[0]:\n",
- " for neighborx2 in G[pairs[1]]:\n",
- " if neighborx2 != pattern[0]:\n",
- " new_pattern = pattern_t + [neighborx1] + [neighborx2]\n",
- "# new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pairs[0]] if neighborx1 != pattern[0] for neighborx2 in G[pairs[1]] if neighborx2 != pattern[0] ]\n",
- " patterns['9'].append(new_pattern)\n",
- " canonkey['9'] = len(patterns['9'])\n",
- " \n",
- " # pattern 10\n",
- " patterns['10'] = []\n",
- " for pattern in patterns['3star']: \n",
- " for i in range(1, len(pattern)):\n",
- " if G.degree(pattern[i]) >= 2:\n",
- " for neighborx in G[pattern[i]]:\n",
- " if neighborx != pattern[0] and G.degree(neighborx) >= 2:\n",
- " pattern_t = pattern[:]\n",
- " pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]\n",
- " new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ]\n",
- " patterns['10'].extend(new_patterns)\n",
- " canonkey['a'] = len(patterns['10'])\n",
- " \n",
- " ### labeling information ###\n",
- " if labeled == True:\n",
- " canonkey_l = {}\n",
- " \n",
- " # linear patterns\n",
- " canonkey_t = Counter(list(nx.get_node_attributes(G, node_label).values()))\n",
- " for key in canonkey_t:\n",
- " canonkey_l['0' + key] = canonkey_t[key]\n",
- " print('\\n pattern 0: ', patterns['0'])\n",
- " print(' treelet 0: ', list(nx.get_node_attributes(G, node_label).values()))\n",
- " \n",
- " for i in range(1, 6):\n",
- " treelet = []\n",
- " for pattern in patterns[str(i)]:\n",
- " canonlist = list(chain.from_iterable((G.node[node][node_label], \\\n",
- " G[node][pattern[idx+1]][edge_label]) for idx, node in enumerate(pattern[:-1])))\n",
- " canonlist.append(G.node[pattern[-1]][node_label])\n",
- " canonkey_t = ''.join(canonlist)\n",
- " canonkey_t = canonkey_t if canonkey_t < canonkey_t[::-1] else canonkey_t[::-1]\n",
- " treelet.append(str(i) + canonkey_t)\n",
- " canonkey_l.update(Counter(treelet))\n",
- " print('\\n pattern', i, ': ', patterns[str(i)])\n",
- " print(' treelet', i, ': ', treelet)\n",
- " \n",
- "# print(canonkey_l)\n",
- " \n",
- " # n-star patterns\n",
- " for i in range(3, 6):\n",
- " treelet = []\n",
- " for pattern in patterns[str(i) + 'star']:\n",
- " canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:] ]\n",
- " canonlist.sort()\n",
- " canonkey_t = ('d' if i == 5 else str(i * 2)) + G.node[pattern[0]][node_label] + ''.join(canonlist)\n",
- " treelet.append(canonkey_t)\n",
- " canonkey_l.update(Counter(treelet))\n",
- " print('\\n pattern', i, 'star: ', patterns[str(i) + 'star'])\n",
- " print(' treelet', i, 'star: ', treelet)\n",
- " \n",
- " # pattern 7\n",
- " treelet = []\n",
- " for pattern in patterns['7']:\n",
- " canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]\n",
- " canonlist.sort()\n",
- " canonkey_t = '7' + G.node[pattern[0]][node_label] + ''.join(canonlist) \\\n",
- " + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \\\n",
- " + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label]\n",
- " treelet.append(canonkey_t)\n",
- " canonkey_l.update(Counter(treelet))\n",
- " print('\\n pattern 7: ', patterns['7'])\n",
- " print(' treelet 7: ', treelet)\n",
- " \n",
- " # pattern 11\n",
- " treelet = []\n",
- " for pattern in patterns['11']:\n",
- " canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:4] ]\n",
- " canonlist.sort()\n",
- " canonkey_t = 'b' + G.node[pattern[0]][node_label] + ''.join(canonlist) \\\n",
- " + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[0]][edge_label] \\\n",
- " + G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]\n",
- " treelet.append(canonkey_t)\n",
- " canonkey_l.update(Counter(treelet))\n",
- " print('\\n pattern 11: ', patterns['11'])\n",
- " print(' treelet 11: ', treelet)\n",
- "\n",
- " # pattern 10\n",
- " treelet = []\n",
- " for pattern in patterns['10']:\n",
- " canonkey4 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]\n",
- " canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]\n",
- " canonlist.sort()\n",
- " canonkey0 = ''.join(canonlist)\n",
- " canonkey_t = 'a' + G.node[pattern[3]][node_label] \\\n",
- " + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label] \\\n",
- " + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \\\n",
- " + canonkey4 + canonkey0\n",
- "# canonkey_t = 'a' + G.node[pattern[0]][node_label] + ''.join(canonlist) \\\n",
- "# + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \\\n",
- "# + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label]\n",
- " treelet.append(canonkey_t)\n",
- " canonkey_l.update(Counter(treelet))\n",
- " print('\\n pattern 10: ', patterns['10'])\n",
- " print(' treelet 10: ', treelet)\n",
- " \n",
- " # pattern 12\n",
- " treelet = []\n",
- " for pattern in patterns['12']:\n",
- " canonlist0 = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]\n",
- " canonlist0.sort()\n",
- " canonlist3 = [ G.node[leaf][node_label] + G[leaf][pattern[3]][edge_label] for leaf in pattern[4:6] ]\n",
- " canonlist3.sort()\n",
- " canonkey_t1 = 'c' + G.node[pattern[0]][node_label] \\\n",
- " + ''.join(canonlist0) \\\n",
- " + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \\\n",
- " + ''.join(canonlist3)\n",
- " \n",
- " canonkey_t2 = 'c' + G.node[pattern[3]][node_label] \\\n",
- " + ''.join(canonlist3) \\\n",
- " + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \\\n",
- " + ''.join(canonlist0)\n",
- " \n",
- " treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2)\n",
- " canonkey_l.update(Counter(treelet))\n",
- " print('\\n pattern 12: ', patterns['12'])\n",
- " print(' treelet 12: ', treelet)\n",
- " \n",
- " # pattern 9\n",
- " treelet = []\n",
- " for pattern in patterns['9']:\n",
- " canonkey2 = G.node[pattern[4]][node_label] + G[pattern[4]][pattern[2]][edge_label]\n",
- " canonkey3 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[3]][edge_label]\n",
- " prekey2 = G.node[pattern[2]][node_label] + G[pattern[2]][pattern[0]][edge_label]\n",
- " prekey3 = G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label]\n",
- " if prekey2 + canonkey2 < prekey3 + canonkey3:\n",
- " canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \\\n",
- " + prekey2 + prekey3 + canonkey2 + canonkey3\n",
- " else:\n",
- " canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \\\n",
- " + prekey3 + prekey2 + canonkey3 + canonkey2\n",
- " treelet.append('9' + G.node[pattern[0]][node_label] + canonkey_t)\n",
- " canonkey_l.update(Counter(treelet))\n",
- " print('\\n pattern 9: ', patterns['9'])\n",
- " print(' treelet 9: ', treelet)\n",
- " \n",
- "\n",
- " \n",
- " \n",
- " print('\\n numbers of canonical keys: ', canonkey_l)\n",
- " \n",
- " \n",
- " return canonkey_l\n",
- " \n",
- " return canonkey\n",
- " \n",
- "\n",
- "def treeletkernel(*args, node_label = 'atom', edge_label = 'bond_type', labeled = True):\n",
- " if len(args) == 1: # for a list of graphs\n",
- " Gn = args[0]\n",
- " Kmatrix = np.zeros((len(Gn), len(Gn)))\n",
- "\n",
- " start_time = time.time()\n",
- " \n",
- " for i in range(0, len(Gn)):\n",
- " print(i)\n",
- " for j in range(i, len(Gn)):\n",
- " Kmatrix[i][j] = treeletkernel(Gn[i], Gn[j], labeled = labeled, node_label = node_label, edge_label = edge_label)\n",
- " Kmatrix[j][i] = Kmatrix[i][j]\n",
- "\n",
- " run_time = time.time() - start_time\n",
- " print(\"\\n --- treelet kernel matrix of size %d built in %s seconds ---\" % (len(Gn), run_time))\n",
- " \n",
- " return Kmatrix, run_time\n",
- " \n",
- " else: # for only 2 graphs\n",
- " \n",
- " G1 = args[0]\n",
- " G = args[1]\n",
- " kernel = 0\n",
- " \n",
- "# start_time = time.time()\n",
- " \n",
- " \n",
- " canonkey2 = get_canonkey(G, node_label = node_label, edge_label = edge_label, labeled = labeled)\n",
- " canonkey1 = get_canonkey(G1, node_label = node_label, edge_label = edge_label, labeled = labeled)\n",
- " \n",
- " keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs\n",
- " vector1 = np.matrix([ (canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys ])\n",
- "# print(vector1)\n",
- " vector2 = np.matrix([ (canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys ]) \n",
- " kernel = np.sum(np.exp(- np.square(vector1 - vector2) / 2))\n",
- "# print(vector2)\n",
- " \n",
- " # labeling information\n",
- " \n",
- " # equal keys and graph isomorphism\n",
- " \n",
- "\n",
- "# run_time = time.time() - start_time\n",
- "# print(\"\\n --- treelet kernel built in %s seconds ---\" % (run_time))\n",
- " \n",
- "# print(kernel)\n",
- " return kernel#, run_time\n",
- " \n",
- "if __name__ == '__main__':\n",
- " main()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|