You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_fetcher.py 58 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Oct 20 14:25:49 2020
  5. @author:
  6. Paul Zanoncelli, paul.zanoncelli@ecole.ensicaen.fr
  7. Luc Brun luc.brun@ensicaen.fr
  8. Sebastien Bougleux sebastien.bougleux@unicaen.fr
  9. Benoit Gaüzère benoit.gauzere@insa-rouen.fr
  10. Linlin Jia linlin.jia@insa-rouen.fr
  11. """
  12. import os
  13. import os.path as osp
  14. import urllib
  15. import tarfile
  16. from zipfile import ZipFile
  17. # from gklearn.utils.graphfiles import loadDataset
  18. # import torch.nn.functional as F
  19. import networkx as nx
  20. # import torch
  21. import random
  22. import sys
  23. # from lxml import etree
  24. import re
  25. # from tqdm import tqdm
  26. from gklearn.dataset import DATABASES, DATASET_META
  27. class DataFetcher():
  28. def __init__(self, name=None, root='datasets', reload=False, verbose=False):
  29. self._name = name
  30. self._root = root
  31. if not osp.exists(self._root):
  32. os.makedirs(self._root)
  33. self._reload = reload
  34. self._verbose = verbose
  35. # self.has_train_valid_test = {
  36. # "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'),
  37. # "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'),
  38. # "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'),
  39. # # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'),
  40. # "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'),
  41. # 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'),
  42. # 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl')
  43. # },
  44. # "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'),
  45. # # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'],
  46. # "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'),
  47. # # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl')
  48. # }
  49. if self._name is None:
  50. if self._verbose:
  51. print('No dataset name entered. All possible datasets will be loaded.')
  52. self._name, self._path = [], []
  53. for idx, ds_name in enumerate(DATASET_META):
  54. if self._verbose:
  55. print(str(idx + 1), '/', str(len(DATASET_META)), 'Fetching', ds_name, end='... ')
  56. self._name.append(ds_name)
  57. success = self.write_archive_file(ds_name)
  58. if success:
  59. self._path.append(self.open_files(ds_name))
  60. else:
  61. self._path.append(None)
  62. if self._verbose and self._path[-1] is not None and not self._reload:
  63. print('Fetched.')
  64. if self._verbose:
  65. print('Finished.', str(sum(v is not None for v in self._path)), 'of', str(len(self._path)), 'datasets are successfully fetched.')
  66. elif self._name not in DATASET_META:
  67. message = 'Invalid dataset name "' + self._name + '".'
  68. message += '\nAvailable datasets are as follows: \n\n'
  69. message += '\n'.join(ds for ds in sorted(DATASET_META))
  70. message += '\n\nFollowing special suffices can be added to the name:'
  71. message += '\n\n' + '\n'.join(['_unlabeled'])
  72. raise ValueError(message)
  73. else:
  74. self.write_archive_file(self._name)
  75. self._path = self.open_files(self._name)
  76. # self.max_for_letter = 0
  77. # if mode == 'Pytorch':
  78. # if self._name in self.data_to_use_in_datasets :
  79. # Gs,y = self.dataset
  80. # inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y)
  81. # #print(inputs,adjs)
  82. # self.pytorch_dataset = inputs,adjs,y
  83. # elif self._name == "Pah":
  84. # self.pytorch_dataset = []
  85. # test,train = self.dataset
  86. # Gs_test,y_test = test
  87. # Gs_train,y_train = train
  88. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test))
  89. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train))
  90. # elif self._name in self.has_train_valid_test:
  91. # self.pytorch_dataset = []
  92. # #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])
  93. # test,train,valid = self.dataset
  94. # Gs_test,y_test = test
  95. #
  96. # Gs_train,y_train = train
  97. # Gs_valid,y_valid = valid
  98. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test))
  99. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train))
  100. # self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid))
  101. # #############
  102. # """
  103. # for G in Gs :
  104. # for e in G.edges():
  105. # print(G[e[0]])
  106. # """
  107. # ##############
  108. def download_file(self, url):
  109. try :
  110. response = urllib.request.urlopen(url)
  111. except urllib.error.HTTPError:
  112. print('"', url.split('/')[-1], '" is not available or incorrect http link.')
  113. return
  114. except urllib.error.URLError:
  115. print('Network is unreachable.')
  116. return
  117. return response
  118. def write_archive_file(self, ds_name):
  119. path = osp.join(self._root, ds_name)
  120. # filename_dir = osp.join(path,filename)
  121. if not osp.exists(path) or self._reload:
  122. url = DATASET_META[ds_name]['url']
  123. response = self.download_file(url)
  124. if response is None:
  125. return False
  126. os.makedirs(path, exist_ok=True)
  127. with open(os.path.join(path, url.split('/')[-1]), 'wb') as outfile:
  128. outfile.write(response.read())
  129. return True
  130. def open_files(self, ds_name=None):
  131. if ds_name is None:
  132. ds_name = (self._name if isinstance(self._name, str) else self._name[0])
  133. filename = DATASET_META[ds_name]['url'].split('/')[-1]
  134. path = osp.join(self._root, ds_name)
  135. filename_archive = osp.join(path, filename)
  136. if filename.endswith('gz'):
  137. if tarfile.is_tarfile(filename_archive):
  138. with tarfile.open(filename_archive, 'r:gz') as tar:
  139. if self._reload and self._verbose:
  140. print(filename + ' Downloaded.')
  141. subpath = os.path.join(path, tar.getnames()[0].split('/')[0])
  142. if not osp.exists(subpath) or self._reload:
  143. tar.extractall(path = path)
  144. return subpath
  145. elif filename.endswith('.tar'):
  146. if tarfile.is_tarfile(filename_archive):
  147. with tarfile.open(filename_archive, 'r:') as tar:
  148. if self._reload and self._verbose:
  149. print(filename + ' Downloaded.')
  150. subpath = os.path.join(path, tar.getnames()[0])
  151. if not osp.exists(subpath) or self._reload:
  152. tar.extractall(path = path)
  153. return subpath
  154. elif filename.endswith('.zip'):
  155. with ZipFile(filename_archive, 'r') as zip_ref:
  156. if self._reload and self._verbose:
  157. print(filename + ' Downloaded.')
  158. subpath = os.path.join(path, zip_ref.namelist()[0])
  159. if not osp.exists(subpath) or self._reload:
  160. zip_ref.extractall(path)
  161. return subpath
  162. else:
  163. raise ValueError(filename + ' Unsupported file.')
  164. def get_all_ds_infos(self, database):
  165. """Get information of all datasets from a database.
  166. Parameters
  167. ----------
  168. database : string
  169. DESCRIPTION.
  170. Returns
  171. -------
  172. None.
  173. """
  174. if database.lower() == 'tudataset':
  175. infos = self.get_all_tud_ds_infos()
  176. elif database.lower() == 'iam':
  177. pass
  178. else:
  179. msg = 'Invalid Database name "' + database + '"'
  180. msg += '\n Available databases are as follows: \n\n'
  181. msg += '\n'.join(db for db in sorted(DATABASES))
  182. msg += 'Check "gklearn.dataset.DATASET_META" for more details.'
  183. raise ValueError(msg)
  184. return infos
  185. def get_all_tud_ds_infos(self):
  186. """Get information of all datasets from database TUDataset.
  187. Returns
  188. -------
  189. None.
  190. """
  191. from lxml import etree
  192. try:
  193. response = urllib.request.urlopen(DATABASES['tudataset'])
  194. except urllib.error.HTTPError:
  195. print('The URL of the database "TUDataset" is not available:\n' + DATABASES['tudataset'])
  196. infos = {}
  197. # Get tables.
  198. h_str = response.read()
  199. tree = etree.HTML(h_str)
  200. tables = tree.xpath('//table')
  201. for table in tables:
  202. # Get the domain of the datasets.
  203. h2_nodes = table.getprevious()
  204. if h2_nodes is not None and h2_nodes.tag == 'h2':
  205. domain = h2_nodes.text.strip().lower()
  206. else:
  207. domain = ''
  208. # Get each line in the table.
  209. tr_nodes = table.xpath('tbody/tr')
  210. for tr in tr_nodes[1:]:
  211. # Get each element in the line.
  212. td_node = tr.xpath('td')
  213. # task type.
  214. cls_txt = td_node[3].text.strip()
  215. if not cls_txt.startswith('R'):
  216. class_number = int(cls_txt)
  217. task_type = 'classification'
  218. else:
  219. class_number = None
  220. task_type = 'regression'
  221. # node attrs.
  222. na_text = td_node[8].text.strip()
  223. if not na_text.startswith('+'):
  224. node_attr_dim = 0
  225. else:
  226. node_attr_dim = int(re.findall('\((.*)\)', na_text)[0])
  227. # edge attrs.
  228. ea_text = td_node[10].text.strip()
  229. if ea_text == 'temporal':
  230. edge_attr_dim = ea_text
  231. elif not ea_text.startswith('+'):
  232. edge_attr_dim = 0
  233. else:
  234. edge_attr_dim = int(re.findall('\((.*)\)', ea_text)[0])
  235. # geometry.
  236. geo_txt = td_node[9].text.strip()
  237. if geo_txt == '–':
  238. geometry = None
  239. else:
  240. geometry = geo_txt
  241. # url.
  242. url = td_node[11].xpath('a')[0].attrib['href'].strip()
  243. pos_zip = url.rfind('.zip')
  244. url = url[:pos_zip + 4]
  245. infos[td_node[0].xpath('strong')[0].text.strip()] = {
  246. 'database': 'tudataset',
  247. 'reference': td_node[1].text.strip(),
  248. 'dataset_size': int(td_node[2].text.strip()),
  249. 'class_number': class_number,
  250. 'task_type': task_type,
  251. 'ave_node_num': float(td_node[4].text.strip()),
  252. 'ave_edge_num': float(td_node[5].text.strip()),
  253. 'node_labeled': True if td_node[6].text.strip() == '+' else False,
  254. 'edge_labeled': True if td_node[7].text.strip() == '+' else False,
  255. 'node_attr_dim': node_attr_dim,
  256. 'geometry': geometry,
  257. 'edge_attr_dim': edge_attr_dim,
  258. 'url': url,
  259. 'domain': domain
  260. }
  261. return infos
  262. def pretty_ds_infos(self, infos):
  263. """Get the string that pretty prints the information of datasets.
  264. Parameters
  265. ----------
  266. datasets : dict
  267. The datasets' information.
  268. Returns
  269. -------
  270. p_str : string
  271. The pretty print of the datasets' information.
  272. """
  273. p_str = '{\n'
  274. for key, val in infos.items():
  275. p_str += '\t\'' + str(key) + '\': {\n'
  276. for k, v in val.items():
  277. p_str += '\t\t\'' + str(k) + '\': '
  278. if isinstance(v, str):
  279. p_str += '\'' + str(v) + '\',\n'
  280. else:
  281. p_str += '' + str(v) + ',\n'
  282. p_str += '\t},\n'
  283. p_str += '}'
  284. return p_str
  285. @property
  286. def path(self):
  287. return self._path
  288. def dataset(self):
  289. if self.mode == "Tensorflow":
  290. return #something
  291. if self.mode == "Pytorch":
  292. return self.pytorch_dataset
  293. return self.dataset
  294. def info(self):
  295. print(self.info_dataset[self._name])
  296. def iter_load_dataset(self,data):
  297. results = []
  298. for datasets in data :
  299. results.append(loadDataset(osp.join(self._root,self._name,datasets)))
  300. return results
  301. def load_dataset(self,list_files):
  302. if self._name == "Ptc":
  303. if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']:
  304. raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr')
  305. results = []
  306. results.append(loadDataset(osp.join(self.root,self._name,'PTC/Test',self.gender + '.ds')))
  307. results.append(loadDataset(osp.join(self.root,self._name,'PTC/Train',self.gender + '.ds')))
  308. return results
  309. if self.name == "Pah":
  310. maximum_sets = 0
  311. for file in list_files:
  312. if file.endswith('ds'):
  313. maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0]))
  314. self.max_for_letter = maximum_sets
  315. if not type(self.option) == int or self.option > maximum_sets or self.option < 0:
  316. raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets))
  317. data = self.has_train_valid_test["Pah"]
  318. data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds'
  319. data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds'
  320. return self.iter_load_dataset(data)
  321. if self.name == "Letter":
  322. if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]:
  323. data = self.has_train_valid_test["Letter"][self.option.upper()]
  324. else:
  325. message = "The parameter for letter is incorrect choose between : "
  326. message += "\nhigh med low"
  327. raise ValueError(message)
  328. return self.iter_load_dataset(data)
  329. if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test
  330. data = self.has_train_valid_test[self.name]
  331. return self.iter_load_dataset(data)
  332. else: #common dataset without train,valid and test, only dataset.ds file
  333. data = self.data_to_use_in_datasets[self.name]
  334. if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane
  335. return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1]))
  336. if data in list_files:
  337. return loadDataset(osp.join(self.root,self.name,data))
  338. def build_dictionary(self,Gs):
  339. labels = set()
  340. #next line : from DeepGraphWithNNTorch
  341. #bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])))
  342. sizes = set()
  343. for G in Gs :
  344. for _,node in G.nodes(data = True): # or for node in nx.nodes(G)
  345. #print(_,node)
  346. labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ?
  347. sizes.add(G.order())
  348. label_dict = {}
  349. #print("labels : ", labels, bond_type_number_maxi)
  350. for i,label in enumerate(labels):
  351. label_dict[label] = [0.]*len(labels)
  352. label_dict[label][i] = 1.
  353. return label_dict
  354. def from_networkx_to_pytorch(self,Gs,y):
  355. #exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]}
  356. # code from https://github.com/bgauzere/pygnn/blob/master/utils.py
  357. atom_to_onehot = self.build_dictionary(Gs)
  358. max_size = 30
  359. adjs = []
  360. inputs = []
  361. for i, G in enumerate(Gs):
  362. I = torch.eye(G.order(), G.order())
  363. #A = torch.Tensor(nx.adjacency_matrix(G).todense())
  364. #A = torch.Tensor(nx.to_numpy_matrix(G))
  365. A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ?
  366. adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...))
  367. adjs.append(adj)
  368. f_0 = []
  369. for _, label in G.nodes(data=True):
  370. #print(_,label)
  371. cur_label = atom_to_onehot[label['label'][0]].copy()
  372. f_0.append(cur_label)
  373. X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order()))
  374. inputs.append(X)
  375. return inputs,adjs,y
  376. def from_pytorch_to_tensorflow(self,batch_size):
  377. seed = random.randrange(sys.maxsize)
  378. random.seed(seed)
  379. tf_inputs = random.sample(self.pytorch_dataset[0],batch_size)
  380. random.seed(seed)
  381. tf_y = random.sample(self.pytorch_dataset[2],batch_size)
  382. def from_networkx_to_tensor(self,G,dict):
  383. A=nx.to_numpy_matrix(G)
  384. lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)]
  385. return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab))
  386. #dataset= selfopen_files()
  387. #print(build_dictionary(Gs))
  388. #dic={'C':0,'N':1,'O':2}
  389. #A,labels=from_networkx_to_tensor(Gs[13],dic)
  390. #print(nx.to_numpy_matrix(Gs[13]),labels)
  391. #print(A,labels)
  392. #@todo : from_networkx_to_tensorflow
  393. # dataloader = DataLoader('Acyclic',root = "database",option = 'high',mode = "Pytorch")
  394. # dataloader.info()
  395. # inputs,adjs,y = dataloader.pytorch_dataset
  396. # """
  397. # test,train,valid = dataloader.dataset
  398. # Gs,y = test
  399. # Gs2,y2 = train
  400. # Gs3,y3 = valid
  401. # """
  402. # #Gs,y = dataloader.
  403. # #print(Gs,y)
  404. # """
  405. # Gs,y = dataloader.dataset
  406. # for G in Gs :
  407. # for e in G.edges():
  408. # print(G[e[0]])
  409. # """
  410. # #for e in Gs[13].edges():
  411. # # print(Gs[13][e[0]])
  412. # #print(from_networkx_to_tensor(Gs[7],{'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]}))
  413. # #dataset.open_files()
  414. # import os
  415. # import os.path as osp
  416. # import urllib
  417. # import tarfile
  418. # from zipfile import ZipFile
  419. # from gklearn.utils.graphfiles import loadDataset
  420. # import torch
  421. # import torch.nn.functional as F
  422. # import networkx as nx
  423. # import matplotlib.pyplot as plt
  424. # import numpy as np
  425. #
  426. # def DataLoader(name,root = 'data',mode = "Networkx",downloadAll = False,reload = False,letter = "High",number = 0,gender = "MM"):
  427. # dir_name = "_".join(name.split("-"))
  428. # if not osp.exists(root) :
  429. # os.makedirs(root)
  430. # url = "https://brunl01.users.greyc.fr/CHEMISTRY/"
  431. # urliam = "https://iapr-tc15.greyc.fr/IAM/"
  432. # list_database = {
  433. # "Ace" : (url,"ACEDataset.tar"),
  434. # "Acyclic" : (url,"Acyclic.tar.gz"),
  435. # "Aids" : (urliam,"AIDS.zip"),
  436. # "Alkane" : (url,"alkane_dataset.tar.gz"),
  437. # "Chiral" : (url,"DatasetAcyclicChiral.tar"),
  438. # "Coil_Del" : (urliam,"COIL-DEL.zip"),
  439. # "Coil_Rag" : (urliam,"COIL-RAG.zip"),
  440. # "Fingerprint" : (urliam,"Fingerprint.zip"),
  441. # "Grec" : (urliam,"GREC.zip"),
  442. # "Letter" : (urliam,"Letter.zip"),
  443. # "Mao" : (url,"mao.tgz"),
  444. # "Monoterpenoides" : (url,"monoterpenoides.tar.gz"),
  445. # "Mutagenicity" : (urliam,"Mutagenicity.zip"),
  446. # "Pah" : (url,"PAH.tar.gz"),
  447. # "Protein" : (urliam,"Protein.zip"),
  448. # "Ptc" : (url,"ptc.tgz"),
  449. # "Steroid" : (url,"SteroidDataset.tar"),
  450. # "Vitamin" : (url,"DatasetVitamin.tar"),
  451. # "Web" : (urliam,"Web.zip")
  452. # }
  453. #
  454. # data_to_use_in_datasets = {
  455. # "Acyclic" : ("Acyclic/dataset_bps.ds"),
  456. # "Aids" : ("AIDS_A.txt"),
  457. # "Alkane" : ("Alkane/dataset.ds","Alkane/dataset_boiling_point_names.txt"),
  458. # "Mao" : ("MAO/dataset.ds"),
  459. # "Monoterpenoides" : ("monoterpenoides/dataset_10+.ds"), #('monoterpenoides/dataset.ds'),('monoterpenoides/dataset_9.ds'),('monoterpenoides/trainset_9.ds')
  460. #
  461. # }
  462. # has_train_valid_test = {
  463. # "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'),
  464. # "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'),
  465. # "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'),
  466. # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'),
  467. # "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'),
  468. # 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'),
  469. # 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl')
  470. # },
  471. # "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'),
  472. # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'],
  473. # "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'),
  474. # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl')
  475. # }
  476. #
  477. # if not name :
  478. # raise ValueError("No dataset entered")
  479. # if name not in list_database:
  480. # message = "Invalid Dataset name " + name
  481. # message += '\n Available datasets are as follows : \n\n'
  482. # message += '\n'.join(database for database in list_database)
  483. # raise ValueError(message)
  484. #
  485. # def download_file(url,filename):
  486. # try :
  487. # response = urllib.request.urlopen(url + filename)
  488. # except urllib.error.HTTPError:
  489. # print(filename + " not available or incorrect http link")
  490. # return
  491. # return response
  492. #
  493. # def write_archive_file(root,database):
  494. # path = osp.join(root,database)
  495. # url,filename = list_database[database]
  496. # filename_dir = osp.join(path,filename)
  497. # if not osp.exists(filename_dir) or reload:
  498. # response = download_file(url,filename)
  499. # if response is None :
  500. # return
  501. # if not osp.exists(path) :
  502. # os.makedirs(path)
  503. # with open(filename_dir,'wb') as outfile :
  504. # outfile.write(response.read())
  505. #
  506. # if downloadAll :
  507. # print('Waiting...')
  508. # for database in list_database :
  509. # write_archive_file(root,database)
  510. # print('Downloading finished')
  511. # else:
  512. # write_archive_file(root,name)
  513. #
  514. # def iter_load_dataset(data):
  515. # results = []
  516. # for datasets in data :
  517. # results.append(loadDataset(osp.join(root,name,datasets)))
  518. # return results
  519. #
  520. # def load_dataset(list_files):
  521. # if name == "Ptc":
  522. # if gender.upper() not in ['FR','FM','MM','MR']:
  523. # raise ValueError('gender chosen needs to be one of \n fr fm mm mr')
  524. # results = []
  525. # results.append(loadDataset(osp.join(root,name,'PTC/Test',gender.upper() + '.ds')))
  526. # results.append(loadDataset(osp.join(root,name,'PTC/Train',gender.upper() + '.ds')))
  527. # return results
  528. # if name == "Pah":
  529. # maximum_sets = 0
  530. # for file in list_files:
  531. # if file.endswith('ds'):
  532. # maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0]))
  533. # if number > maximum_sets :
  534. # raise ValueError("Please select a dataset with number less than " + str(maximum_sets + 1))
  535. # data = has_train_valid_test["Pah"]
  536. # data[0] = has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(number) + '.ds'
  537. # data[1] = has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(number) + '.ds'
  538. # #print(data)
  539. # return iter_load_dataset(data)
  540. # if name == "Letter":
  541. # if letter.upper() in has_train_valid_test["Letter"]:
  542. # data = has_train_valid_test["Letter"][letter.upper()]
  543. # else:
  544. # message = "The parameter for letter is incorrect choose between : "
  545. # message += "\nhigh med low"
  546. # raise ValueError(message)
  547. # results = []
  548. # for datasets in data:
  549. # results.append(loadDataset(osp.join(root,name,datasets)))
  550. # return results
  551. # if name in has_train_valid_test : #common IAM dataset with train, valid and test
  552. # data = has_train_valid_test[name]
  553. # results = []
  554. # for datasets in data :
  555. # results.append(loadDataset(osp.join(root,name,datasets)))
  556. # return results
  557. # else: #common dataset without train,valid and test, only dataset.ds file
  558. # data = data_to_use_in_datasets[name]
  559. # if len(data) > 1 and data[0] in list_files and data[1] in list_files:
  560. # return loadDataset(osp.join(root,name,data[0]),filename_y = osp.join(root,name,data[1]))
  561. # if data in list_files:
  562. # return loadDataset(osp.join(root,name,data))
  563. # def open_files():
  564. # filename = list_database[name][1]
  565. # path = osp.join(root,name)
  566. # filename_archive = osp.join(root,name,filename)
  567. #
  568. # if filename.endswith('gz'):
  569. # if tarfile.is_tarfile(filename_archive):
  570. # with tarfile.open(filename_archive,"r:gz") as tar:
  571. # if reload:
  572. # print(filename + " Downloaded")
  573. # tar.extractall(path = path)
  574. # return load_dataset(tar.getnames())
  575. # #raise ValueError("dataset not available")
  576. #
  577. #
  578. # elif filename.endswith('.tar'):
  579. # if tarfile.is_tarfile(filename_archive):
  580. # with tarfile.open(filename_archive,"r:") as tar:
  581. # if reload :
  582. # print(filename + " Downloaded")
  583. # tar.extractall(path = path)
  584. # return load_dataset(tar.getnames())
  585. # elif filename.endswith('.zip'):
  586. # with ZipFile(filename_archive,"r") as zip_ref:
  587. # if reload :
  588. # print(filename + " Downloaded")
  589. # zip_ref.extractall(path)
  590. # return load_dataset(zip_ref.namelist())
  591. # else:
  592. # print(filename + " Unsupported file")
  593. # """
  594. # with tarfile.open(osp.join(root,name,list_database[name][1]),"r:gz") as files:
  595. # for file in files.getnames():
  596. # print(file)
  597. # """
  598. #
  599. # def build_dictionary(Gs):
  600. # labels = set()
  601. # bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])))
  602. # print(bond_type_number_maxi)
  603. # sizes = set()
  604. # for G in Gs :
  605. # for _,node in G.nodes(data = True): # or for node in nx.nodes(G)
  606. # #print(node)
  607. # labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0])
  608. # sizes.add(G.order())
  609. # if len(labels) >= bond_type_number_maxi:
  610. # break
  611. # label_dict = {}
  612. # for i,label in enumerate(labels):
  613. # label_dict[label] = [0.]*bond_type_number_maxi
  614. # label_dict[label][i] = 1.
  615. # return label_dict
  616. #
  617. # def from_networkx_to_pytorch(Gs):
  618. # #exemple : atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]}
  619. # # code from https://github.com/bgauzere/pygnn/blob/master/utils.py
  620. # atom_to_onehot = build_dictionary(Gs)
  621. # max_size = 30
  622. # adjs = []
  623. # inputs = []
  624. # for i, G in enumerate(Gs):
  625. # I = torch.eye(G.order(), G.order())
  626. # A = torch.Tensor(nx.adjacency_matrix(G).todense())
  627. # A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int)
  628. # adj = F.pad(A+I, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ?
  629. # adjs.append(adj)
  630. # f_0 = []
  631. # for _, label in G.nodes(data=True):
  632. # #print(_,label)
  633. # cur_label = atom_to_onehot[label['label'][0]].copy()
  634. # f_0.append(cur_label)
  635. # X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order()))
  636. # inputs.append(X)
  637. # return inputs,adjs,y
  638. #
  639. # def from_networkx_to_tensor(G,dict):
  640. # A=nx.to_numpy_matrix(G)
  641. # lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)]
  642. # return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab))
  643. #
  644. # dataset= open_files()
  645. # #print(build_dictionary(Gs))
  646. # #dic={'C':0,'N':1,'O':2}
  647. # #A,labels=from_networkx_to_tensor(Gs[13],dic)
  648. # #print(nx.to_numpy_matrix(Gs[13]),labels)
  649. # #print(A,labels)
  650. #
  651. # """
  652. # for G in Gs :
  653. # for node in nx.nodes(G):
  654. # print(G.nodes[node])
  655. # """
  656. # if mode == "pytorch":
  657. # Gs,y = dataset
  658. # inputs,adjs,y = from_networkx_to_pytorch(Gs)
  659. # print(inputs,adjs)
  660. # return inputs,adjs,y
  661. #
  662. #
  663. # """
  664. # dic = dict()
  665. # for i,l in enumerate(label):
  666. # dic[l] = i
  667. # dic = {'C': 0, 'N': 1, 'O': 2}
  668. # A,labels=from_networkx_to_tensor(Gs[0],dic)
  669. # #print(A,labels)
  670. # return A,labels
  671. # """
  672. #
  673. # return dataset
  674. #
  675. # #open_files()
  676. #
  677. # def label_to_color(label):
  678. # if label == 'C':
  679. # return 0.1
  680. # elif label == 'O':
  681. # return 0.8
  682. #
  683. # def nodes_to_color_sequence(G):
  684. # return [label_to_color(c[1]['label'][0]) for c in G.nodes(data=True)]
  685. # ##############
  686. # """
  687. # dataset = DataLoader('Mao',root = "database")
  688. # print(dataset)
  689. # Gs,y = dataset
  690. # """
  691. # """
  692. # dataset = DataLoader('Alkane',root = "database") # Gs is empty here whereas y isn't -> not working
  693. # Gs,y = dataset
  694. # """
  695. # """
  696. # dataset = DataLoader('Acyclic', root = "database")
  697. # Gs,y = dataset
  698. # """
  699. # """
  700. # dataset = DataLoader('Monoterpenoides', root = "database")
  701. # Gs,y = dataset
  702. # """
  703. # """
  704. # dataset = DataLoader('Pah',root = 'database', number = 8)
  705. # test_set,train_set = dataset
  706. # Gs,y = test_set
  707. # Gs2,y2 = train_set
  708. # """
  709. # """
  710. # dataset = DataLoader('Coil_Del',root = "database")
  711. # test,train,valid = dataset
  712. # Gs,y = test
  713. # Gs2,y2 = train
  714. # Gs3, y3 = valid
  715. # """
  716. # """
  717. # dataset = DataLoader('Coil_Rag',root = "database")
  718. # test,train,valid = dataset
  719. # Gs,y = test
  720. # Gs2,y2 = train
  721. # Gs3, y3 = valid
  722. # """
  723. # """
  724. # dataset = DataLoader('Fingerprint',root = "database")
  725. # test,train,valid = dataset
  726. # Gs,y = test
  727. # Gs2,y2 = train
  728. # Gs3, y3 = valid
  729. # """
  730. # """
  731. # dataset = DataLoader('Grec',root = "database")
  732. # test,train,valid = dataset
  733. # Gs,y = test
  734. # Gs2,y2 = train
  735. # Gs3, y3 = valid
  736. # """
  737. # """
  738. # dataset = DataLoader('Letter',root = "database",letter = 'low') #high low med
  739. # test,train,valid = dataset
  740. # Gs,y = test
  741. # Gs2,y2 = train
  742. # Gs3, y3 = valid
  743. # """
  744. # """
  745. # dataset = DataLoader('Mutagenicity',root = "database")
  746. # test,train,valid = dataset
  747. # Gs,y = test
  748. # Gs2,y2 = train
  749. # Gs3, y3 = valid
  750. # """
  751. # """
  752. # dataset = DataLoader('Protein',root = "database")
  753. # test,train,valid = dataset
  754. # Gs,y = test
  755. # Gs2,y2 = train
  756. # Gs3, y3 = valid
  757. # """
  758. # """
  759. # dataset = DataLoader('Ptc', root = "database",gender = 'fm') # not working, Gs and y are empty perhaps issue coming from loadDataset
  760. # valid,train = dataset
  761. # Gs,y = valid
  762. # Gs2,y2 = train
  763. # """
  764. # """
  765. # dataset = DataLoader('Web', root = "database")
  766. # test,train,valid = dataset
  767. # Gs,y = test
  768. # Gs2,y2 = train
  769. # Gs3,y3 = valid
  770. # """
  771. # print(Gs,y)
  772. # print(len(dataset))
  773. # ##############
  774. # #print('edge max label',max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])))
  775. # G1 = Gs[13]
  776. # G2 = Gs[23]
  777. # """
  778. # nx.draw_networkx(G1,with_labels=True,node_color = nodes_to_color_sequence(G1),cmap='autumn')
  779. # plt.figure()
  780. # nx.draw_networkx(G2,with_labels=True,node_color = nodes_to_color_sequence(G2),cmap='autumn')
  781. # """
  782. # from pathlib import Path
  783. # DATA_PATH = Path("data")
  784. # def import_datasets():
  785. #
  786. # import urllib
  787. # import tarfile
  788. # from zipfile import ZipFile
  789. # URL = "https://brunl01.users.greyc.fr/CHEMISTRY/"
  790. # URLIAM = "https://iapr-tc15.greyc.fr/IAM/"
  791. #
  792. # LIST_DATABASE = {
  793. # "Pah" : (URL,"PAH.tar.gz"),
  794. # "Mao" : (URL,"mao.tgz"),
  795. # "Ptc" : (URL,"ptc.tgz"),
  796. # "Aids" : (URLIAM,"AIDS.zip"),
  797. # "Acyclic" : (URL,"Acyclic.tar.gz"),
  798. # "Alkane" : (URL,"alkane_dataset.tar.gz"),
  799. # "Chiral" : (URL,"DatasetAcyclicChiral.tar"),
  800. # "Vitamin" : (URL,"DatasetVitamin.tar"),
  801. # "Ace" : (URL,"ACEDataset.tar"),
  802. # "Steroid" : (URL,"SteroidDataset.tar"),
  803. # "Monoterpenoides" : (URL,"monoterpenoides.tar.gz"),
  804. # "Letter" : (URLIAM,"Letter.zip"),
  805. # "Grec" : (URLIAM,"GREC.zip"),
  806. # "Fingerprint" : (URLIAM,"Fingerprint.zip"),
  807. # "Coil_Rag" : (URLIAM,"COIL-RAG.zip"),
  808. # "Coil_Del" : (URLIAM,"COIL-DEL.zip"),
  809. # "Web" : (URLIAM,"Web.zip"),
  810. # "Mutagenicity" : (URLIAM,"Mutagenicity.zip"),
  811. # "Protein" : (URLIAM,"Protein.zip")
  812. # }
  813. # print("Select databases in the list. Select multiple, split by white spaces .\nWrite All to select all of them.\n")
  814. # print(', '.join(database for database in LIST_DATABASE))
  815. # print("Choice : ",end = ' ')
  816. # selected_databases = input().split()
  817. #
  818. # def download_file(url,filename):
  819. # try :
  820. # response = urllib.request.urlopen(url + filename)
  821. # except urllib.error.HTTPError:
  822. # print(filename + " not available or incorrect http link")
  823. # return
  824. # return response
  825. #
  826. # def write_archive_file(database):
  827. #
  828. # PATH = DATA_PATH / database
  829. # url,filename = LIST_DATABASE[database]
  830. # if not (PATH / filename).exists():
  831. # response = download_file(url,filename)
  832. # if response is None :
  833. # return
  834. # if not PATH.exists() :
  835. # PATH.mkdir(parents=True, exist_ok=True)
  836. # with open(PATH/filename,'wb') as outfile :
  837. # outfile.write(response.read())
  838. #
  839. # if filename[-2:] == 'gz':
  840. # if tarfile.is_tarfile(PATH/filename):
  841. # with tarfile.open(PATH/filename,"r:gz") as tar:
  842. # tar.extractall(path = PATH)
  843. # print(filename + ' Downloaded')
  844. # elif filename[-3:] == 'tar':
  845. # if tarfile.is_tarfile(PATH/filename):
  846. # with tarfile.open(PATH/filename,"r:") as tar:
  847. # tar.extractall(path = PATH)
  848. # print(filename + ' Downloaded')
  849. # elif filename[-3:] == 'zip':
  850. # with ZipFile(PATH/filename,"r") as zip_ref:
  851. # zip_ref.extractall(PATH)
  852. # print(filename + ' Downloaded')
  853. # else:
  854. # print("Unsupported file")
  855. # if 'All' in selected_databases:
  856. # print('Waiting...')
  857. # for database in LIST_DATABASE :
  858. # write_archive_file(database)
  859. # print('Finished')
  860. # else:
  861. # print('Waiting...')
  862. # for database in selected_databases :
  863. # if database in LIST_DATABASE :
  864. # write_archive_file(database)
  865. # print('Finished')
  866. # import_datasets()
  867. # class GraphFetcher(object):
  868. #
  869. #
  870. # def __init__(self, filename=None, filename_targets=None, **kwargs):
  871. # if filename is None:
  872. # self._graphs = None
  873. # self._targets = None
  874. # self._node_labels = None
  875. # self._edge_labels = None
  876. # self._node_attrs = None
  877. # self._edge_attrs = None
  878. # else:
  879. # self.load_dataset(filename, filename_targets=filename_targets, **kwargs)
  880. #
  881. # self._substructures = None
  882. # self._node_label_dim = None
  883. # self._edge_label_dim = None
  884. # self._directed = None
  885. # self._dataset_size = None
  886. # self._total_node_num = None
  887. # self._ave_node_num = None
  888. # self._min_node_num = None
  889. # self._max_node_num = None
  890. # self._total_edge_num = None
  891. # self._ave_edge_num = None
  892. # self._min_edge_num = None
  893. # self._max_edge_num = None
  894. # self._ave_node_degree = None
  895. # self._min_node_degree = None
  896. # self._max_node_degree = None
  897. # self._ave_fill_factor = None
  898. # self._min_fill_factor = None
  899. # self._max_fill_factor = None
  900. # self._node_label_nums = None
  901. # self._edge_label_nums = None
  902. # self._node_attr_dim = None
  903. # self._edge_attr_dim = None
  904. # self._class_number = None
  905. #
  906. #
  907. # def load_dataset(self, filename, filename_targets=None, **kwargs):
  908. # self._graphs, self._targets, label_names = load_dataset(filename, filename_targets=filename_targets, **kwargs)
  909. # self._node_labels = label_names['node_labels']
  910. # self._node_attrs = label_names['node_attrs']
  911. # self._edge_labels = label_names['edge_labels']
  912. # self._edge_attrs = label_names['edge_attrs']
  913. # self.clean_labels()
  914. #
  915. #
  916. # def load_graphs(self, graphs, targets=None):
  917. # # this has to be followed by set_labels().
  918. # self._graphs = graphs
  919. # self._targets = targets
  920. # # self.set_labels_attrs() # @todo
  921. #
  922. #
  923. # def load_predefined_dataset(self, ds_name):
  924. # current_path = os.path.dirname(os.path.realpath(__file__)) + '/'
  925. # if ds_name == 'Acyclic':
  926. # ds_file = current_path + '../../datasets/Acyclic/dataset_bps.ds'
  927. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  928. # elif ds_name == 'AIDS':
  929. # ds_file = current_path + '../../datasets/AIDS/AIDS_A.txt'
  930. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  931. # elif ds_name == 'Alkane':
  932. # ds_file = current_path + '../../datasets/Alkane/dataset.ds'
  933. # fn_targets = current_path + '../../datasets/Alkane/dataset_boiling_point_names.txt'
  934. # self._graphs, self._targets, label_names = load_dataset(ds_file, filename_targets=fn_targets)
  935. # elif ds_name == 'COIL-DEL':
  936. # ds_file = current_path + '../../datasets/COIL-DEL/COIL-DEL_A.txt'
  937. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  938. # elif ds_name == 'COIL-RAG':
  939. # ds_file = current_path + '../../datasets/COIL-RAG/COIL-RAG_A.txt'
  940. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  941. # elif ds_name == 'COLORS-3':
  942. # ds_file = current_path + '../../datasets/COLORS-3/COLORS-3_A.txt'
  943. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  944. # elif ds_name == 'Cuneiform':
  945. # ds_file = current_path + '../../datasets/Cuneiform/Cuneiform_A.txt'
  946. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  947. # elif ds_name == 'DD':
  948. # ds_file = current_path + '../../datasets/DD/DD_A.txt'
  949. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  950. # elif ds_name == 'ENZYMES':
  951. # ds_file = current_path + '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt'
  952. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  953. # elif ds_name == 'Fingerprint':
  954. # ds_file = current_path + '../../datasets/Fingerprint/Fingerprint_A.txt'
  955. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  956. # elif ds_name == 'FRANKENSTEIN':
  957. # ds_file = current_path + '../../datasets/FRANKENSTEIN/FRANKENSTEIN_A.txt'
  958. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  959. # elif ds_name == 'Letter-high': # node non-symb
  960. # ds_file = current_path + '../../datasets/Letter-high/Letter-high_A.txt'
  961. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  962. # elif ds_name == 'Letter-low': # node non-symb
  963. # ds_file = current_path + '../../datasets/Letter-low/Letter-low_A.txt'
  964. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  965. # elif ds_name == 'Letter-med': # node non-symb
  966. # ds_file = current_path + '../../datasets/Letter-med/Letter-med_A.txt'
  967. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  968. # elif ds_name == 'MAO':
  969. # ds_file = current_path + '../../datasets/MAO/dataset.ds'
  970. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  971. # elif ds_name == 'Monoterpenoides':
  972. # ds_file = current_path + '../../datasets/Monoterpenoides/dataset_10+.ds'
  973. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  974. # elif ds_name == 'MUTAG':
  975. # ds_file = current_path + '../../datasets/MUTAG/MUTAG_A.txt'
  976. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  977. # elif ds_name == 'NCI1':
  978. # ds_file = current_path + '../../datasets/NCI1/NCI1_A.txt'
  979. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  980. # elif ds_name == 'NCI109':
  981. # ds_file = current_path + '../../datasets/NCI109/NCI109_A.txt'
  982. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  983. # elif ds_name == 'PAH':
  984. # ds_file = current_path + '../../datasets/PAH/dataset.ds'
  985. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  986. # elif ds_name == 'SYNTHETIC':
  987. # pass
  988. # elif ds_name == 'SYNTHETICnew':
  989. # ds_file = current_path + '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
  990. # self._graphs, self._targets, label_names = load_dataset(ds_file)
  991. # elif ds_name == 'Synthie':
  992. # pass
  993. # else:
  994. # raise Exception('The dataset name "', ds_name, '" is not pre-defined.')
  995. #
  996. # self._node_labels = label_names['node_labels']
  997. # self._node_attrs = label_names['node_attrs']
  998. # self._edge_labels = label_names['edge_labels']
  999. # self._edge_attrs = label_names['edge_attrs']
  1000. # self.clean_labels()
  1001. #
  1002. # def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]):
  1003. # self._node_labels = node_labels
  1004. # self._node_attrs = node_attrs
  1005. # self._edge_labels = edge_labels
  1006. # self._edge_attrs = edge_attrs
  1007. #
  1008. # def set_labels_attrs(self, node_labels=None, node_attrs=None, edge_labels=None, edge_attrs=None):
  1009. # # @todo: remove labels which have only one possible values.
  1010. # if node_labels is None:
  1011. # self._node_labels = self._graphs[0].graph['node_labels']
  1012. # # # graphs are considered node unlabeled if all nodes have the same label.
  1013. # # infos.update({'node_labeled': is_nl if node_label_num > 1 else False})
  1014. # if node_attrs is None:
  1015. # self._node_attrs = self._graphs[0].graph['node_attrs']
  1016. # # for G in Gn:
  1017. # # for n in G.nodes(data=True):
  1018. # # if 'attributes' in n[1]:
  1019. # # return len(n[1]['attributes'])
  1020. # # return 0
  1021. # if edge_labels is None:
  1022. # self._edge_labels = self._graphs[0].graph['edge_labels']
  1023. # # # graphs are considered edge unlabeled if all edges have the same label.
  1024. # # infos.update({'edge_labeled': is_el if edge_label_num > 1 else False})
  1025. # if edge_attrs is None:
  1026. # self._edge_attrs = self._graphs[0].graph['edge_attrs']
  1027. # # for G in Gn:
  1028. # # if nx.number_of_edges(G) > 0:
  1029. # # for e in G.edges(data=True):
  1030. # # if 'attributes' in e[2]:
  1031. # # return len(e[2]['attributes'])
  1032. # # return 0
  1033. #
  1034. #
  1035. # def get_dataset_infos(self, keys=None, params=None):
  1036. # """Computes and returns the structure and property information of the graph dataset.
  1037. #
  1038. # Parameters
  1039. # ----------
  1040. # keys : list, optional
  1041. # A list of strings which indicate which informations will be returned. The
  1042. # possible choices includes:
  1043. #
  1044. # 'substructures': sub-structures graphs contains, including 'linear', 'non
  1045. # linear' and 'cyclic'.
  1046. #
  1047. # 'node_label_dim': whether vertices have symbolic labels.
  1048. #
  1049. # 'edge_label_dim': whether egdes have symbolic labels.
  1050. #
  1051. # 'directed': whether graphs in dataset are directed.
  1052. #
  1053. # 'dataset_size': number of graphs in dataset.
  1054. #
  1055. # 'total_node_num': total number of vertices of all graphs in dataset.
  1056. #
  1057. # 'ave_node_num': average number of vertices of graphs in dataset.
  1058. #
  1059. # 'min_node_num': minimum number of vertices of graphs in dataset.
  1060. #
  1061. # 'max_node_num': maximum number of vertices of graphs in dataset.
  1062. #
  1063. # 'total_edge_num': total number of edges of all graphs in dataset.
  1064. #
  1065. # 'ave_edge_num': average number of edges of graphs in dataset.
  1066. #
  1067. # 'min_edge_num': minimum number of edges of graphs in dataset.
  1068. #
  1069. # 'max_edge_num': maximum number of edges of graphs in dataset.
  1070. #
  1071. # 'ave_node_degree': average vertex degree of graphs in dataset.
  1072. #
  1073. # 'min_node_degree': minimum vertex degree of graphs in dataset.
  1074. #
  1075. # 'max_node_degree': maximum vertex degree of graphs in dataset.
  1076. #
  1077. # 'ave_fill_factor': average fill factor (number_of_edges /
  1078. # (number_of_nodes ** 2)) of graphs in dataset.
  1079. #
  1080. # 'min_fill_factor': minimum fill factor of graphs in dataset.
  1081. #
  1082. # 'max_fill_factor': maximum fill factor of graphs in dataset.
  1083. #
  1084. # 'node_label_nums': list of numbers of symbolic vertex labels of graphs in dataset.
  1085. #
  1086. # 'edge_label_nums': list number of symbolic edge labels of graphs in dataset.
  1087. #
  1088. # 'node_attr_dim': number of dimensions of non-symbolic vertex labels.
  1089. # Extracted from the 'attributes' attribute of graph nodes.
  1090. #
  1091. # 'edge_attr_dim': number of dimensions of non-symbolic edge labels.
  1092. # Extracted from the 'attributes' attribute of graph edges.
  1093. #
  1094. # 'class_number': number of classes. Only available for classification problems.
  1095. #
  1096. # 'all_degree_entropy': the entropy of degree distribution of each graph.
  1097. #
  1098. # 'ave_degree_entropy': the average entropy of degree distribution of all graphs.
  1099. #
  1100. # All informations above will be returned if `keys` is not given.
  1101. #
  1102. # params: dict of dict, optional
  1103. # A dictinary which contains extra parameters for each possible
  1104. # element in ``keys``.
  1105. #
  1106. # Return
  1107. # ------
  1108. # dict
  1109. # Information of the graph dataset keyed by `keys`.
  1110. # """
  1111. # infos = {}
  1112. #
  1113. # if keys == None:
  1114. # keys = [
  1115. # 'substructures',
  1116. # 'node_label_dim',
  1117. # 'edge_label_dim',
  1118. # 'directed',
  1119. # 'dataset_size',
  1120. # 'total_node_num',
  1121. # 'ave_node_num',
  1122. # 'min_node_num',
  1123. # 'max_node_num',
  1124. # 'total_edge_num',
  1125. # 'ave_edge_num',
  1126. # 'min_edge_num',
  1127. # 'max_edge_num',
  1128. # 'ave_node_degree',
  1129. # 'min_node_degree',
  1130. # 'max_node_degree',
  1131. # 'ave_fill_factor',
  1132. # 'min_fill_factor',
  1133. # 'max_fill_factor',
  1134. # 'node_label_nums',
  1135. # 'edge_label_nums',
  1136. # 'node_attr_dim',
  1137. # 'edge_attr_dim',
  1138. # 'class_number',
  1139. # 'all_degree_entropy',
  1140. # 'ave_degree_entropy'
  1141. # ]
  1142. #
  1143. # # dataset size
  1144. # if 'dataset_size' in keys:
  1145. # if self._dataset_size is None:
  1146. # self._dataset_size = self._get_dataset_size()
  1147. # infos['dataset_size'] = self._dataset_size
  1148. #
  1149. # # graph node number
  1150. # if any(i in keys for i in ['total_node_num', 'ave_node_num', 'min_node_num', 'max_node_num']):
  1151. # all_node_nums = self._get_all_node_nums()
  1152. # if 'total_node_num' in keys:
  1153. # if self._total_node_num is None:
  1154. # self._total_node_num = self._get_total_node_num(all_node_nums)
  1155. # infos['total_node_num'] = self._total_node_num
  1156. #
  1157. # if 'ave_node_num' in keys:
  1158. # if self._ave_node_num is None:
  1159. # self._ave_node_num = self._get_ave_node_num(all_node_nums)
  1160. # infos['ave_node_num'] = self._ave_node_num
  1161. #
  1162. # if 'min_node_num' in keys:
  1163. # if self._min_node_num is None:
  1164. # self._min_node_num = self._get_min_node_num(all_node_nums)
  1165. # infos['min_node_num'] = self._min_node_num
  1166. #
  1167. # if 'max_node_num' in keys:
  1168. # if self._max_node_num is None:
  1169. # self._max_node_num = self._get_max_node_num(all_node_nums)
  1170. # infos['max_node_num'] = self._max_node_num
  1171. #
  1172. # # graph edge number
  1173. # if any(i in keys for i in ['total_edge_num', 'ave_edge_num', 'min_edge_num', 'max_edge_num']):
  1174. # all_edge_nums = self._get_all_edge_nums()
  1175. # if 'total_edge_num' in keys:
  1176. # if self._total_edge_num is None:
  1177. # self._total_edge_num = self._get_total_edge_num(all_edge_nums)
  1178. # infos['total_edge_num'] = self._total_edge_num
  1179. #
  1180. # if 'ave_edge_num' in keys:
  1181. # if self._ave_edge_num is None:
  1182. # self._ave_edge_num = self._get_ave_edge_num(all_edge_nums)
  1183. # infos['ave_edge_num'] = self._ave_edge_num
  1184. #
  1185. # if 'max_edge_num' in keys:
  1186. # if self._max_edge_num is None:
  1187. # self._max_edge_num = self._get_max_edge_num(all_edge_nums)
  1188. # infos['max_edge_num'] = self._max_edge_num
  1189. # if 'min_edge_num' in keys:
  1190. # if self._min_edge_num is None:
  1191. # self._min_edge_num = self._get_min_edge_num(all_edge_nums)
  1192. # infos['min_edge_num'] = self._min_edge_num
  1193. #
  1194. # # label number
  1195. # if 'node_label_dim' in keys:
  1196. # if self._node_label_dim is None:
  1197. # self._node_label_dim = self._get_node_label_dim()
  1198. # infos['node_label_dim'] = self._node_label_dim
  1199. #
  1200. # if 'node_label_nums' in keys:
  1201. # if self._node_label_nums is None:
  1202. # self._node_label_nums = {}
  1203. # for node_label in self._node_labels:
  1204. # self._node_label_nums[node_label] = self._get_node_label_num(node_label)
  1205. # infos['node_label_nums'] = self._node_label_nums
  1206. #
  1207. # if 'edge_label_dim' in keys:
  1208. # if self._edge_label_dim is None:
  1209. # self._edge_label_dim = self._get_edge_label_dim()
  1210. # infos['edge_label_dim'] = self._edge_label_dim
  1211. #
  1212. # if 'edge_label_nums' in keys:
  1213. # if self._edge_label_nums is None:
  1214. # self._edge_label_nums = {}
  1215. # for edge_label in self._edge_labels:
  1216. # self._edge_label_nums[edge_label] = self._get_edge_label_num(edge_label)
  1217. # infos['edge_label_nums'] = self._edge_label_nums
  1218. #
  1219. # if 'directed' in keys or 'substructures' in keys:
  1220. # if self._directed is None:
  1221. # self._directed = self._is_directed()
  1222. # infos['directed'] = self._directed
  1223. #
  1224. # # node degree
  1225. # if any(i in keys for i in ['ave_node_degree', 'max_node_degree', 'min_node_degree']):
  1226. # all_node_degrees = self._get_all_node_degrees()
  1227. #
  1228. # if 'ave_node_degree' in keys:
  1229. # if self._ave_node_degree is None:
  1230. # self._ave_node_degree = self._get_ave_node_degree(all_node_degrees)
  1231. # infos['ave_node_degree'] = self._ave_node_degree
  1232. #
  1233. # if 'max_node_degree' in keys:
  1234. # if self._max_node_degree is None:
  1235. # self._max_node_degree = self._get_max_node_degree(all_node_degrees)
  1236. # infos['max_node_degree'] = self._max_node_degree
  1237. #
  1238. # if 'min_node_degree' in keys:
  1239. # if self._min_node_degree is None:
  1240. # self._min_node_degree = self._get_min_node_degree(all_node_degrees)
  1241. # infos['min_node_degree'] = self._min_node_degree
  1242. #
  1243. # # fill factor
  1244. # if any(i in keys for i in ['ave_fill_factor', 'max_fill_factor', 'min_fill_factor']):
  1245. # all_fill_factors = self._get_all_fill_factors()
  1246. #
  1247. # if 'ave_fill_factor' in keys:
  1248. # if self._ave_fill_factor is None:
  1249. # self._ave_fill_factor = self._get_ave_fill_factor(all_fill_factors)
  1250. # infos['ave_fill_factor'] = self._ave_fill_factor
  1251. #
  1252. # if 'max_fill_factor' in keys:
  1253. # if self._max_fill_factor is None:
  1254. # self._max_fill_factor = self._get_max_fill_factor(all_fill_factors)
  1255. # infos['max_fill_factor'] = self._max_fill_factor
  1256. #
  1257. # if 'min_fill_factor' in keys:
  1258. # if self._min_fill_factor is None:
  1259. # self._min_fill_factor = self._get_min_fill_factor(all_fill_factors)
  1260. # infos['min_fill_factor'] = self._min_fill_factor
  1261. #
  1262. # if 'substructures' in keys:
  1263. # if self._substructures is None:
  1264. # self._substructures = self._get_substructures()
  1265. # infos['substructures'] = self._substructures
  1266. #
  1267. # if 'class_number' in keys:
  1268. # if self._class_number is None:
  1269. # self._class_number = self._get_class_number()
  1270. # infos['class_number'] = self._class_number
  1271. #
  1272. # if 'node_attr_dim' in keys:
  1273. # if self._node_attr_dim is None:
  1274. # self._node_attr_dim = self._get_node_attr_dim()
  1275. # infos['node_attr_dim'] = self._node_attr_dim
  1276. #
  1277. # if 'edge_attr_dim' in keys:
  1278. # if self._edge_attr_dim is None:
  1279. # self._edge_attr_dim = self._get_edge_attr_dim()
  1280. # infos['edge_attr_dim'] = self._edge_attr_dim
  1281. #
  1282. # # entropy of degree distribution.
  1283. #
  1284. # if 'all_degree_entropy' in keys:
  1285. # if params is not None and ('all_degree_entropy' in params) and ('base' in params['all_degree_entropy']):
  1286. # base = params['all_degree_entropy']['base']
  1287. # else:
  1288. # base = None
  1289. # infos['all_degree_entropy'] = self._compute_all_degree_entropy(base=base)
  1290. #
  1291. # if 'ave_degree_entropy' in keys:
  1292. # if params is not None and ('ave_degree_entropy' in params) and ('base' in params['ave_degree_entropy']):
  1293. # base = params['ave_degree_entropy']['base']
  1294. # else:
  1295. # base = None
  1296. # infos['ave_degree_entropy'] = np.mean(self._compute_all_degree_entropy(base=base))
  1297. #
  1298. # return infos
  1299. #
  1300. #
  1301. # def print_graph_infos(self, infos):
  1302. # from collections import OrderedDict
  1303. # keys = list(infos.keys())
  1304. # print(OrderedDict(sorted(infos.items(), key=lambda i: keys.index(i[0]))))
  1305. #
  1306. #
  1307. # def remove_labels(self, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):
  1308. # node_labels = [item for item in node_labels if item in self._node_labels]
  1309. # edge_labels = [item for item in edge_labels if item in self._edge_labels]
  1310. # node_attrs = [item for item in node_attrs if item in self._node_attrs]
  1311. # edge_attrs = [item for item in edge_attrs if item in self._edge_attrs]
  1312. # for g in self._graphs:
  1313. # for nd in g.nodes():
  1314. # for nl in node_labels:
  1315. # del g.nodes[nd][nl]
  1316. # for na in node_attrs:
  1317. # del g.nodes[nd][na]
  1318. # for ed in g.edges():
  1319. # for el in edge_labels:
  1320. # del g.edges[ed][el]
  1321. # for ea in edge_attrs:
  1322. # del g.edges[ed][ea]
  1323. # if len(node_labels) > 0:
  1324. # self._node_labels = [nl for nl in self._node_labels if nl not in node_labels]
  1325. # if len(edge_labels) > 0:
  1326. # self._edge_labels = [el for el in self._edge_labels if el not in edge_labels]
  1327. # if len(node_attrs) > 0:
  1328. # self._node_attrs = [na for na in self._node_attrs if na not in node_attrs]
  1329. # if len(edge_attrs) > 0:
  1330. # self._edge_attrs = [ea for ea in self._edge_attrs if ea not in edge_attrs]
  1331. #
  1332. #
  1333. # def clean_labels(self):
  1334. # labels = []
  1335. # for name in self._node_labels:
  1336. # label = set()
  1337. # for G in self._graphs:
  1338. # label = label | set(nx.get_node_attributes(G, name).values())
  1339. # if len(label) > 1:
  1340. # labels.append(name)
  1341. # break
  1342. # if len(label) < 2:
  1343. # for G in self._graphs:
  1344. # for nd in G.nodes():
  1345. # del G.nodes[nd][name]
  1346. # self._node_labels = labels
  1347. # labels = []
  1348. # for name in self._edge_labels:
  1349. # label = set()
  1350. # for G in self._graphs:
  1351. # label = label | set(nx.get_edge_attributes(G, name).values())
  1352. # if len(label) > 1:
  1353. # labels.append(name)
  1354. # break
  1355. # if len(label) < 2:
  1356. # for G in self._graphs:
  1357. # for ed in G.edges():
  1358. # del G.edges[ed][name]
  1359. # self._edge_labels = labels
  1360. # labels = []
  1361. # for name in self._node_attrs:
  1362. # label = set()
  1363. # for G in self._graphs:
  1364. # label = label | set(nx.get_node_attributes(G, name).values())
  1365. # if len(label) > 1:
  1366. # labels.append(name)
  1367. # break
  1368. # if len(label) < 2:
  1369. # for G in self._graphs:
  1370. # for nd in G.nodes():
  1371. # del G.nodes[nd][name]
  1372. # self._node_attrs = labels
  1373. # labels = []
  1374. # for name in self._edge_attrs:
  1375. # label = set()
  1376. # for G in self._graphs:
  1377. # label = label | set(nx.get_edge_attributes(G, name).values())
  1378. # if len(label) > 1:
  1379. # labels.append(name)
  1380. # break
  1381. # if len(label) < 2:
  1382. # for G in self._graphs:
  1383. # for ed in G.edges():
  1384. # del G.edges[ed][name]
  1385. # self._edge_attrs = labels
  1386. #
  1387. #
  1388. # def cut_graphs(self, range_):
  1389. # self._graphs = [self._graphs[i] for i in range_]
  1390. # if self._targets is not None:
  1391. # self._targets = [self._targets[i] for i in range_]
  1392. # self.clean_labels()
  1393. # def trim_dataset(self, edge_required=False):
  1394. # if edge_required:
  1395. # trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)]
  1396. # else:
  1397. # trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0]
  1398. # idx = [p[0] for p in trimed_pairs]
  1399. # self._graphs = [p[1] for p in trimed_pairs]
  1400. # self._targets = [self._targets[i] for i in idx]
  1401. # self.clean_labels()
  1402. #
  1403. #
  1404. # def copy(self):
  1405. # dataset = Dataset()
  1406. # graphs = [g.copy() for g in self._graphs] if self._graphs is not None else None
  1407. # target = self._targets.copy() if self._targets is not None else None
  1408. # node_labels = self._node_labels.copy() if self._node_labels is not None else None
  1409. # node_attrs = self._node_attrs.copy() if self._node_attrs is not None else None
  1410. # edge_labels = self._edge_labels.copy() if self._edge_labels is not None else None
  1411. # edge_attrs = self._edge_attrs.copy() if self._edge_attrs is not None else None
  1412. # dataset.load_graphs(graphs, target)
  1413. # dataset.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
  1414. # # @todo: clean_labels and add other class members?
  1415. # return dataset
  1416. #
  1417. #
  1418. # def get_all_node_labels(self):
  1419. # node_labels = []
  1420. # for g in self._graphs:
  1421. # for n in g.nodes():
  1422. # nl = tuple(g.nodes[n].items())
  1423. # if nl not in node_labels:
  1424. # node_labels.append(nl)
  1425. # return node_labels
  1426. #
  1427. #
  1428. # def get_all_edge_labels(self):
  1429. # edge_labels = []
  1430. # for g in self._graphs:
  1431. # for e in g.edges():
  1432. # el = tuple(g.edges[e].items())
  1433. # if el not in edge_labels:
  1434. # edge_labels.append(el)
  1435. # return edge_labels
  1436. #
  1437. #
  1438. # def _get_dataset_size(self):
  1439. # return len(self._graphs)
  1440. #
  1441. #
  1442. # def _get_all_node_nums(self):
  1443. # return [nx.number_of_nodes(G) for G in self._graphs]
  1444. #
  1445. #
  1446. # def _get_total_node_nums(self, all_node_nums):
  1447. # return np.sum(all_node_nums)
  1448. #
  1449. #
  1450. # def _get_ave_node_num(self, all_node_nums):
  1451. # return np.mean(all_node_nums)
  1452. #
  1453. #
  1454. # def _get_min_node_num(self, all_node_nums):
  1455. # return np.amin(all_node_nums)
  1456. #
  1457. #
  1458. # def _get_max_node_num(self, all_node_nums):
  1459. # return np.amax(all_node_nums)
  1460. #
  1461. #
  1462. # def _get_all_edge_nums(self):
  1463. # return [nx.number_of_edges(G) for G in self._graphs]
  1464. #
  1465. #
  1466. # def _get_total_edge_nums(self, all_edge_nums):
  1467. # return np.sum(all_edge_nums)
  1468. #
  1469. #
  1470. # def _get_ave_edge_num(self, all_edge_nums):
  1471. # return np.mean(all_edge_nums)
  1472. #
  1473. #
  1474. # def _get_min_edge_num(self, all_edge_nums):
  1475. # return np.amin(all_edge_nums)
  1476. #
  1477. #
  1478. # def _get_max_edge_num(self, all_edge_nums):
  1479. # return np.amax(all_edge_nums)
  1480. #
  1481. #
  1482. # def _get_node_label_dim(self):
  1483. # return len(self._node_labels)
  1484. #
  1485. #
  1486. # def _get_node_label_num(self, node_label):
  1487. # nl = set()
  1488. # for G in self._graphs:
  1489. # nl = nl | set(nx.get_node_attributes(G, node_label).values())
  1490. # return len(nl)
  1491. #
  1492. #
  1493. # def _get_edge_label_dim(self):
  1494. # return len(self._edge_labels)
  1495. #
  1496. #
  1497. # def _get_edge_label_num(self, edge_label):
  1498. # el = set()
  1499. # for G in self._graphs:
  1500. # el = el | set(nx.get_edge_attributes(G, edge_label).values())
  1501. # return len(el)
  1502. #
  1503. #
  1504. # def _is_directed(self):
  1505. # return nx.is_directed(self._graphs[0])
  1506. #
  1507. #
  1508. # def _get_all_node_degrees(self):
  1509. # return [np.mean(list(dict(G.degree()).values())) for G in self._graphs]
  1510. #
  1511. #
  1512. # def _get_ave_node_degree(self, all_node_degrees):
  1513. # return np.mean(all_node_degrees)
  1514. #
  1515. #
  1516. # def _get_max_node_degree(self, all_node_degrees):
  1517. # return np.amax(all_node_degrees)
  1518. #
  1519. #
  1520. # def _get_min_node_degree(self, all_node_degrees):
  1521. # return np.amin(all_node_degrees)
  1522. #
  1523. #
  1524. # def _get_all_fill_factors(self):
  1525. # """Get fill factor, the number of non-zero entries in the adjacency matrix.
  1526. # Returns
  1527. # -------
  1528. # list[float]
  1529. # List of fill factors for all graphs.
  1530. # """
  1531. # return [nx.number_of_edges(G) / (nx.number_of_nodes(G) ** 2) for G in self._graphs]
  1532. #
  1533. # def _get_ave_fill_factor(self, all_fill_factors):
  1534. # return np.mean(all_fill_factors)
  1535. #
  1536. #
  1537. # def _get_max_fill_factor(self, all_fill_factors):
  1538. # return np.amax(all_fill_factors)
  1539. #
  1540. #
  1541. # def _get_min_fill_factor(self, all_fill_factors):
  1542. # return np.amin(all_fill_factors)
  1543. #
  1544. #
  1545. # def _get_substructures(self):
  1546. # subs = set()
  1547. # for G in self._graphs:
  1548. # degrees = list(dict(G.degree()).values())
  1549. # if any(i == 2 for i in degrees):
  1550. # subs.add('linear')
  1551. # if np.amax(degrees) >= 3:
  1552. # subs.add('non linear')
  1553. # if 'linear' in subs and 'non linear' in subs:
  1554. # break
  1555. # if self._directed:
  1556. # for G in self._graphs:
  1557. # if len(list(nx.find_cycle(G))) > 0:
  1558. # subs.add('cyclic')
  1559. # break
  1560. # # else:
  1561. # # # @todo: this method does not work for big graph with large amount of edges like D&D, try a better way.
  1562. # # upper = np.amin([nx.number_of_edges(G) for G in Gn]) * 2 + 10
  1563. # # for G in Gn:
  1564. # # if (nx.number_of_edges(G) < upper):
  1565. # # cyc = list(nx.simple_cycles(G.to_directed()))
  1566. # # if any(len(i) > 2 for i in cyc):
  1567. # # subs.add('cyclic')
  1568. # # break
  1569. # # if 'cyclic' not in subs:
  1570. # # for G in Gn:
  1571. # # cyc = list(nx.simple_cycles(G.to_directed()))
  1572. # # if any(len(i) > 2 for i in cyc):
  1573. # # subs.add('cyclic')
  1574. # # break
  1575. #
  1576. # return subs
  1577. #
  1578. #
  1579. # def _get_class_num(self):
  1580. # return len(set(self._targets))
  1581. #
  1582. #
  1583. # def _get_node_attr_dim(self):
  1584. # return len(self._node_attrs)
  1585. #
  1586. #
  1587. # def _get_edge_attr_dim(self):
  1588. # return len(self._edge_attrs)
  1589. #
  1590. # def _compute_all_degree_entropy(self, base=None):
  1591. # """Compute the entropy of degree distribution of each graph.
  1592. # Parameters
  1593. # ----------
  1594. # base : float, optional
  1595. # The logarithmic base to use. The default is ``e`` (natural logarithm).
  1596. # Returns
  1597. # -------
  1598. # degree_entropy : float
  1599. # The calculated entropy.
  1600. # """
  1601. # from gklearn.utils.stats import entropy
  1602. #
  1603. # degree_entropy = []
  1604. # for g in self._graphs:
  1605. # degrees = list(dict(g.degree()).values())
  1606. # en = entropy(degrees, base=base)
  1607. # degree_entropy.append(en)
  1608. # return degree_entropy
  1609. #
  1610. #
  1611. # @property
  1612. # def graphs(self):
  1613. # return self._graphs
  1614. # @property
  1615. # def targets(self):
  1616. # return self._targets
  1617. #
  1618. #
  1619. # @property
  1620. # def node_labels(self):
  1621. # return self._node_labels
  1622. # @property
  1623. # def edge_labels(self):
  1624. # return self._edge_labels
  1625. #
  1626. #
  1627. # @property
  1628. # def node_attrs(self):
  1629. # return self._node_attrs
  1630. #
  1631. #
  1632. # @property
  1633. # def edge_attrs(self):
  1634. # return self._edge_attrs
  1635. #
  1636. #
  1637. # def split_dataset_by_target(dataset):
  1638. # from gklearn.preimage.utils import get_same_item_indices
  1639. #
  1640. # graphs = dataset.graphs
  1641. # targets = dataset.targets
  1642. # datasets = []
  1643. # idx_targets = get_same_item_indices(targets)
  1644. # for key, val in idx_targets.items():
  1645. # sub_graphs = [graphs[i] for i in val]
  1646. # sub_dataset = Dataset()
  1647. # sub_dataset.load_graphs(sub_graphs, [key] * len(val))
  1648. # node_labels = dataset.node_labels.copy() if dataset.node_labels is not None else None
  1649. # node_attrs = dataset.node_attrs.copy() if dataset.node_attrs is not None else None
  1650. # edge_labels = dataset.edge_labels.copy() if dataset.edge_labels is not None else None
  1651. # edge_attrs = dataset.edge_attrs.copy() if dataset.edge_attrs is not None else None
  1652. # sub_dataset.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
  1653. # datasets.append(sub_dataset)
  1654. # # @todo: clean_labels?
  1655. # return datasets

A Python package for graph kernels, graph edit distances and graph pre-image problem.