From 703688228c52c0f7186afd65ccdf39d403c637b4 Mon Sep 17 00:00:00 2001 From: jajupmochi Date: Thu, 12 Nov 2020 17:52:35 +0100 Subject: [PATCH] Fix bugs in the dataset module. --- gklearn/dataset/data_fetcher.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/gklearn/dataset/data_fetcher.py b/gklearn/dataset/data_fetcher.py index 1cdacc2..8f3f167 100644 --- a/gklearn/dataset/data_fetcher.py +++ b/gklearn/dataset/data_fetcher.py @@ -15,7 +15,7 @@ import os.path as osp import urllib import tarfile from zipfile import ZipFile -from gklearn.utils.graphfiles import loadDataset +# from gklearn.utils.graphfiles import loadDataset import torch.nn.functional as F import networkx as nx import torch @@ -152,21 +152,27 @@ class DataFetcher(): with tarfile.open(filename_archive, 'r:gz') as tar: if self._reload and self._verbose: print(filename + ' Downloaded.') + subpath = os.path.join(path, tar.getnames()[0]) + if not osp.exists(subpath) or self._reload: tar.extractall(path = path) - return os.path.join(path, tar.getnames()[0]) + return subpath elif filename.endswith('.tar'): if tarfile.is_tarfile(filename_archive): with tarfile.open(filename_archive, 'r:') as tar: if self._reload and self._verbose: print(filename + ' Downloaded.') + subpath = os.path.join(path, tar.getnames()[0]) + if not osp.exists(subpath) or self._reload: tar.extractall(path = path) - return os.path.join(path, tar.getnames()[0]) + return subpath elif filename.endswith('.zip'): with ZipFile(filename_archive, 'r') as zip_ref: if self._reload and self._verbose: print(filename + ' Downloaded.') + subpath = os.path.join(path, zip_ref.namelist()[0]) + if not osp.exists(subpath) or self._reload: zip_ref.extractall(path) - return os.path.join(path, zip_ref.namelist()[0]) + return subpath else: raise ValueError(filename + ' Unsupported file.') @@ -261,6 +267,11 @@ class DataFetcher(): else: geometry = geo_txt + # url. + url = td_node[11].xpath('a')[0].attrib['href'].strip() + pos_zip = url.rfind('.zip') + url = url[:pos_zip + 4] + infos[td_node[0].xpath('strong')[0].text.strip()] = { 'database': 'tudataset', 'reference': td_node[1].text.strip(), @@ -274,7 +285,7 @@ class DataFetcher(): 'node_attr_dim': node_attr_dim, 'geometry': geometry, 'edge_attr_dim': edge_attr_dim, - 'url': td_node[11].xpath('a')[0].attrib['href'].strip(), + 'url': url, 'domain': domain }