|
|
@@ -15,7 +15,7 @@ import os.path as osp |
|
|
|
import urllib |
|
|
|
import tarfile |
|
|
|
from zipfile import ZipFile |
|
|
|
from gklearn.utils.graphfiles import loadDataset |
|
|
|
# from gklearn.utils.graphfiles import loadDataset |
|
|
|
import torch.nn.functional as F |
|
|
|
import networkx as nx |
|
|
|
import torch |
|
|
@@ -152,21 +152,27 @@ class DataFetcher(): |
|
|
|
with tarfile.open(filename_archive, 'r:gz') as tar: |
|
|
|
if self._reload and self._verbose: |
|
|
|
print(filename + ' Downloaded.') |
|
|
|
subpath = os.path.join(path, tar.getnames()[0]) |
|
|
|
if not osp.exists(subpath) or self._reload: |
|
|
|
tar.extractall(path = path) |
|
|
|
return os.path.join(path, tar.getnames()[0]) |
|
|
|
return subpath |
|
|
|
elif filename.endswith('.tar'): |
|
|
|
if tarfile.is_tarfile(filename_archive): |
|
|
|
with tarfile.open(filename_archive, 'r:') as tar: |
|
|
|
if self._reload and self._verbose: |
|
|
|
print(filename + ' Downloaded.') |
|
|
|
subpath = os.path.join(path, tar.getnames()[0]) |
|
|
|
if not osp.exists(subpath) or self._reload: |
|
|
|
tar.extractall(path = path) |
|
|
|
return os.path.join(path, tar.getnames()[0]) |
|
|
|
return subpath |
|
|
|
elif filename.endswith('.zip'): |
|
|
|
with ZipFile(filename_archive, 'r') as zip_ref: |
|
|
|
if self._reload and self._verbose: |
|
|
|
print(filename + ' Downloaded.') |
|
|
|
subpath = os.path.join(path, zip_ref.namelist()[0]) |
|
|
|
if not osp.exists(subpath) or self._reload: |
|
|
|
zip_ref.extractall(path) |
|
|
|
return os.path.join(path, zip_ref.namelist()[0]) |
|
|
|
return subpath |
|
|
|
else: |
|
|
|
raise ValueError(filename + ' Unsupported file.') |
|
|
|
|
|
|
@@ -261,6 +267,11 @@ class DataFetcher(): |
|
|
|
else: |
|
|
|
geometry = geo_txt |
|
|
|
|
|
|
|
# url. |
|
|
|
url = td_node[11].xpath('a')[0].attrib['href'].strip() |
|
|
|
pos_zip = url.rfind('.zip') |
|
|
|
url = url[:pos_zip + 4] |
|
|
|
|
|
|
|
infos[td_node[0].xpath('strong')[0].text.strip()] = { |
|
|
|
'database': 'tudataset', |
|
|
|
'reference': td_node[1].text.strip(), |
|
|
@@ -274,7 +285,7 @@ class DataFetcher(): |
|
|
|
'node_attr_dim': node_attr_dim, |
|
|
|
'geometry': geometry, |
|
|
|
'edge_attr_dim': edge_attr_dim, |
|
|
|
'url': td_node[11].xpath('a')[0].attrib['href'].strip(), |
|
|
|
'url': url, |
|
|
|
'domain': domain |
|
|
|
} |
|
|
|
|
|
|
|