You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 27 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Thu Mar 26 18:48:27 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import networkx as nx
  9. import os
  10. from gklearn.dataset import DATASET_META, DataFetcher, DataLoader
  11. class Dataset(object):
  12. def __init__(self, inputs=None, root='datasets', filename_targets=None, targets=None, mode='networkx', remove_null_graphs=True, clean_labels=True, reload=False, verbose=False, **kwargs):
  13. self._substructures = None
  14. self._node_label_dim = None
  15. self._edge_label_dim = None
  16. self._directed = None
  17. self._dataset_size = None
  18. self._total_node_num = None
  19. self._ave_node_num = None
  20. self._min_node_num = None
  21. self._max_node_num = None
  22. self._total_edge_num = None
  23. self._ave_edge_num = None
  24. self._min_edge_num = None
  25. self._max_edge_num = None
  26. self._ave_node_degree = None
  27. self._min_node_degree = None
  28. self._max_node_degree = None
  29. self._ave_fill_factor = None
  30. self._min_fill_factor = None
  31. self._max_fill_factor = None
  32. self._node_label_nums = None
  33. self._edge_label_nums = None
  34. self._node_attr_dim = None
  35. self._edge_attr_dim = None
  36. self._class_number = None
  37. self._ds_name = None
  38. self._task_type = None
  39. if inputs is None:
  40. self._graphs = None
  41. self._targets = None
  42. self._node_labels = None
  43. self._edge_labels = None
  44. self._node_attrs = None
  45. self._edge_attrs = None
  46. # If inputs is a list of graphs.
  47. elif isinstance(inputs, list):
  48. node_labels = kwargs.get('node_labels', None)
  49. node_attrs = kwargs.get('node_attrs', None)
  50. edge_labels = kwargs.get('edge_labels', None)
  51. edge_attrs = kwargs.get('edge_attrs', None)
  52. self.load_graphs(inputs, targets=targets)
  53. self.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
  54. if clean_labels:
  55. self.clean_labels()
  56. elif isinstance(inputs, str):
  57. # If inputs is predefined dataset name.
  58. if inputs in DATASET_META:
  59. self.load_predefined_dataset(inputs, root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
  60. self._ds_name = inputs
  61. # If the dataset is specially defined, i.g., Alkane_unlabeled, MAO_lite.
  62. elif self.is_special_dataset(inputs):
  63. self.load_special_dataset(inputs, root, clean_labels, reload, verbose)
  64. self._ds_name = inputs
  65. # If inputs is a file name.
  66. elif os.path.isfile(inputs):
  67. self.load_dataset(inputs, filename_targets=filename_targets, clean_labels=clean_labels, **kwargs)
  68. # If inputs is a file name.
  69. else:
  70. raise ValueError('The "inputs" argument "' + inputs + '" is not a valid dataset name or file name.')
  71. else:
  72. raise TypeError('The "inputs" argument cannot be recognized. "Inputs" can be a list of graphs, a predefined dataset name, or a file name of a dataset.')
  73. if remove_null_graphs:
  74. self.trim_dataset(edge_required=False)
  75. def load_dataset(self, filename, filename_targets=None, clean_labels=True, **kwargs):
  76. self._graphs, self._targets, label_names = DataLoader(filename, filename_targets=filename_targets, **kwargs).data
  77. self._node_labels = label_names['node_labels']
  78. self._node_attrs = label_names['node_attrs']
  79. self._edge_labels = label_names['edge_labels']
  80. self._edge_attrs = label_names['edge_attrs']
  81. if clean_labels:
  82. self.clean_labels()
  83. def load_graphs(self, graphs, targets=None):
  84. # this has to be followed by set_labels().
  85. self._graphs = graphs
  86. self._targets = targets
  87. # self.set_labels_attrs() # @todo
  88. def load_predefined_dataset(self, ds_name, root='datasets', clean_labels=True, reload=False, verbose=False):
  89. path = DataFetcher(name=ds_name, root=root, reload=reload, verbose=verbose).path
  90. if DATASET_META[ds_name]['database'] == 'tudataset':
  91. ds_file = os.path.join(path, ds_name + '_A.txt')
  92. fn_targets = None
  93. else:
  94. load_files = DATASET_META[ds_name]['load_files']
  95. if isinstance(load_files[0], str):
  96. ds_file = os.path.join(path, load_files[0])
  97. else: # load_files[0] is a list of files.
  98. ds_file = [os.path.join(path, fn) for fn in load_files[0]]
  99. fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None
  100. # Get extra_params.
  101. if 'extra_params' in DATASET_META[ds_name]:
  102. kwargs = DATASET_META[ds_name]['extra_params']
  103. else:
  104. kwargs = {}
  105. # Get the task type that is associated with the dataset. If it is classification, get the number of classes.
  106. self._get_task_type(ds_name)
  107. self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets, **kwargs).data
  108. self._node_labels = label_names['node_labels']
  109. self._node_attrs = label_names['node_attrs']
  110. self._edge_labels = label_names['edge_labels']
  111. self._edge_attrs = label_names['edge_attrs']
  112. if clean_labels:
  113. self.clean_labels()
  114. # Deal with specific datasets.
  115. if ds_name == 'Alkane':
  116. self.trim_dataset(edge_required=True)
  117. self.remove_labels(node_labels=['atom_symbol'])
  118. def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]):
  119. self._node_labels = node_labels
  120. self._node_attrs = node_attrs
  121. self._edge_labels = edge_labels
  122. self._edge_attrs = edge_attrs
  123. def set_labels_attrs(self, node_labels=None, node_attrs=None, edge_labels=None, edge_attrs=None):
  124. # @todo: remove labels which have only one possible values.
  125. if node_labels is None:
  126. self._node_labels = self._graphs[0].graph['node_labels']
  127. # # graphs are considered node unlabeled if all nodes have the same label.
  128. # infos.update({'node_labeled': is_nl if node_label_num > 1 else False})
  129. if node_attrs is None:
  130. self._node_attrs = self._graphs[0].graph['node_attrs']
  131. # for G in Gn:
  132. # for n in G.nodes(data=True):
  133. # if 'attributes' in n[1]:
  134. # return len(n[1]['attributes'])
  135. # return 0
  136. if edge_labels is None:
  137. self._edge_labels = self._graphs[0].graph['edge_labels']
  138. # # graphs are considered edge unlabeled if all edges have the same label.
  139. # infos.update({'edge_labeled': is_el if edge_label_num > 1 else False})
  140. if edge_attrs is None:
  141. self._edge_attrs = self._graphs[0].graph['edge_attrs']
  142. # for G in Gn:
  143. # if nx.number_of_edges(G) > 0:
  144. # for e in G.edges(data=True):
  145. # if 'attributes' in e[2]:
  146. # return len(e[2]['attributes'])
  147. # return 0
  148. def get_dataset_infos(self, keys=None, params=None):
  149. """Computes and returns the structure and property information of the graph dataset.
  150. Parameters
  151. ----------
  152. keys : list, optional
  153. A list of strings which indicate which informations will be returned. The
  154. possible choices includes:
  155. 'substructures': sub-structures graphs contains, including 'linear', 'non
  156. linear' and 'cyclic'.
  157. 'node_label_dim': whether vertices have symbolic labels.
  158. 'edge_label_dim': whether egdes have symbolic labels.
  159. 'directed': whether graphs in dataset are directed.
  160. 'dataset_size': number of graphs in dataset.
  161. 'total_node_num': total number of vertices of all graphs in dataset.
  162. 'ave_node_num': average number of vertices of graphs in dataset.
  163. 'min_node_num': minimum number of vertices of graphs in dataset.
  164. 'max_node_num': maximum number of vertices of graphs in dataset.
  165. 'total_edge_num': total number of edges of all graphs in dataset.
  166. 'ave_edge_num': average number of edges of graphs in dataset.
  167. 'min_edge_num': minimum number of edges of graphs in dataset.
  168. 'max_edge_num': maximum number of edges of graphs in dataset.
  169. 'ave_node_degree': average vertex degree of graphs in dataset.
  170. 'min_node_degree': minimum vertex degree of graphs in dataset.
  171. 'max_node_degree': maximum vertex degree of graphs in dataset.
  172. 'ave_fill_factor': average fill factor (number_of_edges /
  173. (number_of_nodes ** 2)) of graphs in dataset.
  174. 'min_fill_factor': minimum fill factor of graphs in dataset.
  175. 'max_fill_factor': maximum fill factor of graphs in dataset.
  176. 'node_label_nums': list of numbers of symbolic vertex labels of graphs in dataset.
  177. 'edge_label_nums': list number of symbolic edge labels of graphs in dataset.
  178. 'node_attr_dim': number of dimensions of non-symbolic vertex labels.
  179. Extracted from the 'attributes' attribute of graph nodes.
  180. 'edge_attr_dim': number of dimensions of non-symbolic edge labels.
  181. Extracted from the 'attributes' attribute of graph edges.
  182. 'class_number': number of classes. Only available for classification problems.
  183. 'all_degree_entropy': the entropy of degree distribution of each graph.
  184. 'ave_degree_entropy': the average entropy of degree distribution of all graphs.
  185. All informations above will be returned if `keys` is not given.
  186. params: dict of dict, optional
  187. A dictinary which contains extra parameters for each possible
  188. element in ``keys``.
  189. Return
  190. ------
  191. dict
  192. Information of the graph dataset keyed by `keys`.
  193. """
  194. infos = {}
  195. if keys == None:
  196. keys = [
  197. 'substructures',
  198. 'node_label_dim',
  199. 'edge_label_dim',
  200. 'directed',
  201. 'dataset_size',
  202. 'total_node_num',
  203. 'ave_node_num',
  204. 'min_node_num',
  205. 'max_node_num',
  206. 'total_edge_num',
  207. 'ave_edge_num',
  208. 'min_edge_num',
  209. 'max_edge_num',
  210. 'ave_node_degree',
  211. 'min_node_degree',
  212. 'max_node_degree',
  213. 'ave_fill_factor',
  214. 'min_fill_factor',
  215. 'max_fill_factor',
  216. 'node_label_nums',
  217. 'edge_label_nums',
  218. 'node_attr_dim',
  219. 'edge_attr_dim',
  220. 'class_number',
  221. 'all_degree_entropy',
  222. 'ave_degree_entropy',
  223. 'class_type'
  224. ]
  225. # dataset size
  226. if 'dataset_size' in keys:
  227. if self._dataset_size is None:
  228. self._dataset_size = self._get_dataset_size()
  229. infos['dataset_size'] = self._dataset_size
  230. # graph node number
  231. if any(i in keys for i in ['total_node_num', 'ave_node_num', 'min_node_num', 'max_node_num']):
  232. all_node_nums = self._get_all_node_nums()
  233. if 'total_node_num' in keys:
  234. if self._total_node_num is None:
  235. self._total_node_num = self._get_total_node_num(all_node_nums)
  236. infos['total_node_num'] = self._total_node_num
  237. if 'ave_node_num' in keys:
  238. if self._ave_node_num is None:
  239. self._ave_node_num = self._get_ave_node_num(all_node_nums)
  240. infos['ave_node_num'] = self._ave_node_num
  241. if 'min_node_num' in keys:
  242. if self._min_node_num is None:
  243. self._min_node_num = self._get_min_node_num(all_node_nums)
  244. infos['min_node_num'] = self._min_node_num
  245. if 'max_node_num' in keys:
  246. if self._max_node_num is None:
  247. self._max_node_num = self._get_max_node_num(all_node_nums)
  248. infos['max_node_num'] = self._max_node_num
  249. # graph edge number
  250. if any(i in keys for i in ['total_edge_num', 'ave_edge_num', 'min_edge_num', 'max_edge_num']):
  251. all_edge_nums = self._get_all_edge_nums()
  252. if 'total_edge_num' in keys:
  253. if self._total_edge_num is None:
  254. self._total_edge_num = self._get_total_edge_num(all_edge_nums)
  255. infos['total_edge_num'] = self._total_edge_num
  256. if 'ave_edge_num' in keys:
  257. if self._ave_edge_num is None:
  258. self._ave_edge_num = self._get_ave_edge_num(all_edge_nums)
  259. infos['ave_edge_num'] = self._ave_edge_num
  260. if 'max_edge_num' in keys:
  261. if self._max_edge_num is None:
  262. self._max_edge_num = self._get_max_edge_num(all_edge_nums)
  263. infos['max_edge_num'] = self._max_edge_num
  264. if 'min_edge_num' in keys:
  265. if self._min_edge_num is None:
  266. self._min_edge_num = self._get_min_edge_num(all_edge_nums)
  267. infos['min_edge_num'] = self._min_edge_num
  268. # label number
  269. if 'node_label_dim' in keys:
  270. if self._node_label_dim is None:
  271. self._node_label_dim = self._get_node_label_dim()
  272. infos['node_label_dim'] = self._node_label_dim
  273. if 'node_label_nums' in keys:
  274. if self._node_label_nums is None:
  275. self._node_label_nums = {}
  276. for node_label in self._node_labels:
  277. self._node_label_nums[node_label] = self._get_node_label_num(node_label)
  278. infos['node_label_nums'] = self._node_label_nums
  279. if 'edge_label_dim' in keys:
  280. if self._edge_label_dim is None:
  281. self._edge_label_dim = self._get_edge_label_dim()
  282. infos['edge_label_dim'] = self._edge_label_dim
  283. if 'edge_label_nums' in keys:
  284. if self._edge_label_nums is None:
  285. self._edge_label_nums = {}
  286. for edge_label in self._edge_labels:
  287. self._edge_label_nums[edge_label] = self._get_edge_label_num(edge_label)
  288. infos['edge_label_nums'] = self._edge_label_nums
  289. if 'directed' in keys or 'substructures' in keys:
  290. if self._directed is None:
  291. self._directed = self._is_directed()
  292. infos['directed'] = self._directed
  293. # node degree
  294. if any(i in keys for i in ['ave_node_degree', 'max_node_degree', 'min_node_degree']):
  295. all_node_degrees = self._get_all_node_degrees()
  296. if 'ave_node_degree' in keys:
  297. if self._ave_node_degree is None:
  298. self._ave_node_degree = self._get_ave_node_degree(all_node_degrees)
  299. infos['ave_node_degree'] = self._ave_node_degree
  300. if 'max_node_degree' in keys:
  301. if self._max_node_degree is None:
  302. self._max_node_degree = self._get_max_node_degree(all_node_degrees)
  303. infos['max_node_degree'] = self._max_node_degree
  304. if 'min_node_degree' in keys:
  305. if self._min_node_degree is None:
  306. self._min_node_degree = self._get_min_node_degree(all_node_degrees)
  307. infos['min_node_degree'] = self._min_node_degree
  308. # fill factor
  309. if any(i in keys for i in ['ave_fill_factor', 'max_fill_factor', 'min_fill_factor']):
  310. all_fill_factors = self._get_all_fill_factors()
  311. if 'ave_fill_factor' in keys:
  312. if self._ave_fill_factor is None:
  313. self._ave_fill_factor = self._get_ave_fill_factor(all_fill_factors)
  314. infos['ave_fill_factor'] = self._ave_fill_factor
  315. if 'max_fill_factor' in keys:
  316. if self._max_fill_factor is None:
  317. self._max_fill_factor = self._get_max_fill_factor(all_fill_factors)
  318. infos['max_fill_factor'] = self._max_fill_factor
  319. if 'min_fill_factor' in keys:
  320. if self._min_fill_factor is None:
  321. self._min_fill_factor = self._get_min_fill_factor(all_fill_factors)
  322. infos['min_fill_factor'] = self._min_fill_factor
  323. if 'substructures' in keys:
  324. if self._substructures is None:
  325. self._substructures = self._get_substructures()
  326. infos['substructures'] = self._substructures
  327. if 'class_number' in keys:
  328. if self._class_number is None:
  329. self._class_number = self._get_class_num()
  330. infos['class_number'] = self._class_number
  331. if 'node_attr_dim' in keys:
  332. if self._node_attr_dim is None:
  333. self._node_attr_dim = self._get_node_attr_dim()
  334. infos['node_attr_dim'] = self._node_attr_dim
  335. if 'edge_attr_dim' in keys:
  336. if self._edge_attr_dim is None:
  337. self._edge_attr_dim = self._get_edge_attr_dim()
  338. infos['edge_attr_dim'] = self._edge_attr_dim
  339. # entropy of degree distribution.
  340. if 'all_degree_entropy' in keys:
  341. if params is not None and ('all_degree_entropy' in params) and ('base' in params['all_degree_entropy']):
  342. base = params['all_degree_entropy']['base']
  343. else:
  344. base = None
  345. infos['all_degree_entropy'] = self._compute_all_degree_entropy(base=base)
  346. if 'ave_degree_entropy' in keys:
  347. if params is not None and ('ave_degree_entropy' in params) and ('base' in params['ave_degree_entropy']):
  348. base = params['ave_degree_entropy']['base']
  349. else:
  350. base = None
  351. infos['ave_degree_entropy'] = np.mean(self._compute_all_degree_entropy(base=base))
  352. if 'task_type' in keys:
  353. if self._task_type is None:
  354. self._task_type = self._get_task_type()
  355. infos['task_type'] = self._task_type
  356. return infos
  357. def print_graph_infos(self, infos):
  358. from collections import OrderedDict
  359. keys = list(infos.keys())
  360. print(OrderedDict(sorted(infos.items(), key=lambda i: keys.index(i[0]))))
  361. def remove_labels(self, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):
  362. node_labels = [item for item in node_labels if item in self._node_labels]
  363. edge_labels = [item for item in edge_labels if item in self._edge_labels]
  364. node_attrs = [item for item in node_attrs if item in self._node_attrs]
  365. edge_attrs = [item for item in edge_attrs if item in self._edge_attrs]
  366. for g in self._graphs:
  367. for nd in g.nodes():
  368. for nl in node_labels:
  369. del g.nodes[nd][nl]
  370. for na in node_attrs:
  371. del g.nodes[nd][na]
  372. for ed in g.edges():
  373. for el in edge_labels:
  374. del g.edges[ed][el]
  375. for ea in edge_attrs:
  376. del g.edges[ed][ea]
  377. if len(node_labels) > 0:
  378. self._node_labels = [nl for nl in self._node_labels if nl not in node_labels]
  379. if len(edge_labels) > 0:
  380. self._edge_labels = [el for el in self._edge_labels if el not in edge_labels]
  381. if len(node_attrs) > 0:
  382. self._node_attrs = [na for na in self._node_attrs if na not in node_attrs]
  383. if len(edge_attrs) > 0:
  384. self._edge_attrs = [ea for ea in self._edge_attrs if ea not in edge_attrs]
  385. def clean_labels(self):
  386. labels = []
  387. for name in self._node_labels:
  388. label = set()
  389. for G in self._graphs:
  390. label = label | set(nx.get_node_attributes(G, name).values())
  391. if len(label) > 1:
  392. labels.append(name)
  393. break
  394. if len(label) < 2:
  395. for G in self._graphs:
  396. for nd in G.nodes():
  397. del G.nodes[nd][name]
  398. self._node_labels = labels
  399. labels = []
  400. for name in self._edge_labels:
  401. label = set()
  402. for G in self._graphs:
  403. label = label | set(nx.get_edge_attributes(G, name).values())
  404. if len(label) > 1:
  405. labels.append(name)
  406. break
  407. if len(label) < 2:
  408. for G in self._graphs:
  409. for ed in G.edges():
  410. del G.edges[ed][name]
  411. self._edge_labels = labels
  412. labels = []
  413. for name in self._node_attrs:
  414. label = set()
  415. for G in self._graphs:
  416. label = label | set(nx.get_node_attributes(G, name).values())
  417. if len(label) > 1:
  418. labels.append(name)
  419. break
  420. if len(label) < 2:
  421. for G in self._graphs:
  422. for nd in G.nodes():
  423. del G.nodes[nd][name]
  424. self._node_attrs = labels
  425. labels = []
  426. for name in self._edge_attrs:
  427. label = set()
  428. for G in self._graphs:
  429. label = label | set(nx.get_edge_attributes(G, name).values())
  430. if len(label) > 1:
  431. labels.append(name)
  432. break
  433. if len(label) < 2:
  434. for G in self._graphs:
  435. for ed in G.edges():
  436. del G.edges[ed][name]
  437. self._edge_attrs = labels
  438. def cut_graphs(self, range_):
  439. self._graphs = [self._graphs[i] for i in range_]
  440. if self._targets is not None:
  441. self._targets = [self._targets[i] for i in range_]
  442. self.clean_labels()
  443. def trim_dataset(self, edge_required=False):
  444. if edge_required: # @todo: there is a possibility that some node labels will be removed.
  445. trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if (nx.number_of_nodes(g) != 0 and nx.number_of_edges(g) != 0)]
  446. else:
  447. trimed_pairs = [(idx, g) for idx, g in enumerate(self._graphs) if nx.number_of_nodes(g) != 0]
  448. idx = [p[0] for p in trimed_pairs]
  449. self._graphs = [p[1] for p in trimed_pairs]
  450. self._targets = [self._targets[i] for i in idx]
  451. self.clean_labels()
  452. def copy(self):
  453. dataset = Dataset()
  454. graphs = [g.copy() for g in self._graphs] if self._graphs is not None else None
  455. target = self._targets.copy() if self._targets is not None else None
  456. node_labels = self._node_labels.copy() if self._node_labels is not None else None
  457. node_attrs = self._node_attrs.copy() if self._node_attrs is not None else None
  458. edge_labels = self._edge_labels.copy() if self._edge_labels is not None else None
  459. edge_attrs = self._edge_attrs.copy() if self._edge_attrs is not None else None
  460. dataset.load_graphs(graphs, target)
  461. dataset.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
  462. # @todo: clean_labels and add other class members?
  463. return dataset
  464. def is_special_dataset(self, inputs):
  465. if inputs.endswith('_unlabeled'):
  466. return True
  467. if inputs == 'MAO_lite':
  468. return True
  469. if inputs == 'Monoterpens':
  470. return True
  471. return False
  472. def load_special_dataset(self, inputs, root, clean_labels, reload, verbose):
  473. if inputs.endswith('_unlabeled'):
  474. self.load_predefined_dataset(inputs[:len(inputs) - 10], root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
  475. self.remove_labels(node_labels=self._node_labels,
  476. edge_labels=self._edge_labels,
  477. node_attrs=self._node_attrs,
  478. edge_attrs=self._edge_attrs)
  479. elif inputs == 'MAO_lite':
  480. self.load_predefined_dataset(inputs[:len(inputs) - 5], root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
  481. self.remove_labels(edge_labels=['bond_stereo'], node_attrs=['x', 'y'])
  482. elif inputs == 'Monoterpens':
  483. self.load_predefined_dataset('Monoterpenoides', root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)
  484. def get_all_node_labels(self):
  485. node_labels = []
  486. for g in self._graphs:
  487. for n in g.nodes():
  488. nl = tuple(g.nodes[n].items())
  489. if nl not in node_labels:
  490. node_labels.append(nl)
  491. return node_labels
  492. def get_all_edge_labels(self):
  493. edge_labels = []
  494. for g in self._graphs:
  495. for e in g.edges():
  496. el = tuple(g.edges[e].items())
  497. if el not in edge_labels:
  498. edge_labels.append(el)
  499. return edge_labels
  500. def _get_dataset_size(self):
  501. return len(self._graphs)
  502. def _get_all_node_nums(self):
  503. return [nx.number_of_nodes(G) for G in self._graphs]
  504. def _get_total_node_nums(self, all_node_nums):
  505. return np.sum(all_node_nums)
  506. def _get_ave_node_num(self, all_node_nums):
  507. return np.mean(all_node_nums)
  508. def _get_min_node_num(self, all_node_nums):
  509. return np.amin(all_node_nums)
  510. def _get_max_node_num(self, all_node_nums):
  511. return np.amax(all_node_nums)
  512. def _get_all_edge_nums(self):
  513. return [nx.number_of_edges(G) for G in self._graphs]
  514. def _get_total_edge_nums(self, all_edge_nums):
  515. return np.sum(all_edge_nums)
  516. def _get_ave_edge_num(self, all_edge_nums):
  517. return np.mean(all_edge_nums)
  518. def _get_min_edge_num(self, all_edge_nums):
  519. return np.amin(all_edge_nums)
  520. def _get_max_edge_num(self, all_edge_nums):
  521. return np.amax(all_edge_nums)
  522. def _get_node_label_dim(self):
  523. return len(self._node_labels)
  524. def _get_node_label_num(self, node_label):
  525. nl = set()
  526. for G in self._graphs:
  527. nl = nl | set(nx.get_node_attributes(G, node_label).values())
  528. return len(nl)
  529. def _get_edge_label_dim(self):
  530. return len(self._edge_labels)
  531. def _get_edge_label_num(self, edge_label):
  532. el = set()
  533. for G in self._graphs:
  534. el = el | set(nx.get_edge_attributes(G, edge_label).values())
  535. return len(el)
  536. def _is_directed(self):
  537. return nx.is_directed(self._graphs[0])
  538. def _get_all_node_degrees(self):
  539. return [np.mean(list(dict(G.degree()).values())) for G in self._graphs]
  540. def _get_ave_node_degree(self, all_node_degrees):
  541. return np.mean(all_node_degrees)
  542. def _get_max_node_degree(self, all_node_degrees):
  543. return np.amax(all_node_degrees)
  544. def _get_min_node_degree(self, all_node_degrees):
  545. return np.amin(all_node_degrees)
  546. def _get_all_fill_factors(self):
  547. """Get fill factor, the number of non-zero entries in the adjacency matrix.
  548. Returns
  549. -------
  550. list[float]
  551. List of fill factors for all graphs.
  552. """
  553. return [nx.number_of_edges(G) / (nx.number_of_nodes(G) ** 2) for G in self._graphs]
  554. def _get_ave_fill_factor(self, all_fill_factors):
  555. return np.mean(all_fill_factors)
  556. def _get_max_fill_factor(self, all_fill_factors):
  557. return np.amax(all_fill_factors)
  558. def _get_min_fill_factor(self, all_fill_factors):
  559. return np.amin(all_fill_factors)
  560. def _get_substructures(self):
  561. subs = set()
  562. for G in self._graphs:
  563. degrees = list(dict(G.degree()).values())
  564. if any(i == 2 for i in degrees):
  565. subs.add('linear')
  566. if np.amax(degrees) >= 3:
  567. subs.add('non linear')
  568. if 'linear' in subs and 'non linear' in subs:
  569. break
  570. if self._directed:
  571. for G in self._graphs:
  572. if len(list(nx.find_cycle(G))) > 0:
  573. subs.add('cyclic')
  574. break
  575. # else:
  576. # # @todo: this method does not work for big graph with large amount of edges like D&D, try a better way.
  577. # upper = np.amin([nx.number_of_edges(G) for G in Gn]) * 2 + 10
  578. # for G in Gn:
  579. # if (nx.number_of_edges(G) < upper):
  580. # cyc = list(nx.simple_cycles(G.to_directed()))
  581. # if any(len(i) > 2 for i in cyc):
  582. # subs.add('cyclic')
  583. # break
  584. # if 'cyclic' not in subs:
  585. # for G in Gn:
  586. # cyc = list(nx.simple_cycles(G.to_directed()))
  587. # if any(len(i) > 2 for i in cyc):
  588. # subs.add('cyclic')
  589. # break
  590. return subs
  591. def _get_class_num(self):
  592. return len(set(self._targets))
  593. def _get_node_attr_dim(self):
  594. return len(self._node_attrs)
  595. def _get_edge_attr_dim(self):
  596. return len(self._edge_attrs)
  597. def _compute_all_degree_entropy(self, base=None):
  598. """Compute the entropy of degree distribution of each graph.
  599. Parameters
  600. ----------
  601. base : float, optional
  602. The logarithmic base to use. The default is ``e`` (natural logarithm).
  603. Returns
  604. -------
  605. degree_entropy : float
  606. The calculated entropy.
  607. """
  608. from gklearn.utils.stats import entropy
  609. degree_entropy = []
  610. for g in self._graphs:
  611. degrees = list(dict(g.degree()).values())
  612. en = entropy(degrees, base=base)
  613. degree_entropy.append(en)
  614. return degree_entropy
  615. def _get_task_type(self, ds_name):
  616. if 'task_type' in DATASET_META[ds_name]:
  617. self._task_type = DATASET_META[ds_name]['task_type']
  618. if self._task_type == 'classification' and self._class_number is None and 'class_number' in DATASET_META[ds_name]:
  619. self._class_number = DATASET_META[ds_name]['class_number']
  620. @property
  621. def graphs(self):
  622. return self._graphs
  623. @property
  624. def targets(self):
  625. return self._targets
  626. @property
  627. def node_labels(self):
  628. return self._node_labels
  629. @property
  630. def edge_labels(self):
  631. return self._edge_labels
  632. @property
  633. def node_attrs(self):
  634. return self._node_attrs
  635. @property
  636. def edge_attrs(self):
  637. return self._edge_attrs
  638. def split_dataset_by_target(dataset):
  639. from gklearn.preimage.utils import get_same_item_indices
  640. graphs = dataset.graphs
  641. targets = dataset.targets
  642. datasets = []
  643. idx_targets = get_same_item_indices(targets)
  644. for key, val in idx_targets.items():
  645. sub_graphs = [graphs[i] for i in val]
  646. sub_dataset = Dataset()
  647. sub_dataset.load_graphs(sub_graphs, [key] * len(val))
  648. node_labels = dataset.node_labels.copy() if dataset.node_labels is not None else None
  649. node_attrs = dataset.node_attrs.copy() if dataset.node_attrs is not None else None
  650. edge_labels = dataset.edge_labels.copy() if dataset.edge_labels is not None else None
  651. edge_attrs = dataset.edge_attrs.copy() if dataset.edge_attrs is not None else None
  652. sub_dataset.set_labels(node_labels=node_labels, node_attrs=node_attrs, edge_labels=edge_labels, edge_attrs=edge_attrs)
  653. datasets.append(sub_dataset)
  654. # @todo: clean_labels?
  655. return datasets

A Python package for graph kernels, graph edit distances and graph pre-image problem.