@@ -7,15 +7,9 @@ Created on Tue Oct 20 14:25:49 2020
Paul Zanoncelli, paul.zanoncelli@ecole.ensicaen.fr
Paul Zanoncelli, paul.zanoncelli@ecole.ensicaen.fr
Luc Brun luc.brun@ensicaen.fr
Luc Brun luc.brun@ensicaen.fr
Sebastien Bougleux sebastien.bougleux@unicaen.fr
Sebastien Bougleux sebastien.bougleux@unicaen.fr
benoit g aüzère benoit.gauzere@insa-rouen.fr
Benoit G aüzère benoit.gauzere@insa-rouen.fr
Linlin Jia linlin.jia@insa-rouen.fr
Linlin Jia linlin.jia@insa-rouen.fr
"""
"""
import numpy as np
import networkx as nx
from gklearn.utils.graph_files import load_dataset
import os
import os
import os
import os.path as osp
import os.path as osp
import urllib
import urllib
@@ -29,299 +23,154 @@ import random
import sys
import sys
from lxml import etree
from lxml import etree
import re
import re
from gklearn.dataset import DATABASES
from tqdm import tqdm
from gklearn.dataset import DATABASES, DATASET_META
class DataFetcher():
class DataFetcher():
def __init__(self,name='Ace',root = 'data',downloadAll = False,reload = False,mode = 'Networkx', option = None): # option : number, gender, letter
self.name = name
self.dir_name = "_".join(name.split("-"))
self.root = root
self.option = option
self.mode = mode
if not osp.exists(self.root) :
os.makedirs(self.root)
self.url = "https://brunl01.users.greyc.fr/CHEMISTRY/"
self.urliam = "https://iapr-tc15.greyc.fr/IAM/"
self.downloadAll = downloadAll
self.reload = reload
self.list_database = {
# "Ace" : (self.url,"ACEDataset.tar"),
# "Acyclic" : (self.url,"Acyclic.tar.gz"),
# "Aids" : (self.urliam,"AIDS.zip"),
# "Alkane" : (self.url,"alkane_dataset.tar.gz"),
# "Chiral" : (self.url,"DatasetAcyclicChiral.tar"),
# "Coil_Del" : (self.urliam,"COIL-DEL.zip"),
# "Coil_Rag" : (self.urliam,"COIL-RAG.zip"),
# "Fingerprint" : (self.urliam,"Fingerprint.zip"),
# "Grec" : (self.urliam,"GREC.zip"),
# "Letter" : (self.urliam,"Letter.zip"),
# "Mao" : (self.url,"mao.tgz"),
# "Monoterpenoides" : (self.url,"monoterpenoides.tar.gz"),
# "Mutagenicity" : (self.urliam,"Mutagenicity.zip"),
# "Pah" : (self.url,"PAH.tar.gz"),
# "Protein" : (self.urliam,"Protein.zip"),
# "Ptc" : (self.url,"ptc.tgz"),
# "Steroid" : (self.url,"SteroidDataset.tar"),
# "Vitamin" : (self.url,"DatasetVitamin.tar"),
# "Web" : (self.urliam,"Web.zip")
}
self.data_to_use_in_datasets = {
# "Acyclic" : ("Acyclic/dataset_bps.ds"),
# "Aids" : ("AIDS_A.txt"),
# "Alkane" : ("Alkane/dataset.ds","Alkane/dataset_boiling_point_names.txt"),
# "Mao" : ("MAO/dataset.ds"),
# "Monoterpenoides" : ("monoterpenoides/dataset_10+.ds"), #('monoterpenoides/dataset.ds'),('monoterpenoides/dataset_9.ds'),('monoterpenoides/trainset_9.ds')
}
self.has_train_valid_test = {
"Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'),
"Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'),
"Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'),
# "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'),
"Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'),
'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'),
'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl')
},
"Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'),
# "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'],
"Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'),
# "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl')
}
def __init__(self, name=None, root='datasets', reload=False, verbose=False):
self._name = name
self._root = root
if not osp.exists(self._root):
os.makedirs(self._root)
self._reload = reload
self._verbose = verbose
# self.has_train_valid_test = {
# "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'),
# "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'),
# "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'),
# # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'),
# "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'),
# 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'),
# 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl')
# },
# "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'),
# # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'],
# "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'),
# # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl')
# }
# if not self.name :
# raise ValueError("No dataset entered" )
# if self.name not in self.list_database:
# message = "Invalid Dataset name " + self.name
# message += '\n Available datasets are as follows : \n\n'
#
# message += '\n'.join(database for database in self.list_database)
# raise ValueError(message)
# if self.downloadAll :
# print('Waiting...')
# for database in self.list_database :
# self.write_archive_file(database)
# print('Finished')
# else:
# self.write_archive_file(self.name)
# self.max_for_letter = 0
# self.dataset = self.open_files()
self.info_dataset = {
# 'Ace' : "This dataset is not available yet",
# 'Acyclic' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Acyclic,root = ...') \nGs,y = dataloader.dataset ",
# 'Aids' : "This dataset is not available yet",
# 'Alkane' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Acyclic',root = ...) \nGs,y = dataloader.dataset ",
# 'Chiral' : "This dataset is not available yet",
# "Coil-Del" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Coil-Deg', root = ...). \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid",
# "Coil-Rag" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Coil-Rag', root = ...). \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train\n Gs_valid,y_valid = valid",
# "Fingerprint" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Fingerprint', root = ...). \ntest,train,valid = dataloader.dataset. \nGs_test,y_test = test \nGs_train,y_train = train\n Gs_valid,y_valid = valid",
# "Grec" : "This dataset has test,train,valid datasets. Write dataloader = DataLoader('Grec', root = ...). \ntest,train,valid = dataloader.dataset. \nGs_test,y_test = test\n Gs_train,y_train = train\n Gs_valid,y_valid = valid",
# "Letter" : "This dataset has test,train,valid datasets. Choose between high,low,med dataset. \ndataloader = DataLoader('Letter', root = ..., option = 'high') \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid",
# 'Mao' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Mao',root= ...) \nGs,y = dataloader.dataset ",
# 'Monoterpenoides': "This dataset isn't composed of valid, test, train dataset but one whole dataset\n Write dataloader = DataLoader('Monoterpenoides',root= ...) \nGs,y = dataloader.dataset ",
# 'Mutagenicity' : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Mutagenicity', root = ...) \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test\n Gs_train,y_train = train \nGs_valid,y_valid = valid",
# 'Pah' : 'This dataset is composed of test and train datasets. '+ str(self.max_for_letter + 1) + ' datasets are available. \nChoose number between 0 and ' + str(self.max_for_letter) + "\ndataloader = DataLoader('Pah', root = ...,option = 0) \ntest,train = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train\n ",
# "Protein" : "This dataset has test,train,valid dataset. \ndataloader = DataLoader('Protein', root = ...) \n test,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid",
# "Ptc" : "This dataset has test and train datasets. Select gender between mm, fm, mr, fr. \ndataloader = DataLoader('Ptc',root = ...,option = 'mm') \ntest,train = dataloader.dataset \nGs_test,y_test = test \nGs_train_,y_train = train",
# "Steroid" : "This dataset is not available yet",
# 'Vitamin' : "This dataset is not available yet",
# 'Web' : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Web', root = ...) \n test,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid",
}
if mode == "Pytorch":
if self.name in self.data_to_use_in_datasets :
Gs,y = self.dataset
inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y)
#print(inputs,adjs)
self.pytorch_dataset = inputs,adjs,y
elif self.name == "Pah":
self.pytorch_dataset = []
test,train = self.dataset
Gs_test,y_test = test
Gs_train,y_train = train
self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test))
self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train))
elif self.name in self.has_train_valid_test:
self.pytorch_dataset = []
#[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])
test,train,valid = self.dataset
Gs_test,y_test = test
Gs_train,y_train = train
Gs_valid,y_valid = valid
self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test))
self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train))
self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid))
#############
"""
for G in Gs :
for e in G.edges():
print(G[e[0]])
"""
##############
if self._name is None:
if self._verbose:
print('No dataset name entered. All possible datasets will be loaded.')
self._name, self._path = [], []
for idx, ds_name in enumerate(DATASET_META):
if self._verbose:
print(str(idx + 1), '/', str(len(DATASET_META)), 'Fetching', ds_name, end='... ')
self._name.append(ds_name)
success = self.write_archive_file(ds_name)
if success:
self._path.append(self.open_files(ds_name))
else:
self._path.append(None)
if self._verbose and self._path[-1] is not None and not self._reload:
print('Fetched.')
if self._verbose:
print('Finished.', str(sum(v is not None for v in self._path)), 'of', str(len(self._path)), 'datasets are successfully fetched.')
elif self._name not in DATASET_META:
message = 'Invalid Dataset name "' + self._name + '".'
message += '\nAvailable datasets are as follows: \n\n'
message += '\n'.join(ds for ds in sorted(DATASET_META))
raise ValueError(message)
else:
self.write_archive_file(self._name)
self._path = self.open_files(self._name)
# self.max_for_letter = 0
# if mode == 'Pytorch':
# if self._name in self.data_to_use_in_datasets :
# Gs,y = self.dataset
# inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y)
# #print(inputs,adjs)
# self.pytorch_dataset = inputs,adjs,y
# elif self._name == "Pah":
# self.pytorch_dataset = []
# test,train = self.dataset
# Gs_test,y_test = test
# Gs_train,y_train = train
# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test))
# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train))
# elif self._name in self.has_train_valid_test:
# self.pytorch_dataset = []
# #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])
# test,train,valid = self.dataset
# Gs_test,y_test = test
#
# Gs_train,y_train = train
# Gs_valid,y_valid = valid
# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test))
# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train))
# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid))
# #############
# """
# for G in Gs :
# for e in G.edges():
# print(G[e[0]])
# """
# ##############
def download_file(self,url,filename):
def download_file(self, url):
try :
try :
response = urllib.request.urlopen(url + filename )
response = urllib.request.urlopen(url)
except urllib.error.HTTPError:
except urllib.error.HTTPError:
print(filename + " not available or incorrect http link")
print('"', url.split('/')[-1], '" is not available or incorrect http link.')
return
except urllib.error.URLError:
print('Network is unreachable.')
return
return
return response
return response
def write_archive_file(self,database):
path = osp.join(self.root,database)
url,filename = self.list_database[database]
filename_dir = osp.join(path,filename)
if not osp.exists(filename_dir) or self.reload:
response = self.download_file(url,filename)
if response is None :
return
if not osp.exists(path) :
os.makedirs(path)
with open(filename_dir,'wb') as outfile :
outfile.write(response.read())
def dataset(self):
if self.mode == "Tensorflow":
return #something
if self.mode == "Pytorch":
return self.pytorch_dataset
return self.dataset
def info(self):
print(self.info_dataset[self.name])
def iter_load_dataset(self,data):
results = []
for datasets in data :
results.append(loadDataset(osp.join(self.root,self.name,datasets)))
return results
def load_dataset(self,list_files):
if self.name == "Ptc":
if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']:
raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr')
results = []
results.append(loadDataset(osp.join(self.root,self.name,'PTC/Test',self.gender + '.ds')))
results.append(loadDataset(osp.join(self.root,self.name,'PTC/Train',self.gender + '.ds')))
return results
if self.name == "Pah":
maximum_sets = 0
for file in list_files:
if file.endswith('ds'):
maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0]))
self.max_for_letter = maximum_sets
if not type(self.option) == int or self.option > maximum_sets or self.option < 0:
raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets))
data = self.has_train_valid_test["Pah"]
data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds'
data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds'
return self.iter_load_dataset(data)
if self.name == "Letter":
if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]:
data = self.has_train_valid_test["Letter"][self.option.upper()]
else:
message = "The parameter for letter is incorrect choose between : "
message += "\nhigh med low"
raise ValueError(message)
return self.iter_load_dataset(data)
if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test
data = self.has_train_valid_test[self.name]
return self.iter_load_dataset(data)
else: #common dataset without train,valid and test, only dataset.ds file
data = self.data_to_use_in_datasets[self.name]
if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane
return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1]))
if data in list_files:
return loadDataset(osp.join(self.root,self.name,data))
def open_files(self):
filename = self.list_database[self.name][1]
path = osp.join(self.root,self.name)
filename_archive = osp.join(path,filename)
def write_archive_file(self, ds_name):
path = osp.join(self._root, ds_name)
url = DATASET_META[ds_name]['url']
# filename_dir = osp.join(path,filename)
if not osp.exists(path) or self._reload:
response = self.download_file(url)
if response is None:
return False
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, url.split('/')[-1]), 'wb') as outfile:
outfile.write(response.read())
return True
def open_files(self, ds_name=None):
if ds_name is None:
ds_name = (self._name if isinstance(self._name, str) else self._name[0])
filename = DATASET_META[ds_name]['url'].split('/')[-1]
path = osp.join(self._root, ds_name)
filename_archive = osp.join(path, filename)
if filename.endswith('gz'):
if filename.endswith('gz'):
if tarfile.is_tarfile(filename_archive):
if tarfile.is_tarfile(filename_archive):
with tarfile.open(filename_archive,"r:gz" ) as tar:
if self.reload:
print(filename + " Downloaded" )
tar.extractall(path = path)
return self.load_dataset(tar.getnames() )
with tarfile.open(filename_archive, 'r:gz') as tar:
if self._reload and self._verbose:
print(filename + ' Downloaded.')
tar.extractall(path = path)
return os.path.join(path, tar.getnames()[0])
elif filename.endswith('.tar'):
elif filename.endswith('.tar'):
if tarfile.is_tarfile(filename_archive):
if tarfile.is_tarfile(filename_archive):
with tarfile.open(filename_archive,"r:" ) as tar:
if self.reload :
print(filename + " Downloaded" )
tar.extractall(path = path)
return self.load_dataset(tar.getnames() )
with tarfile.open(filename_archive, 'r:') as tar:
if self._reload and self._verbose:
print(filename + ' Downloaded.')
tar.extractall(path = path)
return os.path.join(path, tar.getnames()[0])
elif filename.endswith('.zip'):
elif filename.endswith('.zip'):
with ZipFile(filename_archive,"r" ) as zip_ref:
if self.reload :
print(filename + " Downloaded" )
zip_ref.extractall(path)
return self.load_dataset(zip_ref.namelist() )
with ZipFile(filename_archive, 'r') as zip_ref:
if self._reload and self._verbose:
print(filename + ' Downloaded.')
zip_ref.extractall(path)
return os.path.join(path, zip_ref.namelist()[0])
else:
else:
print(filename + " Unsupported file")
def build_dictionary(self,Gs):
labels = set()
#next line : from DeepGraphWithNNTorch
#bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])))
sizes = set()
for G in Gs :
for _,node in G.nodes(data = True): # or for node in nx.nodes(G)
#print(_,node)
labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ?
sizes.add(G.order())
label_dict = {}
#print("labels : ", labels, bond_type_number_maxi)
for i,label in enumerate(labels):
label_dict[label] = [0.]*len(labels)
label_dict[label][i] = 1.
return label_dict
def from_networkx_to_pytorch(self,Gs,y):
#exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]}
# code from https://github.com/bgauzere/pygnn/blob/master/utils.py
atom_to_onehot = self.build_dictionary(Gs)
max_size = 30
adjs = []
inputs = []
for i, G in enumerate(Gs):
I = torch.eye(G.order(), G.order())
#A = torch.Tensor(nx.adjacency_matrix(G).todense())
#A = torch.Tensor(nx.to_numpy_matrix(G))
A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ?
adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...))
adjs.append(adj)
f_0 = []
for _, label in G.nodes(data=True):
#print(_,label)
cur_label = atom_to_onehot[label['label'][0]].copy()
f_0.append(cur_label)
X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order()))
inputs.append(X)
return inputs,adjs,y
def from_pytorch_to_tensorflow(self,batch_size):
seed = random.randrange(sys.maxsize)
random.seed(seed)
tf_inputs = random.sample(self.pytorch_dataset[0],batch_size)
random.seed(seed)
tf_y = random.sample(self.pytorch_dataset[2],batch_size)
def from_networkx_to_tensor(self,G,dict):
A=nx.to_numpy_matrix(G)
lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)]
return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab))
raise ValueError(filename + ' Unsupported file.')
def get_all_ds_infos(self, database):
def get_all_ds_infos(self, database):
"""Get information of all datasets from a database.
"""Get information of all datasets from a database.
@@ -342,6 +191,7 @@ class DataFetcher():
msg = 'Invalid Database name "' + database + '"'
msg = 'Invalid Database name "' + database + '"'
msg += '\n Available databases are as follows: \n\n'
msg += '\n Available databases are as follows: \n\n'
msg += '\n'.join(db for db in sorted(DATABASES))
msg += '\n'.join(db for db in sorted(DATABASES))
msg += 'Check "gklearn.dataset.DATASET_META" for more details.'
raise ValueError(msg)
raise ValueError(msg)
return infos
return infos
@@ -457,6 +307,146 @@ class DataFetcher():
p_str += '}'
p_str += '}'
return p_str
return p_str
@property
def path(self):
return self._path
def dataset(self):
if self.mode == "Tensorflow":
return #something
if self.mode == "Pytorch":
return self.pytorch_dataset
return self.dataset
def info(self):
print(self.info_dataset[self._name])
def iter_load_dataset(self,data):
results = []
for datasets in data :
results.append(loadDataset(osp.join(self._root,self._name,datasets)))
return results
def load_dataset(self,list_files):
if self._name == "Ptc":
if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']:
raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr')
results = []
results.append(loadDataset(osp.join(self.root,self._name,'PTC/Test',self.gender + '.ds')))
results.append(loadDataset(osp.join(self.root,self._name,'PTC/Train',self.gender + '.ds')))
return results
if self.name == "Pah":
maximum_sets = 0
for file in list_files:
if file.endswith('ds'):
maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0]))
self.max_for_letter = maximum_sets
if not type(self.option) == int or self.option > maximum_sets or self.option < 0:
raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets))
data = self.has_train_valid_test["Pah"]
data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds'
data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds'
return self.iter_load_dataset(data)
if self.name == "Letter":
if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]:
data = self.has_train_valid_test["Letter"][self.option.upper()]
else:
message = "The parameter for letter is incorrect choose between : "
message += "\nhigh med low"
raise ValueError(message)
return self.iter_load_dataset(data)
if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test
data = self.has_train_valid_test[self.name]
return self.iter_load_dataset(data)
else: #common dataset without train,valid and test, only dataset.ds file
data = self.data_to_use_in_datasets[self.name]
if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane
return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1]))
if data in list_files:
return loadDataset(osp.join(self.root,self.name,data))
def build_dictionary(self,Gs):
labels = set()
#next line : from DeepGraphWithNNTorch
#bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs])))
sizes = set()
for G in Gs :
for _,node in G.nodes(data = True): # or for node in nx.nodes(G)
#print(_,node)
labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ?
sizes.add(G.order())
label_dict = {}
#print("labels : ", labels, bond_type_number_maxi)
for i,label in enumerate(labels):
label_dict[label] = [0.]*len(labels)
label_dict[label][i] = 1.
return label_dict
def from_networkx_to_pytorch(self,Gs,y):
#exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]}
# code from https://github.com/bgauzere/pygnn/blob/master/utils.py
atom_to_onehot = self.build_dictionary(Gs)
max_size = 30
adjs = []
inputs = []
for i, G in enumerate(Gs):
I = torch.eye(G.order(), G.order())
#A = torch.Tensor(nx.adjacency_matrix(G).todense())
#A = torch.Tensor(nx.to_numpy_matrix(G))
A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ?
adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...))
adjs.append(adj)
f_0 = []
for _, label in G.nodes(data=True):
#print(_,label)
cur_label = atom_to_onehot[label['label'][0]].copy()
f_0.append(cur_label)
X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order()))
inputs.append(X)
return inputs,adjs,y
def from_pytorch_to_tensorflow(self,batch_size):
seed = random.randrange(sys.maxsize)
random.seed(seed)
tf_inputs = random.sample(self.pytorch_dataset[0],batch_size)
random.seed(seed)
tf_y = random.sample(self.pytorch_dataset[2],batch_size)
def from_networkx_to_tensor(self,G,dict):
A=nx.to_numpy_matrix(G)
lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)]
return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab))