Browse Source

Fix bugs in the dataset module.

v0.2.x
jajupmochi 4 years ago
parent
commit
703688228c
1 changed files with 16 additions and 5 deletions
  1. +16
    -5
      gklearn/dataset/data_fetcher.py

+ 16
- 5
gklearn/dataset/data_fetcher.py View File

@@ -15,7 +15,7 @@ import os.path as osp
import urllib
import tarfile
from zipfile import ZipFile
from gklearn.utils.graphfiles import loadDataset
# from gklearn.utils.graphfiles import loadDataset
import torch.nn.functional as F
import networkx as nx
import torch
@@ -152,21 +152,27 @@ class DataFetcher():
with tarfile.open(filename_archive, 'r:gz') as tar:
if self._reload and self._verbose:
print(filename + ' Downloaded.')
subpath = os.path.join(path, tar.getnames()[0])
if not osp.exists(subpath) or self._reload:
tar.extractall(path = path)
return os.path.join(path, tar.getnames()[0])
return subpath
elif filename.endswith('.tar'):
if tarfile.is_tarfile(filename_archive):
with tarfile.open(filename_archive, 'r:') as tar:
if self._reload and self._verbose:
print(filename + ' Downloaded.')
subpath = os.path.join(path, tar.getnames()[0])
if not osp.exists(subpath) or self._reload:
tar.extractall(path = path)
return os.path.join(path, tar.getnames()[0])
return subpath
elif filename.endswith('.zip'):
with ZipFile(filename_archive, 'r') as zip_ref:
if self._reload and self._verbose:
print(filename + ' Downloaded.')
subpath = os.path.join(path, zip_ref.namelist()[0])
if not osp.exists(subpath) or self._reload:
zip_ref.extractall(path)
return os.path.join(path, zip_ref.namelist()[0])
return subpath
else:
raise ValueError(filename + ' Unsupported file.')
@@ -261,6 +267,11 @@ class DataFetcher():
else:
geometry = geo_txt
# url.
url = td_node[11].xpath('a')[0].attrib['href'].strip()
pos_zip = url.rfind('.zip')
url = url[:pos_zip + 4]
infos[td_node[0].xpath('strong')[0].text.strip()] = {
'database': 'tudataset',
'reference': td_node[1].text.strip(),
@@ -274,7 +285,7 @@ class DataFetcher():
'node_attr_dim': node_attr_dim,
'geometry': geometry,
'edge_attr_dim': edge_attr_dim,
'url': td_node[11].xpath('a')[0].attrib['href'].strip(),
'url': url,
'domain': domain
}



Loading…
Cancel
Save