From 35bb1d3f951d2ecd392b7cd3ef5195db4e61f958 Mon Sep 17 00:00:00 2001 From: jajupmochi Date: Thu, 12 Nov 2020 17:26:26 +0100 Subject: [PATCH] Update DataFetcher: do not re-unzip data file if the file exists. --- gklearn/dataset/data_fetcher.py | 552 ++++++++++++++++++++-------------------- 1 file changed, 271 insertions(+), 281 deletions(-) diff --git a/gklearn/dataset/data_fetcher.py b/gklearn/dataset/data_fetcher.py index 5d35388..1cdacc2 100644 --- a/gklearn/dataset/data_fetcher.py +++ b/gklearn/dataset/data_fetcher.py @@ -7,15 +7,9 @@ Created on Tue Oct 20 14:25:49 2020 Paul Zanoncelli, paul.zanoncelli@ecole.ensicaen.fr Luc Brun luc.brun@ensicaen.fr Sebastien Bougleux sebastien.bougleux@unicaen.fr - benoit gaüzère benoit.gauzere@insa-rouen.fr + Benoit Gaüzère benoit.gauzere@insa-rouen.fr Linlin Jia linlin.jia@insa-rouen.fr """ -import numpy as np -import networkx as nx -from gklearn.utils.graph_files import load_dataset -import os - - import os import os.path as osp import urllib @@ -29,299 +23,154 @@ import random import sys from lxml import etree import re - -from gklearn.dataset import DATABASES +from tqdm import tqdm +from gklearn.dataset import DATABASES, DATASET_META class DataFetcher(): - def __init__(self,name='Ace',root = 'data',downloadAll = False,reload = False,mode = 'Networkx', option = None): # option : number, gender, letter - self.name = name - self.dir_name = "_".join(name.split("-")) - self.root = root - self.option = option - self.mode = mode - if not osp.exists(self.root) : - os.makedirs(self.root) - self.url = "https://brunl01.users.greyc.fr/CHEMISTRY/" - self.urliam = "https://iapr-tc15.greyc.fr/IAM/" - self.downloadAll = downloadAll - self.reload = reload - self.list_database = { -# "Ace" : (self.url,"ACEDataset.tar"), -# "Acyclic" : (self.url,"Acyclic.tar.gz"), -# "Aids" : (self.urliam,"AIDS.zip"), -# "Alkane" : (self.url,"alkane_dataset.tar.gz"), -# "Chiral" : (self.url,"DatasetAcyclicChiral.tar"), -# "Coil_Del" : (self.urliam,"COIL-DEL.zip"), -# "Coil_Rag" : (self.urliam,"COIL-RAG.zip"), -# "Fingerprint" : (self.urliam,"Fingerprint.zip"), -# "Grec" : (self.urliam,"GREC.zip"), -# "Letter" : (self.urliam,"Letter.zip"), -# "Mao" : (self.url,"mao.tgz"), -# "Monoterpenoides" : (self.url,"monoterpenoides.tar.gz"), -# "Mutagenicity" : (self.urliam,"Mutagenicity.zip"), -# "Pah" : (self.url,"PAH.tar.gz"), -# "Protein" : (self.urliam,"Protein.zip"), -# "Ptc" : (self.url,"ptc.tgz"), -# "Steroid" : (self.url,"SteroidDataset.tar"), -# "Vitamin" : (self.url,"DatasetVitamin.tar"), -# "Web" : (self.urliam,"Web.zip") - } - - self.data_to_use_in_datasets = { -# "Acyclic" : ("Acyclic/dataset_bps.ds"), -# "Aids" : ("AIDS_A.txt"), -# "Alkane" : ("Alkane/dataset.ds","Alkane/dataset_boiling_point_names.txt"), -# "Mao" : ("MAO/dataset.ds"), -# "Monoterpenoides" : ("monoterpenoides/dataset_10+.ds"), #('monoterpenoides/dataset.ds'),('monoterpenoides/dataset_9.ds'),('monoterpenoides/trainset_9.ds') - - } - self.has_train_valid_test = { - "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'), - "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'), - "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'), -# "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'), - "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'), - 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'), - 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl') - }, - "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'), -# "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'], - "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'), -# "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl') - } + def __init__(self, name=None, root='datasets', reload=False, verbose=False): + self._name = name + self._root = root + if not osp.exists(self._root): + os.makedirs(self._root) + self._reload = reload + self._verbose = verbose +# self.has_train_valid_test = { +# "Coil_Del" : ('COIL-DEL/data/test.cxl','COIL-DEL/data/train.cxl','COIL-DEL/data/valid.cxl'), +# "Coil_Rag" : ('COIL-RAG/data/test.cxl','COIL-RAG/data/train.cxl','COIL-RAG/data/valid.cxl'), +# "Fingerprint" : ('Fingerprint/data/test.cxl','Fingerprint/data/train.cxl','Fingerprint/data/valid.cxl'), +# # "Grec" : ('GREC/data/test.cxl','GREC/data/train.cxl','GREC/data/valid.cxl'), +# "Letter" : {'HIGH' : ('Letter/HIGH/test.cxl','Letter/HIGH/train.cxl','Letter/HIGH/validation.cxl'), +# 'MED' : ('Letter/MED/test.cxl','Letter/MED/train.cxl','Letter/MED/validation.cxl'), +# 'LOW' : ('Letter/LOW/test.cxl','Letter/LOW/train.cxl','Letter/LOW/validation.cxl') +# }, +# "Mutagenicity" : ('Mutagenicity/data/test.cxl','Mutagenicity/data/train.cxl','Mutagenicity/data/validation.cxl'), +# # "Pah" : ['PAH/testset_0.ds','PAH/trainset_0.ds'], +# "Protein" : ('Protein/data/test.cxl','Protein/data/train.cxl','Protein/data/valid.cxl'), +# # "Web" : ('Web/data/test.cxl','Web/data/train.cxl','Web/data/valid.cxl') +# } -# if not self.name : -# raise ValueError("No dataset entered" ) -# if self.name not in self.list_database: -# message = "Invalid Dataset name " + self.name -# message += '\n Available datasets are as follows : \n\n' -# -# message += '\n'.join(database for database in self.list_database) -# raise ValueError(message) -# if self.downloadAll : -# print('Waiting...') -# for database in self.list_database : -# self.write_archive_file(database) -# print('Finished') -# else: -# self.write_archive_file(self.name) -# self.max_for_letter = 0 -# self.dataset = self.open_files() - self.info_dataset = { - # 'Ace' : "This dataset is not available yet", - # 'Acyclic' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Acyclic,root = ...') \nGs,y = dataloader.dataset ", - # 'Aids' : "This dataset is not available yet", - # 'Alkane' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Acyclic',root = ...) \nGs,y = dataloader.dataset ", - # 'Chiral' : "This dataset is not available yet", - # "Coil-Del" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Coil-Deg', root = ...). \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", - # "Coil-Rag" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Coil-Rag', root = ...). \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train\n Gs_valid,y_valid = valid", - # "Fingerprint" : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Fingerprint', root = ...). \ntest,train,valid = dataloader.dataset. \nGs_test,y_test = test \nGs_train,y_train = train\n Gs_valid,y_valid = valid", - # "Grec" : "This dataset has test,train,valid datasets. Write dataloader = DataLoader('Grec', root = ...). \ntest,train,valid = dataloader.dataset. \nGs_test,y_test = test\n Gs_train,y_train = train\n Gs_valid,y_valid = valid", - # "Letter" : "This dataset has test,train,valid datasets. Choose between high,low,med dataset. \ndataloader = DataLoader('Letter', root = ..., option = 'high') \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", - # 'Mao' : "This dataset isn't composed of valid, test, train dataset but one whole dataset \ndataloader = DataLoader('Mao',root= ...) \nGs,y = dataloader.dataset ", - # 'Monoterpenoides': "This dataset isn't composed of valid, test, train dataset but one whole dataset\n Write dataloader = DataLoader('Monoterpenoides',root= ...) \nGs,y = dataloader.dataset ", - # 'Mutagenicity' : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Mutagenicity', root = ...) \ntest,train,valid = dataloader.dataset \nGs_test,y_test = test\n Gs_train,y_train = train \nGs_valid,y_valid = valid", - # 'Pah' : 'This dataset is composed of test and train datasets. '+ str(self.max_for_letter + 1) + ' datasets are available. \nChoose number between 0 and ' + str(self.max_for_letter) + "\ndataloader = DataLoader('Pah', root = ...,option = 0) \ntest,train = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train\n ", - # "Protein" : "This dataset has test,train,valid dataset. \ndataloader = DataLoader('Protein', root = ...) \n test,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", - # "Ptc" : "This dataset has test and train datasets. Select gender between mm, fm, mr, fr. \ndataloader = DataLoader('Ptc',root = ...,option = 'mm') \ntest,train = dataloader.dataset \nGs_test,y_test = test \nGs_train_,y_train = train", - # "Steroid" : "This dataset is not available yet", - # 'Vitamin' : "This dataset is not available yet", - # 'Web' : "This dataset has test,train,valid datasets. \ndataloader = DataLoader('Web', root = ...) \n test,train,valid = dataloader.dataset \nGs_test,y_test = test \nGs_train,y_train = train \nGs_valid,y_valid = valid", - } - - if mode == "Pytorch": - if self.name in self.data_to_use_in_datasets : - Gs,y = self.dataset - inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y) - #print(inputs,adjs) - self.pytorch_dataset = inputs,adjs,y - elif self.name == "Pah": - self.pytorch_dataset = [] - test,train = self.dataset - Gs_test,y_test = test - Gs_train,y_train = train - self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) - self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) - elif self.name in self.has_train_valid_test: - self.pytorch_dataset = [] - #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]) - test,train,valid = self.dataset - Gs_test,y_test = test - - Gs_train,y_train = train - Gs_valid,y_valid = valid - self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) - self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) - self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid)) - ############# - """ - for G in Gs : - for e in G.edges(): - print(G[e[0]]) - """ - ############## + if self._name is None: + if self._verbose: + print('No dataset name entered. All possible datasets will be loaded.') + self._name, self._path = [], [] + for idx, ds_name in enumerate(DATASET_META): + if self._verbose: + print(str(idx + 1), '/', str(len(DATASET_META)), 'Fetching', ds_name, end='... ') + self._name.append(ds_name) + success = self.write_archive_file(ds_name) + if success: + self._path.append(self.open_files(ds_name)) + else: + self._path.append(None) + if self._verbose and self._path[-1] is not None and not self._reload: + print('Fetched.') + + if self._verbose: + print('Finished.', str(sum(v is not None for v in self._path)), 'of', str(len(self._path)), 'datasets are successfully fetched.') + + elif self._name not in DATASET_META: + message = 'Invalid Dataset name "' + self._name + '".' + message += '\nAvailable datasets are as follows: \n\n' + message += '\n'.join(ds for ds in sorted(DATASET_META)) + raise ValueError(message) + else: + self.write_archive_file(self._name) + self._path = self.open_files(self._name) + +# self.max_for_letter = 0 +# if mode == 'Pytorch': +# if self._name in self.data_to_use_in_datasets : +# Gs,y = self.dataset +# inputs,adjs,y = self.from_networkx_to_pytorch(Gs,y) +# #print(inputs,adjs) +# self.pytorch_dataset = inputs,adjs,y +# elif self._name == "Pah": +# self.pytorch_dataset = [] +# test,train = self.dataset +# Gs_test,y_test = test +# Gs_train,y_train = train +# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) +# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) +# elif self._name in self.has_train_valid_test: +# self.pytorch_dataset = [] +# #[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]) +# test,train,valid = self.dataset +# Gs_test,y_test = test +# +# Gs_train,y_train = train +# Gs_valid,y_valid = valid +# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_test,y_test)) +# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_train,y_train)) +# self.pytorch_dataset.append(self.from_networkx_to_pytorch(Gs_valid,y_valid)) +# ############# +# """ +# for G in Gs : +# for e in G.edges(): +# print(G[e[0]]) +# """ +# ############## - def download_file(self,url,filename): + + def download_file(self, url): try : - response = urllib.request.urlopen(url + filename) + response = urllib.request.urlopen(url) except urllib.error.HTTPError: - print(filename + " not available or incorrect http link") + print('"', url.split('/')[-1], '" is not available or incorrect http link.') + return + except urllib.error.URLError: + print('Network is unreachable.') return return response - def write_archive_file(self,database): - path = osp.join(self.root,database) - url,filename = self.list_database[database] - filename_dir = osp.join(path,filename) - if not osp.exists(filename_dir) or self.reload: - response = self.download_file(url,filename) - if response is None : - return - if not osp.exists(path) : - os.makedirs(path) - with open(filename_dir,'wb') as outfile : - outfile.write(response.read()) - - def dataset(self): - if self.mode == "Tensorflow": - return #something - if self.mode == "Pytorch": - return self.pytorch_dataset - return self.dataset - - def info(self): - print(self.info_dataset[self.name]) - - def iter_load_dataset(self,data): - results = [] - for datasets in data : - results.append(loadDataset(osp.join(self.root,self.name,datasets))) - return results - def load_dataset(self,list_files): - if self.name == "Ptc": - if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']: - raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr') - results = [] - results.append(loadDataset(osp.join(self.root,self.name,'PTC/Test',self.gender + '.ds'))) - results.append(loadDataset(osp.join(self.root,self.name,'PTC/Train',self.gender + '.ds'))) - return results - if self.name == "Pah": - maximum_sets = 0 - for file in list_files: - if file.endswith('ds'): - maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0])) - self.max_for_letter = maximum_sets - if not type(self.option) == int or self.option > maximum_sets or self.option < 0: - raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets)) - data = self.has_train_valid_test["Pah"] - data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds' - data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds' - return self.iter_load_dataset(data) - if self.name == "Letter": - if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]: - data = self.has_train_valid_test["Letter"][self.option.upper()] - else: - message = "The parameter for letter is incorrect choose between : " - message += "\nhigh med low" - raise ValueError(message) - return self.iter_load_dataset(data) - if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test - data = self.has_train_valid_test[self.name] - return self.iter_load_dataset(data) - else: #common dataset without train,valid and test, only dataset.ds file - data = self.data_to_use_in_datasets[self.name] - if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane - return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1])) - if data in list_files: - return loadDataset(osp.join(self.root,self.name,data)) - - def open_files(self): - filename = self.list_database[self.name][1] - path = osp.join(self.root,self.name) - filename_archive = osp.join(path,filename) + def write_archive_file(self, ds_name): + path = osp.join(self._root, ds_name) + url = DATASET_META[ds_name]['url'] +# filename_dir = osp.join(path,filename) + if not osp.exists(path) or self._reload: + response = self.download_file(url) + if response is None: + return False + os.makedirs(path, exist_ok=True) + with open(os.path.join(path, url.split('/')[-1]), 'wb') as outfile: + outfile.write(response.read()) + + return True + + + def open_files(self, ds_name=None): + if ds_name is None: + ds_name = (self._name if isinstance(self._name, str) else self._name[0]) + filename = DATASET_META[ds_name]['url'].split('/')[-1] + path = osp.join(self._root, ds_name) + filename_archive = osp.join(path, filename) if filename.endswith('gz'): if tarfile.is_tarfile(filename_archive): - with tarfile.open(filename_archive,"r:gz") as tar: - if self.reload: - print(filename + " Downloaded") - tar.extractall(path = path) - return self.load_dataset(tar.getnames()) + with tarfile.open(filename_archive, 'r:gz') as tar: + if self._reload and self._verbose: + print(filename + ' Downloaded.') + tar.extractall(path = path) + return os.path.join(path, tar.getnames()[0]) elif filename.endswith('.tar'): if tarfile.is_tarfile(filename_archive): - with tarfile.open(filename_archive,"r:") as tar: - if self.reload : - print(filename + " Downloaded") - tar.extractall(path = path) - return self.load_dataset(tar.getnames()) + with tarfile.open(filename_archive, 'r:') as tar: + if self._reload and self._verbose: + print(filename + ' Downloaded.') + tar.extractall(path = path) + return os.path.join(path, tar.getnames()[0]) elif filename.endswith('.zip'): - with ZipFile(filename_archive,"r") as zip_ref: - if self.reload : - print(filename + " Downloaded") - zip_ref.extractall(path) - return self.load_dataset(zip_ref.namelist()) + with ZipFile(filename_archive, 'r') as zip_ref: + if self._reload and self._verbose: + print(filename + ' Downloaded.') + zip_ref.extractall(path) + return os.path.join(path, zip_ref.namelist()[0]) else: - print(filename + " Unsupported file") - - - def build_dictionary(self,Gs): - labels = set() - #next line : from DeepGraphWithNNTorch - #bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]))) - sizes = set() - for G in Gs : - for _,node in G.nodes(data = True): # or for node in nx.nodes(G) - #print(_,node) - labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ? - sizes.add(G.order()) - label_dict = {} - #print("labels : ", labels, bond_type_number_maxi) - for i,label in enumerate(labels): - label_dict[label] = [0.]*len(labels) - label_dict[label][i] = 1. - return label_dict - - def from_networkx_to_pytorch(self,Gs,y): - #exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]} - # code from https://github.com/bgauzere/pygnn/blob/master/utils.py - atom_to_onehot = self.build_dictionary(Gs) - max_size = 30 - adjs = [] - inputs = [] - for i, G in enumerate(Gs): - I = torch.eye(G.order(), G.order()) - #A = torch.Tensor(nx.adjacency_matrix(G).todense()) - #A = torch.Tensor(nx.to_numpy_matrix(G)) - A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ? - adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...)) - adjs.append(adj) - - f_0 = [] - for _, label in G.nodes(data=True): - #print(_,label) - cur_label = atom_to_onehot[label['label'][0]].copy() - f_0.append(cur_label) - - X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order())) - inputs.append(X) - return inputs,adjs,y - - def from_pytorch_to_tensorflow(self,batch_size): - seed = random.randrange(sys.maxsize) - random.seed(seed) - tf_inputs = random.sample(self.pytorch_dataset[0],batch_size) - random.seed(seed) - tf_y = random.sample(self.pytorch_dataset[2],batch_size) - - def from_networkx_to_tensor(self,G,dict): - A=nx.to_numpy_matrix(G) - lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)] - return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab)) - - + raise ValueError(filename + ' Unsupported file.') + + def get_all_ds_infos(self, database): """Get information of all datasets from a database. @@ -342,6 +191,7 @@ class DataFetcher(): msg = 'Invalid Database name "' + database + '"' msg += '\n Available databases are as follows: \n\n' msg += '\n'.join(db for db in sorted(DATABASES)) + msg += 'Check "gklearn.dataset.DATASET_META" for more details.' raise ValueError(msg) return infos @@ -457,6 +307,146 @@ class DataFetcher(): p_str += '}' return p_str + + + @property + def path(self): + return self._path + + + + + + + + + + + + + + + + + + + + + def dataset(self): + if self.mode == "Tensorflow": + return #something + if self.mode == "Pytorch": + return self.pytorch_dataset + return self.dataset + + + def info(self): + print(self.info_dataset[self._name]) + + + def iter_load_dataset(self,data): + results = [] + for datasets in data : + results.append(loadDataset(osp.join(self._root,self._name,datasets))) + return results + + + def load_dataset(self,list_files): + if self._name == "Ptc": + if type(self.option) != str or self.option.upper() not in ['FR','FM','MM','MR']: + raise ValueError('option for Ptc dataset needs to be one of : \n fr fm mm mr') + results = [] + results.append(loadDataset(osp.join(self.root,self._name,'PTC/Test',self.gender + '.ds'))) + results.append(loadDataset(osp.join(self.root,self._name,'PTC/Train',self.gender + '.ds'))) + return results + if self.name == "Pah": + maximum_sets = 0 + for file in list_files: + if file.endswith('ds'): + maximum_sets = max(maximum_sets,int(file.split('_')[1].split('.')[0])) + self.max_for_letter = maximum_sets + if not type(self.option) == int or self.option > maximum_sets or self.option < 0: + raise ValueError('option needs to be an integer between 0 and ' + str(maximum_sets)) + data = self.has_train_valid_test["Pah"] + data[0] = self.has_train_valid_test["Pah"][0].split('_')[0] + '_' + str(self.option) + '.ds' + data[1] = self.has_train_valid_test["Pah"][1].split('_')[0] + '_' + str(self.option) + '.ds' + return self.iter_load_dataset(data) + if self.name == "Letter": + if type(self.option) == str and self.option.upper() in self.has_train_valid_test["Letter"]: + data = self.has_train_valid_test["Letter"][self.option.upper()] + else: + message = "The parameter for letter is incorrect choose between : " + message += "\nhigh med low" + raise ValueError(message) + return self.iter_load_dataset(data) + if self.name in self.has_train_valid_test : #common IAM dataset with train, valid and test + data = self.has_train_valid_test[self.name] + return self.iter_load_dataset(data) + else: #common dataset without train,valid and test, only dataset.ds file + data = self.data_to_use_in_datasets[self.name] + if len(data) > 1 and data[0] in list_files and data[1] in list_files: #case for Alkane + return loadDataset(osp.join(self.root,self.name,data[0]),filename_y = osp.join(self.root,self.name,data[1])) + if data in list_files: + return loadDataset(osp.join(self.root,self.name,data)) + + + def build_dictionary(self,Gs): + labels = set() + #next line : from DeepGraphWithNNTorch + #bond_type_number_maxi = int(max(max([[G[e[0]][e[1]]['bond_type'] for e in G.edges()] for G in Gs]))) + sizes = set() + for G in Gs : + for _,node in G.nodes(data = True): # or for node in nx.nodes(G) + #print(_,node) + labels.add(node["label"][0]) # labels.add(G.nodes[node]["label"][0]) #what do we use for IAM datasets (they don't have bond_type or event label) ? + sizes.add(G.order()) + label_dict = {} + #print("labels : ", labels, bond_type_number_maxi) + for i,label in enumerate(labels): + label_dict[label] = [0.]*len(labels) + label_dict[label][i] = 1. + return label_dict + + + def from_networkx_to_pytorch(self,Gs,y): + #exemple for MAO: atom_to_onehot = {'C': [1., 0., 0.], 'N': [0., 1., 0.], 'O': [0., 0., 1.]} + # code from https://github.com/bgauzere/pygnn/blob/master/utils.py + atom_to_onehot = self.build_dictionary(Gs) + max_size = 30 + adjs = [] + inputs = [] + for i, G in enumerate(Gs): + I = torch.eye(G.order(), G.order()) + #A = torch.Tensor(nx.adjacency_matrix(G).todense()) + #A = torch.Tensor(nx.to_numpy_matrix(G)) + A = torch.tensor(nx.to_scipy_sparse_matrix(G,dtype = int,weight = 'bond_type').todense(),dtype = torch.int) #what do we use for IAM datasets (they don't have bond_type or event label) ? + adj = F.pad(A, pad=(0, max_size-G.order(), 0, max_size-G.order())) #add I now ? if yes : F.pad(A + I,pad = (...)) + adjs.append(adj) + + f_0 = [] + for _, label in G.nodes(data=True): + #print(_,label) + cur_label = atom_to_onehot[label['label'][0]].copy() + f_0.append(cur_label) + + X = F.pad(torch.Tensor(f_0), pad=(0, 0, 0, max_size-G.order())) + inputs.append(X) + return inputs,adjs,y + + + def from_pytorch_to_tensorflow(self,batch_size): + seed = random.randrange(sys.maxsize) + random.seed(seed) + tf_inputs = random.sample(self.pytorch_dataset[0],batch_size) + random.seed(seed) + tf_y = random.sample(self.pytorch_dataset[2],batch_size) + + + def from_networkx_to_tensor(self,G,dict): + A=nx.to_numpy_matrix(G) + lab=[dict[G.nodes[v]['label'][0]] for v in nx.nodes(G)] + return (torch.tensor(A).view(1,A.shape[0]*A.shape[1]),torch.tensor(lab)) +