You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_treeletkernel_acyclic-checkpoint.ipynb 57 kB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stdout",
  10. "output_type": "stream",
  11. "text": [
  12. "\n",
  13. " --- This is a regression problem ---\n",
  14. "\n",
  15. "\n",
  16. " Loading dataset from file...\n",
  17. "\n",
  18. " Calculating kernel matrix, this could take a while...\n",
  19. "\n",
  20. " --- treelet kernel matrix of size 185 built in 0.47543811798095703 seconds ---\n",
  21. "[[4.00000000e+00 2.60653066e+00 1.00000000e+00 ... 1.26641655e-14\n",
  22. " 1.26641655e-14 1.26641655e-14]\n",
  23. " [2.60653066e+00 6.00000000e+00 1.00000000e+00 ... 1.26641655e-14\n",
  24. " 1.26641655e-14 1.26641655e-14]\n",
  25. " [1.00000000e+00 1.00000000e+00 4.00000000e+00 ... 3.00000000e+00\n",
  26. " 3.00000000e+00 3.00000000e+00]\n",
  27. " ...\n",
  28. " [1.26641655e-14 1.26641655e-14 3.00000000e+00 ... 1.80000000e+01\n",
  29. " 1.30548713e+01 8.19020657e+00]\n",
  30. " [1.26641655e-14 1.26641655e-14 3.00000000e+00 ... 1.30548713e+01\n",
  31. " 2.20000000e+01 9.71901120e+00]\n",
  32. " [1.26641655e-14 1.26641655e-14 3.00000000e+00 ... 8.19020657e+00\n",
  33. " 9.71901120e+00 1.60000000e+01]]\n",
  34. "\n",
  35. " Starting calculate accuracy/rmse...\n",
  36. "calculate performance: 98%|█████████▊| 983/1000 [00:01<00:00, 796.45it/s]\n",
  37. " Mean performance on train set: 2.688029\n",
  38. "With standard deviation: 1.541623\n",
  39. "\n",
  40. " Mean performance on test set: 10.099738\n",
  41. "With standard deviation: 5.035844\n",
  42. "calculate performance: 100%|██████████| 1000/1000 [00:01<00:00, 745.11it/s]\n",
  43. "\n",
  44. "\n",
  45. " rmse_test std_test rmse_train std_train k_time\n",
  46. "----------- ---------- ------------ ----------- --------\n",
  47. " 10.0997 5.03584 2.68803 1.54162 0.475438\n"
  48. ]
  49. }
  50. ],
  51. "source": [
  52. "%load_ext line_profiler\n",
  53. "\n",
  54. "import sys\n",
  55. "sys.path.insert(0, \"../\")\n",
  56. "from pygraph.utils.utils import kernel_train_test\n",
  57. "from pygraph.kernels.treeletKernel import treeletkernel\n",
  58. "\n",
  59. "datafile = '../../../../datasets/acyclic/Acyclic/dataset_bps.ds'\n",
  60. "kernel_file_path = 'kernelmatrices_path_acyclic/'\n",
  61. "\n",
  62. "kernel_para = dict(node_label = 'atom', edge_label = 'bond_type', labeled = True)\n",
  63. "\n",
  64. "kernel_train_test(datafile, kernel_file_path, treeletkernel, kernel_para, normalize = False)\n",
  65. "\n",
  66. "# %lprun -f treeletkernel \\\n",
  67. "# kernel_train_test(datafile, kernel_file_path, treeletkernel, kernel_para, normalize = False)"
  68. ]
  69. },
  70. {
  71. "cell_type": "code",
  72. "execution_count": null,
  73. "metadata": {},
  74. "outputs": [],
  75. "source": [
  76. "# results\n",
  77. "\n",
  78. "# with y normalization\n",
  79. " RMSE_test std_test RMSE_train std_train k_time\n",
  80. "----------- ---------- ------------ ----------- --------\n",
  81. " 8.3079 3.37838 2.90887 1.2679 0.500302\n",
  82. "\n",
  83. "# without y normalization\n",
  84. " RMSE_test std_test RMSE_train std_train k_time\n",
  85. "----------- ---------- ------------ ----------- --------\n",
  86. " 10.0997 5.03584 2.68803 1.54162 0.484171\n",
  87. "\n",
  88. " \n",
  89. "\n",
  90. "# G0 -> WL subtree h = 0\n",
  91. " rmse_test std_test rmse_train std_train k_time\n",
  92. "----------- ---------- ------------ ----------- --------\n",
  93. " 13.9223 2.88611 13.373 0.653301 0.186731\n",
  94. "\n",
  95. "# G0 U G1 U G6 U G8 U G13 -> WL subtree h = 1\n",
  96. " rmse_test std_test rmse_train std_train k_time\n",
  97. "----------- ---------- ------------ ----------- --------\n",
  98. " 8.97706 2.90771 6.7343 1.17505 0.223171\n",
  99. " \n",
  100. "# all patterns \\ { G3 U G4 U G5 U G10 } -> WL subtree h = 2 \n",
  101. " rmse_test std_test rmse_train std_train k_time\n",
  102. "----------- ---------- ------------ ----------- --------\n",
  103. " 7.31274 1.96289 3.73909 0.406267 0.294902\n",
  104. "\n",
  105. "# all patterns \\ { G4 U G5 } -> WL subtree h = 3\n",
  106. " rmse_test std_test rmse_train std_train k_time\n",
  107. "----------- ---------- ------------ ----------- --------\n",
  108. " 8.39977 2.78309 3.8606 1.58686 0.348912\n",
  109. "\n",
  110. "# all patterns \\ { G5 } \n",
  111. " rmse_test std_test rmse_train std_train k_time\n",
  112. "----------- ---------- ------------ ----------- --------\n",
  113. " 9.47647 4.22113 3.18029 1.5669 0.423638\n",
  114. " \n",
  115. " \n",
  116. " \n",
  117. "# G0, -> WL subtree h = 0\n",
  118. " rmse_test std_test rmse_train std_train k_time\n",
  119. "----------- ---------- ------------ ----------- --------\n",
  120. " 13.9223 2.88611 13.373 0.653301 0.186731 \n",
  121. " \n",
  122. "# G0 U G1 U G2 U G6 U G8 U G13 -> WL subtree h = 1\n",
  123. " rmse_test std_test rmse_train std_train k_time\n",
  124. "----------- ---------- ------------ ----------- --------\n",
  125. " 8.62431 2.54327 5.63422 0.255002 0.290797\n",
  126. " \n",
  127. "# all patterns \\ { G5 U G10 } -> WL subtree h = 2\n",
  128. " rmse_test std_test rmse_train std_train k_time\n",
  129. "----------- ---------- ------------ ----------- --------\n",
  130. " 10.1294 3.50275 3.69664 1.55116 0.418498"
  131. ]
  132. },
  133. {
  134. "cell_type": "code",
  135. "execution_count": 3,
  136. "metadata": {
  137. "scrolled": true
  138. },
  139. "outputs": [
  140. {
  141. "name": "stdout",
  142. "output_type": "stream",
  143. "text": [
  144. "{0: 'C', 1: 'C', 2: 'C', 3: 'C', 4: 'C', 5: 'O', 6: 'O'}\n"
  145. ]
  146. },
  147. {
  148. "data": {
  149. "image/png": "\n",
  150. "text/plain": [
  151. "<matplotlib.figure.Figure at 0x7f23a4d68ef0>"
  152. ]
  153. },
  154. "metadata": {},
  155. "output_type": "display_data"
  156. },
  157. {
  158. "name": "stdout",
  159. "output_type": "stream",
  160. "text": [
  161. "{0: 'C', 1: 'C', 2: 'C', 3: 'C', 4: 'C', 5: 'C', 6: 'O', 7: 'O'}\n"
  162. ]
  163. },
  164. {
  165. "data": {
  166. "image/png": "\n",
  167. "text/plain": [
  168. "<matplotlib.figure.Figure at 0x7f23a02cfac8>"
  169. ]
  170. },
  171. "metadata": {},
  172. "output_type": "display_data"
  173. },
  174. {
  175. "name": "stdout",
  176. "output_type": "stream",
  177. "text": [
  178. "\n",
  179. " pattern 0: [0, 1, 2, 3, 4, 5, 6, 7]\n",
  180. " treelet 0: ['C', 'C', 'C', 'C', 'C', 'C', 'O', 'O']\n",
  181. "\n",
  182. " pattern 1 : [[4, 0], [4, 1], [5, 4], [6, 2], [6, 5], [7, 3], [7, 5]]\n",
  183. " treelet 1 : ['1C1C', '1C1C', '1C1C', '1C1O', '1C1O', '1C1O', '1C1O']\n",
  184. "\n",
  185. " pattern 2 : [[1, 4, 0], [5, 4, 0], [5, 4, 1], [5, 6, 2], [5, 7, 3], [6, 5, 4], [7, 5, 4], [7, 5, 6]]\n",
  186. " treelet 2 : ['2C1C1C', '2C1C1C', '2C1C1C', '2C1O1C', '2C1O1C', '2C1C1O', '2C1C1O', '2O1C1O']\n",
  187. "\n",
  188. " pattern 3 : [[4, 5, 6, 2], [4, 5, 7, 3], [6, 5, 4, 0], [6, 5, 4, 1], [6, 5, 7, 3], [7, 5, 4, 0], [7, 5, 4, 1], [7, 5, 6, 2]]\n",
  189. " treelet 3 : ['3C1C1O1C', '3C1C1O1C', '3C1C1C1O', '3C1C1C1O', '3C1O1C1O', '3C1C1C1O', '3C1C1C1O', '3C1O1C1O']\n",
  190. "\n",
  191. " pattern 4 : [[2, 6, 5, 4, 0], [2, 6, 5, 4, 1], [3, 7, 5, 4, 0], [3, 7, 5, 4, 1], [3, 7, 5, 6, 2]]\n",
  192. " treelet 4 : ['4C1C1C1O1C', '4C1C1C1O1C', '4C1C1C1O1C', '4C1C1C1O1C', '4C1O1C1O1C']\n",
  193. "\n",
  194. " pattern 5 : []\n",
  195. " treelet 5 : []\n",
  196. "\n",
  197. " pattern 3 star: [[4, 0, 1, 5], [5, 4, 6, 7]]\n",
  198. " treelet 3 star: ['6CC1C1C1', '6CC1O1O1']\n",
  199. "\n",
  200. " pattern 4 star: []\n",
  201. " treelet 4 star: []\n",
  202. "\n",
  203. " pattern 5 star: []\n",
  204. " treelet 5 star: []\n",
  205. "\n",
  206. " pattern 7: [[4, 0, 1, 5, 6], [4, 0, 1, 5, 7], [5, 7, 6, 4, 0], [5, 7, 6, 4, 1], [5, 4, 7, 6, 2], [5, 4, 6, 7, 3]]\n",
  207. " treelet 7: ['7CC1C1C1O1', '7CC1C1C1O1', '7CO1O1C1C1', '7CO1O1C1C1', '7CC1O1O1C1', '7CC1O1O1C1']\n",
  208. "\n",
  209. " pattern 11: []\n",
  210. " treelet 11: []\n",
  211. "\n",
  212. " pattern 10: [[4, 0, 1, 5, 6, 2], [4, 0, 1, 5, 7, 3]]\n",
  213. " treelet 10: ['aCO1C1C1C1C1', 'aCO1C1C1C1C1']\n",
  214. "\n",
  215. " pattern 12: [[4, 0, 1, 5, 7, 6]]\n",
  216. " treelet 12: ['cCC1C1C1O1O1']\n",
  217. "\n",
  218. " pattern 9: [[5, 7, 6, 4, 2, 0], [5, 7, 6, 4, 2, 1], [5, 6, 7, 4, 3, 0], [5, 6, 7, 4, 3, 1], [5, 4, 7, 6, 3, 2]]\n",
  219. " treelet 9: ['9CO1C1O1C1C1', '9CO1C1O1C1C1', '9CO1C1O1C1C1', '9CO1C1O1C1C1', '9CC1O1O1C1C1']\n",
  220. "\n",
  221. " numbers of canonical keys: {'2O1C1O': 1, '7CC1C1C1O1': 2, '7CC1O1O1C1': 2, 'aCO1C1C1C1C1': 2, '2C1C1C': 3, '6CC1C1C1': 1, '9CO1C1O1C1C1': 4, '1C1C': 3, '3C1C1C1O': 4, '4C1C1C1O1C': 4, '7CO1O1C1C1': 2, '2C1C1O': 2, '1C1O': 4, '9CC1O1O1C1C1': 1, '3C1C1O1C': 2, '6CC1O1O1': 1, '2C1O1C': 2, '0O': 2, '4C1O1C1O1C': 1, 'cCC1C1C1O1O1': 1, '0C': 6, '3C1O1C1O': 2}\n",
  222. "\n",
  223. " pattern 0: [0, 1, 2, 3, 4, 5, 6]\n",
  224. " treelet 0: ['C', 'C', 'C', 'C', 'C', 'O', 'O']\n",
  225. "\n",
  226. " pattern 1 : [[2, 0], [3, 1], [5, 2], [5, 4], [6, 3], [6, 4]]\n",
  227. " treelet 1 : ['1C1C', '1C1C', '1C1O', '1C1O', '1C1O', '1C1O']\n",
  228. "\n",
  229. " pattern 2 : [[4, 5, 2], [4, 6, 3], [5, 2, 0], [6, 3, 1], [6, 4, 5]]\n",
  230. " treelet 2 : ['2C1O1C', '2C1O1C', '2C1C1O', '2C1C1O', '2O1C1O']\n",
  231. "\n",
  232. " pattern 3 : [[4, 5, 2, 0], [4, 6, 3, 1], [5, 4, 6, 3], [6, 4, 5, 2]]\n",
  233. " treelet 3 : ['3C1C1O1C', '3C1C1O1C', '3C1O1C1O', '3C1O1C1O']\n",
  234. "\n",
  235. " pattern 4 : [[3, 6, 4, 5, 2], [5, 4, 6, 3, 1], [6, 4, 5, 2, 0]]\n",
  236. " treelet 4 : ['4C1O1C1O1C', '4C1C1O1C1O', '4C1C1O1C1O']\n",
  237. "\n",
  238. " pattern 5 : [[2, 5, 4, 6, 3, 1], [3, 6, 4, 5, 2, 0]]\n",
  239. " treelet 5 : ['5C1C1O1C1O1C', '5C1C1O1C1O1C']\n",
  240. "\n",
  241. " pattern 3 star: []\n",
  242. " treelet 3 star: []\n",
  243. "\n",
  244. " pattern 4 star: []\n",
  245. " treelet 4 star: []\n",
  246. "\n",
  247. " pattern 5 star: []\n",
  248. " treelet 5 star: []\n",
  249. "\n",
  250. " pattern 7: []\n",
  251. " treelet 7: []\n",
  252. "\n",
  253. " pattern 11: []\n",
  254. " treelet 11: []\n",
  255. "\n",
  256. " pattern 10: []\n",
  257. " treelet 10: []\n",
  258. "\n",
  259. " pattern 12: []\n",
  260. " treelet 12: []\n",
  261. "\n",
  262. " pattern 9: []\n",
  263. " treelet 9: []\n",
  264. "\n",
  265. " numbers of canonical keys: {'3C1C1O1C': 2, '2O1C1O': 1, '1C1O': 4, '2C1O1C': 2, '0O': 2, '5C1C1O1C1O1C': 2, '1C1C': 2, '4C1O1C1O1C': 1, '0C': 5, '3C1O1C1O': 2, '4C1C1O1C1O': 2, '2C1C1O': 2}\n"
  266. ]
  267. }
  268. ],
  269. "source": [
  270. "import sys\n",
  271. "import pathlib\n",
  272. "from collections import Counter\n",
  273. "from itertools import chain\n",
  274. "sys.path.insert(0, \"../\")\n",
  275. "\n",
  276. "import networkx as nx\n",
  277. "import numpy as np\n",
  278. "import time\n",
  279. "\n",
  280. "from sklearn.metrics.pairwise import rbf_kernel, paired_distances\n",
  281. "import matplotlib.pyplot as plt\n",
  282. "\n",
  283. "# main\n",
  284. "import sys\n",
  285. "from collections import Counter\n",
  286. "import networkx as nx\n",
  287. "sys.path.insert(0, \"../\")\n",
  288. "from pygraph.utils.graphfiles import loadDataset\n",
  289. "\n",
  290. "\n",
  291. "def main(): \n",
  292. " dataset, y = loadDataset(\"../../../../datasets/acyclic/Acyclic/dataset_bps.ds\")\n",
  293. " G1 = dataset[15]\n",
  294. " print(nx.get_node_attributes(G1, 'label'))\n",
  295. " nx.draw_networkx(G1)\n",
  296. " plt.show()\n",
  297. " G2 = dataset[57] # 180 double 4, 57, 3, double 3\n",
  298. " print(nx.get_node_attributes(G2, 'label'))\n",
  299. " nx.draw_networkx(G2)\n",
  300. " plt.show()\n",
  301. "\n",
  302. " treeletkernel(G1, G2, labeled = True)\n",
  303. " # Kmatrix = weisfeilerlehmankernel(G1, G2)\n",
  304. " \n",
  305. "def find_paths(G, source_node, length):\n",
  306. " if length == 0:\n",
  307. " return [[source_node]]\n",
  308. " path = [ [source_node] + path for neighbor in G[source_node] \\\n",
  309. " for path in find_paths(G, neighbor, length - 1) if source_node not in path ]\n",
  310. " return path\n",
  311. "\n",
  312. "def find_all_paths(G, length):\n",
  313. " all_paths = []\n",
  314. " for node in G:\n",
  315. " all_paths.extend(find_paths(G, node, length))\n",
  316. " all_paths_r = [ path[::-1] for path in all_paths ]\n",
  317. " \n",
  318. " # remove double direction\n",
  319. " for idx, path in enumerate(all_paths[:-1]):\n",
  320. " for path2 in all_paths_r[idx+1::]:\n",
  321. " if path == path2:\n",
  322. " all_paths[idx] = []\n",
  323. " break\n",
  324. " \n",
  325. " return list(filter(lambda a: a != [], all_paths))\n",
  326. "\n",
  327. "def get_canonkey(G, node_label = 'atom', edge_label = 'bond_type', labeled = True):\n",
  328. " \n",
  329. " patterns = {}\n",
  330. " canonkey = {} # canonical key\n",
  331. " \n",
  332. " ### structural analysis ###\n",
  333. " # linear patterns\n",
  334. " patterns['0'] = G.nodes()\n",
  335. " canonkey['0'] = nx.number_of_nodes(G)\n",
  336. " for i in range(1, 6):\n",
  337. " patterns[str(i)] = find_all_paths(G, i)\n",
  338. " canonkey[str(i)] = len(patterns[str(i)])\n",
  339. " \n",
  340. " # n-star patterns\n",
  341. " patterns['3star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 3 ]\n",
  342. " patterns['4star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 4 ]\n",
  343. " patterns['5star'] = [ [node] + [neighbor for neighbor in G[node]] for node in G.nodes() if G.degree(node) == 5 ] \n",
  344. " # n-star patterns\n",
  345. " canonkey['6'] = len(patterns['3star'])\n",
  346. " canonkey['8'] = len(patterns['4star'])\n",
  347. " canonkey['d'] = len(patterns['5star'])\n",
  348. " \n",
  349. " # pattern 7\n",
  350. " patterns['7'] = []\n",
  351. " for pattern in patterns['3star']:\n",
  352. " for i in range(1, len(pattern)):\n",
  353. " if G.degree(pattern[i]) >= 2:\n",
  354. " pattern_t = pattern[:]\n",
  355. " pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]\n",
  356. " for neighborx in G[pattern[i]]:\n",
  357. " if neighborx != pattern[0]:\n",
  358. " new_pattern = pattern_t + [ neighborx ]\n",
  359. "# new_patterns = [ pattern + [neighbor] for neighbor in G[pattern[i]] if neighbor != pattern[0] ]\n",
  360. " patterns['7'].append(new_pattern)\n",
  361. " canonkey['7'] = len(patterns['7'])\n",
  362. " \n",
  363. " # pattern 11\n",
  364. " patterns['11'] = []\n",
  365. " for pattern in patterns['4star']:\n",
  366. " for i in range(1, len(pattern)):\n",
  367. " if G.degree(pattern[i]) >= 2:\n",
  368. " pattern_t = pattern[:]\n",
  369. " pattern_t[i], pattern_t[4] = pattern_t[4], pattern_t[i]\n",
  370. " for neighborx in G[pattern[i]]:\n",
  371. " if neighborx != pattern[0]:\n",
  372. " new_pattern = pattern_t + [ neighborx ]\n",
  373. "# new_patterns = [ pattern + [neighborx] for neighborx in G[pattern[i]] if neighborx != pattern[0] ]\n",
  374. " patterns['11'].append(new_pattern)\n",
  375. " canonkey['b'] = len(patterns['11'])\n",
  376. " \n",
  377. " # pattern 12\n",
  378. " patterns['12'] = []\n",
  379. " rootlist = []\n",
  380. " for pattern in patterns['3star']:\n",
  381. "# print(pattern)\n",
  382. " if pattern[0] not in rootlist:\n",
  383. " rootlist.append(pattern[0])\n",
  384. " for i in range(1, len(pattern)):\n",
  385. " if G.degree(pattern[i]) >= 3:\n",
  386. " rootlist.append(pattern[i])\n",
  387. " pattern_t = pattern[:]\n",
  388. " pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]\n",
  389. " for neighborx1 in G[pattern[i]]:\n",
  390. " if neighborx1 != pattern[0]:\n",
  391. " for neighborx2 in G[pattern[i]]:\n",
  392. " if neighborx1 > neighborx2 and neighborx2 != pattern[0]:\n",
  393. " new_pattern = pattern_t + [neighborx1] + [neighborx2]\n",
  394. "# new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pattern[i]] if neighborx1 != pattern[0] for neighborx2 in G[pattern[i]] if (neighborx1 > neighborx2 and neighborx2 != pattern[0]) ]\n",
  395. " patterns['12'].append(new_pattern)\n",
  396. " canonkey['c'] = int(len(patterns['12']) / 2)\n",
  397. " \n",
  398. " # pattern 9\n",
  399. " patterns['9'] = []\n",
  400. " for pattern in patterns['3star']:\n",
  401. "# print('pattern: ', pattern)\n",
  402. " for pairs in [ [neighbor1, neighbor2] for neighbor1 in G[pattern[0]] if G.degree(neighbor1) >= 2 \\\n",
  403. " for neighbor2 in G[pattern[0]] if G.degree(neighbor2) >= 2 if neighbor1 > neighbor2 ]:\n",
  404. "# print('pairs: ', pairs)\n",
  405. " pattern_t = pattern[:]\n",
  406. "# print('pattern_t: ', pattern_t)\n",
  407. " pattern_t[pattern_t.index(pairs[0])], pattern_t[2] = pattern_t[2], pattern_t[pattern_t.index(pairs[0])]\n",
  408. "# print('pattern_t: ', pattern_t)\n",
  409. " pattern_t[pattern_t.index(pairs[1])], pattern_t[3] = pattern_t[3], pattern_t[pattern_t.index(pairs[1])]\n",
  410. "# print('pattern_t: ', pattern_t)\n",
  411. " for neighborx1 in G[pairs[0]]:\n",
  412. " if neighborx1 != pattern[0]:\n",
  413. " for neighborx2 in G[pairs[1]]:\n",
  414. " if neighborx2 != pattern[0]:\n",
  415. " new_pattern = pattern_t + [neighborx1] + [neighborx2]\n",
  416. "# new_patterns = [ pattern + [neighborx1] + [neighborx2] for neighborx1 in G[pairs[0]] if neighborx1 != pattern[0] for neighborx2 in G[pairs[1]] if neighborx2 != pattern[0] ]\n",
  417. " patterns['9'].append(new_pattern)\n",
  418. " canonkey['9'] = len(patterns['9'])\n",
  419. " \n",
  420. " # pattern 10\n",
  421. " patterns['10'] = []\n",
  422. " for pattern in patterns['3star']: \n",
  423. " for i in range(1, len(pattern)):\n",
  424. " if G.degree(pattern[i]) >= 2:\n",
  425. " for neighborx in G[pattern[i]]:\n",
  426. " if neighborx != pattern[0] and G.degree(neighborx) >= 2:\n",
  427. " pattern_t = pattern[:]\n",
  428. " pattern_t[i], pattern_t[3] = pattern_t[3], pattern_t[i]\n",
  429. " new_patterns = [ pattern_t + [neighborx] + [neighborxx] for neighborxx in G[neighborx] if neighborxx != pattern[i] ]\n",
  430. " patterns['10'].extend(new_patterns)\n",
  431. " canonkey['a'] = len(patterns['10'])\n",
  432. " \n",
  433. " ### labeling information ###\n",
  434. " if labeled == True:\n",
  435. " canonkey_l = {}\n",
  436. " \n",
  437. " # linear patterns\n",
  438. " canonkey_t = Counter(list(nx.get_node_attributes(G, node_label).values()))\n",
  439. " for key in canonkey_t:\n",
  440. " canonkey_l['0' + key] = canonkey_t[key]\n",
  441. " print('\\n pattern 0: ', patterns['0'])\n",
  442. " print(' treelet 0: ', list(nx.get_node_attributes(G, node_label).values()))\n",
  443. " \n",
  444. " for i in range(1, 6):\n",
  445. " treelet = []\n",
  446. " for pattern in patterns[str(i)]:\n",
  447. " canonlist = list(chain.from_iterable((G.node[node][node_label], \\\n",
  448. " G[node][pattern[idx+1]][edge_label]) for idx, node in enumerate(pattern[:-1])))\n",
  449. " canonlist.append(G.node[pattern[-1]][node_label])\n",
  450. " canonkey_t = ''.join(canonlist)\n",
  451. " canonkey_t = canonkey_t if canonkey_t < canonkey_t[::-1] else canonkey_t[::-1]\n",
  452. " treelet.append(str(i) + canonkey_t)\n",
  453. " canonkey_l.update(Counter(treelet))\n",
  454. " print('\\n pattern', i, ': ', patterns[str(i)])\n",
  455. " print(' treelet', i, ': ', treelet)\n",
  456. " \n",
  457. "# print(canonkey_l)\n",
  458. " \n",
  459. " # n-star patterns\n",
  460. " for i in range(3, 6):\n",
  461. " treelet = []\n",
  462. " for pattern in patterns[str(i) + 'star']:\n",
  463. " canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:] ]\n",
  464. " canonlist.sort()\n",
  465. " canonkey_t = ('d' if i == 5 else str(i * 2)) + G.node[pattern[0]][node_label] + ''.join(canonlist)\n",
  466. " treelet.append(canonkey_t)\n",
  467. " canonkey_l.update(Counter(treelet))\n",
  468. " print('\\n pattern', i, 'star: ', patterns[str(i) + 'star'])\n",
  469. " print(' treelet', i, 'star: ', treelet)\n",
  470. " \n",
  471. " # pattern 7\n",
  472. " treelet = []\n",
  473. " for pattern in patterns['7']:\n",
  474. " canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]\n",
  475. " canonlist.sort()\n",
  476. " canonkey_t = '7' + G.node[pattern[0]][node_label] + ''.join(canonlist) \\\n",
  477. " + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \\\n",
  478. " + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label]\n",
  479. " treelet.append(canonkey_t)\n",
  480. " canonkey_l.update(Counter(treelet))\n",
  481. " print('\\n pattern 7: ', patterns['7'])\n",
  482. " print(' treelet 7: ', treelet)\n",
  483. " \n",
  484. " # pattern 11\n",
  485. " treelet = []\n",
  486. " for pattern in patterns['11']:\n",
  487. " canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:4] ]\n",
  488. " canonlist.sort()\n",
  489. " canonkey_t = 'b' + G.node[pattern[0]][node_label] + ''.join(canonlist) \\\n",
  490. " + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[0]][edge_label] \\\n",
  491. " + G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]\n",
  492. " treelet.append(canonkey_t)\n",
  493. " canonkey_l.update(Counter(treelet))\n",
  494. " print('\\n pattern 11: ', patterns['11'])\n",
  495. " print(' treelet 11: ', treelet)\n",
  496. "\n",
  497. " # pattern 10\n",
  498. " treelet = []\n",
  499. " for pattern in patterns['10']:\n",
  500. " canonkey4 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[4]][edge_label]\n",
  501. " canonlist = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]\n",
  502. " canonlist.sort()\n",
  503. " canonkey0 = ''.join(canonlist)\n",
  504. " canonkey_t = 'a' + G.node[pattern[3]][node_label] \\\n",
  505. " + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label] \\\n",
  506. " + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \\\n",
  507. " + canonkey4 + canonkey0\n",
  508. "# canonkey_t = 'a' + G.node[pattern[0]][node_label] + ''.join(canonlist) \\\n",
  509. "# + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \\\n",
  510. "# + G.node[pattern[4]][node_label] + G[pattern[4]][pattern[3]][edge_label]\n",
  511. " treelet.append(canonkey_t)\n",
  512. " canonkey_l.update(Counter(treelet))\n",
  513. " print('\\n pattern 10: ', patterns['10'])\n",
  514. " print(' treelet 10: ', treelet)\n",
  515. " \n",
  516. " # pattern 12\n",
  517. " treelet = []\n",
  518. " for pattern in patterns['12']:\n",
  519. " canonlist0 = [ G.node[leaf][node_label] + G[leaf][pattern[0]][edge_label] for leaf in pattern[1:3] ]\n",
  520. " canonlist0.sort()\n",
  521. " canonlist3 = [ G.node[leaf][node_label] + G[leaf][pattern[3]][edge_label] for leaf in pattern[4:6] ]\n",
  522. " canonlist3.sort()\n",
  523. " canonkey_t1 = 'c' + G.node[pattern[0]][node_label] \\\n",
  524. " + ''.join(canonlist0) \\\n",
  525. " + G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label] \\\n",
  526. " + ''.join(canonlist3)\n",
  527. " \n",
  528. " canonkey_t2 = 'c' + G.node[pattern[3]][node_label] \\\n",
  529. " + ''.join(canonlist3) \\\n",
  530. " + G.node[pattern[0]][node_label] + G[pattern[0]][pattern[3]][edge_label] \\\n",
  531. " + ''.join(canonlist0)\n",
  532. " \n",
  533. " treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2)\n",
  534. " canonkey_l.update(Counter(treelet))\n",
  535. " print('\\n pattern 12: ', patterns['12'])\n",
  536. " print(' treelet 12: ', treelet)\n",
  537. " \n",
  538. " # pattern 9\n",
  539. " treelet = []\n",
  540. " for pattern in patterns['9']:\n",
  541. " canonkey2 = G.node[pattern[4]][node_label] + G[pattern[4]][pattern[2]][edge_label]\n",
  542. " canonkey3 = G.node[pattern[5]][node_label] + G[pattern[5]][pattern[3]][edge_label]\n",
  543. " prekey2 = G.node[pattern[2]][node_label] + G[pattern[2]][pattern[0]][edge_label]\n",
  544. " prekey3 = G.node[pattern[3]][node_label] + G[pattern[3]][pattern[0]][edge_label]\n",
  545. " if prekey2 + canonkey2 < prekey3 + canonkey3:\n",
  546. " canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \\\n",
  547. " + prekey2 + prekey3 + canonkey2 + canonkey3\n",
  548. " else:\n",
  549. " canonkey_t = G.node[pattern[1]][node_label] + G[pattern[1]][pattern[0]][edge_label] \\\n",
  550. " + prekey3 + prekey2 + canonkey3 + canonkey2\n",
  551. " treelet.append('9' + G.node[pattern[0]][node_label] + canonkey_t)\n",
  552. " canonkey_l.update(Counter(treelet))\n",
  553. " print('\\n pattern 9: ', patterns['9'])\n",
  554. " print(' treelet 9: ', treelet)\n",
  555. " \n",
  556. "\n",
  557. " \n",
  558. " \n",
  559. " print('\\n numbers of canonical keys: ', canonkey_l)\n",
  560. " \n",
  561. " \n",
  562. " return canonkey_l\n",
  563. " \n",
  564. " return canonkey\n",
  565. " \n",
  566. "\n",
  567. "def treeletkernel(*args, node_label = 'atom', edge_label = 'bond_type', labeled = True):\n",
  568. " if len(args) == 1: # for a list of graphs\n",
  569. " Gn = args[0]\n",
  570. " Kmatrix = np.zeros((len(Gn), len(Gn)))\n",
  571. "\n",
  572. " start_time = time.time()\n",
  573. " \n",
  574. " for i in range(0, len(Gn)):\n",
  575. " print(i)\n",
  576. " for j in range(i, len(Gn)):\n",
  577. " Kmatrix[i][j] = treeletkernel(Gn[i], Gn[j], labeled = labeled, node_label = node_label, edge_label = edge_label)\n",
  578. " Kmatrix[j][i] = Kmatrix[i][j]\n",
  579. "\n",
  580. " run_time = time.time() - start_time\n",
  581. " print(\"\\n --- treelet kernel matrix of size %d built in %s seconds ---\" % (len(Gn), run_time))\n",
  582. " \n",
  583. " return Kmatrix, run_time\n",
  584. " \n",
  585. " else: # for only 2 graphs\n",
  586. " \n",
  587. " G1 = args[0]\n",
  588. " G = args[1]\n",
  589. " kernel = 0\n",
  590. " \n",
  591. "# start_time = time.time()\n",
  592. " \n",
  593. " \n",
  594. " canonkey2 = get_canonkey(G, node_label = node_label, edge_label = edge_label, labeled = labeled)\n",
  595. " canonkey1 = get_canonkey(G1, node_label = node_label, edge_label = edge_label, labeled = labeled)\n",
  596. " \n",
  597. " keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs\n",
  598. " vector1 = np.matrix([ (canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys ])\n",
  599. "# print(vector1)\n",
  600. " vector2 = np.matrix([ (canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys ]) \n",
  601. " kernel = np.sum(np.exp(- np.square(vector1 - vector2) / 2))\n",
  602. "# print(vector2)\n",
  603. " \n",
  604. " # labeling information\n",
  605. " \n",
  606. " # equal keys and graph isomorphism\n",
  607. " \n",
  608. "\n",
  609. "# run_time = time.time() - start_time\n",
  610. "# print(\"\\n --- treelet kernel built in %s seconds ---\" % (run_time))\n",
  611. " \n",
  612. "# print(kernel)\n",
  613. " return kernel#, run_time\n",
  614. " \n",
  615. "if __name__ == '__main__':\n",
  616. " main()"
  617. ]
  618. }
  619. ],
  620. "metadata": {
  621. "kernelspec": {
  622. "display_name": "Python 3",
  623. "language": "python",
  624. "name": "python3"
  625. },
  626. "language_info": {
  627. "codemirror_mode": {
  628. "name": "ipython",
  629. "version": 3
  630. },
  631. "file_extension": ".py",
  632. "mimetype": "text/x-python",
  633. "name": "python",
  634. "nbconvert_exporter": "python",
  635. "pygments_lexer": "ipython3",
  636. "version": "3.5.2"
  637. }
  638. },
  639. "nbformat": 4,
  640. "nbformat_minor": 2
  641. }

A Python package for graph kernels, graph edit distances and graph pre-image problem.