Browse Source

Merge pull request #42 from jajupmochi/v0.2.x

V0.2.x
master
linlin GitHub 4 years ago
parent
commit
e9154d579f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 253 additions and 144 deletions
  1. +10
    -1
      gklearn/dataset/dataset.py
  2. +228
    -132
      gklearn/dataset/file_managers.py
  3. +15
    -11
      gklearn/dataset/metadata.py

+ 10
- 1
gklearn/dataset/dataset.py View File

@@ -115,7 +115,12 @@ class Dataset(object):
ds_file = [os.path.join(path, fn) for fn in load_files[0]]
fn_targets = os.path.join(path, load_files[1]) if len(load_files) == 2 else None

self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets).data
if 'extra_params' in DATASET_META[ds_name]:
kwargs = DATASET_META[ds_name]['extra_params']
else:
kwargs = {}

self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets, **kwargs).data

self._node_labels = label_names['node_labels']
self._node_attrs = label_names['node_attrs']
@@ -561,6 +566,8 @@ class Dataset(object):
return True
if inputs == 'MAO_lite':
return True
if inputs == 'Monoterpens':
return True
return False


@@ -578,6 +585,8 @@ class Dataset(object):

self.remove_labels(edge_labels=['bond_stereo'], node_attrs=['x', 'y'])

elif inputs == 'Monoterpens':
self.load_predefined_dataset('Monoterpenoides', root=root, clean_labels=clean_labels, reload=reload, verbose=verbose)


def get_all_node_labels(self):


+ 228
- 132
gklearn/dataset/file_managers.py View File

@@ -4,35 +4,35 @@ from os.path import dirname, splitext


class DataLoader():
def __init__(self, filename, filename_targets=None, gformat=None, **kwargs):
"""Read graph data from filename and load them as NetworkX graphs.
Parameters
----------
filename : string
The name of the file from where the dataset is read.
filename_targets : string
The name of file of the targets corresponding to graphs.
Notes
-----
This function supports following graph dataset formats:
'ds': load data from .ds file. See comments of function loadFromDS for a example.
'cxl': load data from Graph eXchange Language file (.cxl file). See
'cxl': load data from Graph eXchange Language file (.cxl file). See
`here <http://www.gupro.de/GXL/Introduction/background.html>`__ for detail.
'sdf': load data from structured data file (.sdf file). See
'sdf': load data from structured data file (.sdf file). See
`here <http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx>`__
for details.
'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See
README in `downloadable file <http://mlcb.is.tuebingen.mpg.de/Mitarbeiter/Nino/WL/>`__
for details.
'txt': Load graph data from the TUDataset. See
`here <https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets>`__
for details. Note here filename is the name of either .txt file in
@@ -42,7 +42,7 @@ class DataLoader():
extension = splitext(filename)[1][1:]
else: # filename is a list of files.
extension = splitext(filename[0])[1][1:]
if extension == "ds":
self._graphs, self._targets, self._label_names = self.load_from_ds(filename, filename_targets)
elif extension == "cxl":
@@ -55,21 +55,25 @@ class DataLoader():
order = kwargs.get('order')
self._graphs, self._targets, self._label_names = self.load_mat(filename, order)
elif extension == 'txt':
self._graphs, self._targets, self._label_names = self.load_tud(filename)
if gformat is None:
self._graphs, self._targets, self._label_names = self.load_tud(filename)
elif gformat == 'cml':
self._graphs, self._targets, self._label_names = self.load_from_ds(filename, filename_targets)

else:
raise ValueError('The input file with the extension ".', extension, '" is not supported. The supported extensions includes: ".ds", ".cxl", ".xml", ".mat", ".txt".')
def load_from_ds(self, filename, filename_targets):
"""Load data from .ds file.
Possible graph formats include:
'.ct': see function load_ct for detail.
'.gxl': see dunction load_gxl for detail.
Note these graph formats are checked automatically by the extensions of
Note these graph formats are checked automatically by the extensions of
graph files.
"""
if isinstance(filename, str):
@@ -94,14 +98,16 @@ class DataLoader():
load_file_fun = self.load_ct
elif extension == 'gxl' or extension == 'sdf': # @todo: .sdf not tested yet.
load_file_fun = self.load_gxl
elif extension == 'cml': # dataset "Chiral"
load_file_fun = self.load_cml

if filename_targets is None or filename_targets == '':
for i in range(0, len(content)):
tmp = content[i].split(' ')
# remove the '#'s in file names
g, l_names = load_file_fun(dirname_dataset + '/' + tmp[0].replace('#', '', 1))
data.append(g)
self._append_label_names(label_names, l_names)
self._append_label_names(label_names, l_names) # @todo: this is so redundant.
y.append(float(tmp[1]))
else: # targets in a seperate file
for i in range(0, len(content)):
@@ -110,7 +116,7 @@ class DataLoader():
g, l_names = load_file_fun(dirname_dataset + '/' + tmp.replace('#', '', 1))
data.append(g)
self._append_label_names(label_names, l_names)
with open(filename_targets) as fnt:
content_y = fnt.read().splitlines()
# assume entries in filename and filename_targets have the same order.
@@ -118,36 +124,51 @@ class DataLoader():
tmp = item.split(' ')
# assume the 3rd entry in a line is y (for Alkane dataset)
y.append(float(tmp[2]))
return data, y, label_names


def load_from_xml(self, filename, dir_dataset=None):
import xml.etree.ElementTree as ET
if dir_dataset is not None:
dir_dataset = dir_dataset
else:
dir_dataset = dirname(filename)
tree = ET.parse(filename)
root = tree.getroot()

def load_one_file(filename, data, y, label_names):
tree = ET.parse(filename)
root = tree.getroot()
for graph in root.iter('graph') if root.find('graph') is not None else root.iter('print'): # "graph" for ... I forgot; "print" for datasets GREC and Web.
mol_filename = graph.attrib['file']
mol_class = graph.attrib['class']
g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename)
data.append(g)
self._append_label_names(label_names, l_names)
y.append(mol_class)

data = []
y = []
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
for graph in root.iter('graph'):
mol_filename = graph.attrib['file']
mol_class = graph.attrib['class']
g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename)
data.append(g)
self._append_label_names(label_names, l_names)
y.append(mol_class)

if isinstance(filename, str):
if dir_dataset is not None:
dir_dataset = dir_dataset
else:
dir_dataset = dirname(filename)
load_one_file(filename, data, y, label_names)


else: # filename is a list of files.
if dir_dataset is not None:
dir_dataset = dir_dataset
else:
dir_dataset = dirname(filename[0])

for fn in filename:
load_one_file(fn, data, y, label_names)

return data, y, label_names
def load_mat(self, filename, order): # @todo: need to be updated (auto order) or deprecated.
"""Load graph data from a MATLAB (up to version 7.1) .mat file.
Notes
------
A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph.
@@ -184,17 +205,17 @@ class DataLoader():
for col, row in zip(index_no0[0], index_no0[1]):
g.add_edge(col, row)
data.append(g)
label_names = {'node_labels': ['label_1'], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
if order[1] == 0:
label_names['edge_labels'].append('label_1')
return data, y, label_names
def load_tud(self, filename):
"""Load graph data from TUD dataset files.
Notes
------
The graph data is loaded from separate files.
@@ -203,12 +224,12 @@ class DataLoader():
import networkx as nx
from os import listdir
from os.path import dirname, basename
def get_infos_from_readme(frm): # @todo: add README (cuniform), maybe node/edge label maps.
"""Get information from DS_label_readme.txt file.
"""
def get_label_names_from_line(line):
"""Get names of labels/attributes from a line.
"""
@@ -216,17 +237,17 @@ class DataLoader():
names = str_names.split(',')
names = [attr.strip() for attr in names]
return names
def get_class_label_map(label_map_strings):
label_map = {}
for string in label_map_strings:
integer, label = string.split('\t')
label_map[int(integer.strip())] = label.strip()
return label_map
label_names = {'node_labels': [], 'node_attrs': [],
label_names = {'node_labels': [], 'node_attrs': [],
'edge_labels': [], 'edge_attrs': []}
class_label_map = None
class_label_map_strings = []
@@ -254,16 +275,16 @@ class DataLoader():
line = content_rm[i].strip()
class_label_map = get_class_label_map(class_label_map_strings)
i += 1
return label_names, class_label_map
# get dataset name.
dirname_dataset = dirname(filename)
filename = basename(filename)
fn_split = filename.split('_A')
ds_name = fn_split[0].strip()
# load data file names
for name in listdir(dirname_dataset):
if ds_name + '_A' in name:
@@ -287,20 +308,20 @@ class DataLoader():
# this is supposed to be the node attrs, make sure to put this as the last 'elif'
elif ds_name + '_attributes' in name:
fna = dirname_dataset + '/' + name
# get labels and attributes names.
if 'frm' in locals():
label_names, class_label_map = get_infos_from_readme(frm)
else:
label_names = {'node_labels': [], 'node_attrs': [],
label_names = {'node_labels': [], 'node_attrs': [],
'edge_labels': [], 'edge_attrs': []}
class_label_map = None
with open(fgi) as gi:
content_gi = gi.read().splitlines() # graph indicator
with open(fam) as am:
content_am = am.read().splitlines() # adjacency matrix
# load targets.
if 'fgl' in locals():
with open(fgl) as gl:
@@ -314,7 +335,7 @@ class DataLoader():
raise Exception('Can not find targets file. Please make sure there is a "', ds_name, '_graph_labels.txt" or "', ds_name, '_graph_attributes.txt"', 'file in your dataset folder.')
if class_label_map is not None:
targets = [class_label_map[t] for t in targets]
# create graphs and add nodes
data = [nx.Graph(name=str(i)) for i in range(0, len(content_targets))]
if 'fnl' in locals():
@@ -335,7 +356,7 @@ class DataLoader():
else:
for i, line in enumerate(content_gi):
data[int(line) - 1].add_node(i)
# add edges
for line in content_am:
tmp = line.split(',')
@@ -344,7 +365,7 @@ class DataLoader():
# ignore edge weight here.
g = int(content_gi[n1]) - 1
data[g].add_edge(n1, n2)
# add edge labels
if 'fel' in locals():
with open(fel) as el:
@@ -361,7 +382,7 @@ class DataLoader():
else:
for i, l_name in enumerate(label_names['edge_labels']):
data[g].edges[n[0], n[1]][l_name] = labels[i]
# add node attributes
if 'fna' in locals():
with open(fna) as na:
@@ -377,7 +398,7 @@ class DataLoader():
else:
for i, a_name in enumerate(label_names['node_attrs']):
data[g].nodes[idx][a_name] = attrs[i]
# add edge attributes
if 'fea' in locals():
with open(fea) as ea:
@@ -394,29 +415,29 @@ class DataLoader():
else:
for i, a_name in enumerate(label_names['edge_attrs']):
data[g].edges[n[0], n[1]][a_name] = attrs[i]
return data, targets, label_names
def load_ct(self, filename): # @todo: this function is only tested on CTFile V2000; header not considered; only simple cases (atoms and bonds are considered.)
"""load data from a Chemical Table (.ct) file.
Notes
------
a typical example of data in .ct is like this:
3 2 <- number of nodes and edges
0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label)
0.0000 0.0000 0.0000 C
0.0000 0.0000 0.0000 O
1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo
2 3 1 1
Check `CTFile Formats file <https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=10&ved=2ahUKEwivhaSdjsTlAhVhx4UKHczHA8gQFjAJegQIARAC&url=https%3A%2F%2Fwww.daylight.com%2Fmeetings%2Fmug05%2FKappler%2Fctfile.pdf&usg=AOvVaw1cDNrrmMClkFPqodlF2inS>`__
for detailed format discription.
"""
@@ -426,7 +447,7 @@ class DataLoader():
with open(filename) as f:
content = f.read().splitlines()
g = nx.Graph(name=str(content[0]), filename=basename(filename)) # set name of the graph
# read the counts line.
tmp = content[1].split(' ')
tmp = [x for x in tmp if x != '']
@@ -438,7 +459,7 @@ class DataLoader():
if count_line_tags[i] != '': # if not obsoleted
g.graph[count_line_tags[i]] = tmp[i].strip()
i += 1
# read the atom block.
atom_tags = ['x', 'y', 'z', 'atom_symbol', 'mass_difference', 'charge', 'atom_stereo_parity', 'hydrogen_count_plus_1', 'stereo_care_box', 'valence', 'h0_designator', '', '', 'atom_atom_mapping_number', 'inversion_retention_flag', 'exact_change_flag']
for i in range(0, nb_atoms):
@@ -450,7 +471,7 @@ class DataLoader():
if atom_tags[j] != '':
g.nodes[i][atom_tags[j]] = tmp[j].strip()
j += 1
# read the bond block.
bond_tags = ['first_atom_number', 'second_atom_number', 'bond_type', 'bond_stereo', '', 'bond_topology', 'reacting_center_status']
for i in range(0, nb_bonds):
@@ -463,7 +484,7 @@ class DataLoader():
if bond_tags[j] != '':
g.edges[(n1, n2)][bond_tags[j]] = tmp[j].strip()
j += 1
# get label names.
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
atom_symbolic = [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, None, None, 1, 1, 1]
@@ -482,15 +503,15 @@ class DataLoader():
else:
label_names['edge_attrs'].append(key)
break
return g, label_names
def load_gxl(self, filename): # @todo: directed graphs.
from os.path import basename
import networkx as nx
import xml.etree.ElementTree as ET
tree = ET.parse(filename)
root = tree.getroot()
index = 0
@@ -499,65 +520,140 @@ class DataLoader():
for node in root.iter('node'):
dic[node.attrib['id']] = index
labels = {}
for attr in node.iter('attr'):
for attr in node.iter('attr'): # for datasets "GREC" and "Monoterpens".
labels[attr.attrib['name']] = attr[0].text
for attr in node.iter('attribute'): # for dataset "Web".
labels[attr.attrib['name']] = attr.attrib['value']
g.add_node(index, **labels)
index += 1
for edge in root.iter('edge'):
labels = {}
for attr in edge.iter('attr'):
for attr in edge.iter('attr'): # for datasets "GREC" and "Monoterpens".
labels[attr.attrib['name']] = attr[0].text
for attr in edge.iter('attribute'): # for dataset "Web".
labels[attr.attrib['name']] = attr.attrib['value']
g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
# get label names.
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
# @todo: possible loss of label names if some nodes miss some labels.
for node in root.iter('node'):
for attr in node.iter('attr'):
if attr[0].tag == 'int': # @todo: this maybe wrong, and slow.
for attr in node.iter('attr'): # for datasets "GREC" and "Monoterpens".
if attr[0].tag == 'int' or attr.attrib['name'] == 'type': # @todo: this maybe wrong, and slow. "type" is for dataset GREC; "int" is for dataset "Monoterpens".
label_names['node_labels'].append(attr.attrib['name'])
else:
label_names['node_attrs'].append(attr.attrib['name'])

for attr in node.iter('attribute'): # for dataset "Web".
label_names['node_attrs'].append(attr.attrib['name'])
# @todo: is id useful in dataset "Web"? is "FREQUENCY" symbolic or not?
break

for edge in root.iter('edge'):
for attr in edge.iter('attr'):
if attr[0].tag == 'int': # @todo: this maybe wrong, and slow.
for attr in edge.iter('attr'): # for datasets "GREC" and "Monoterpens".
if attr[0].tag == 'int' or attr.attrib['name'] == 'frequency' or 'type' in attr.attrib['name']: # @todo: this maybe wrong, and slow. "frequency" and "type" are for dataset GREC; "int" is for dataset "Monoterpens".
label_names['edge_labels'].append(attr.attrib['name'])
else:
label_names['edge_attrs'].append(attr.attrib['name'])

for attr in edge.iter('attribute'): # for dataset "Web".
label_names['edge_attrs'].append(attr.attrib['name'])
break
return g, label_names


def load_cml(self, filename): # @todo: directed graphs.
# @todo: what is "atomParity" and "bondStereo" in the data file?
from os.path import basename
import networkx as nx
import xml.etree.ElementTree as ET

xmlns = '{http://www.xml-cml.org/schema}' # @todo: why this has to be added?
tree = ET.parse(filename)
root = tree.getroot()
index = 0
if root.tag == xmlns + 'molecule':
g_id = root.attrib['id']
else:
g_id = root.find(xmlns + 'molecule').attrib['id']
g = nx.Graph(filename=basename(filename), name=g_id)
dic = {} # used to retrieve incident nodes of edges
for atom in root.iter(xmlns + 'atom'):
dic[atom.attrib['id']] = index
labels = {}
for key, val in atom.attrib.items():
if key != 'id':
labels[key] = val
g.add_node(index, **labels)
index += 1

for bond in root.iter(xmlns + 'bond'):
labels = {}
for key, val in bond.attrib.items():
if key != 'atomRefs2' and key != 'id': # "id" is in dataset "ACE".
labels[key] = val
n1, n2 = bond.attrib['atomRefs2'].strip().split(' ')
g.add_edge(dic[n1], dic[n2], **labels)

# get label names.
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
# @todo: possible loss of label names if some nodes miss some labels.
for key, val in g.nodes[0].items():
try:
float(val)
except:
label_names['node_labels'].append(key)
else:
if val.isdigit():
label_names['node_labels'].append(key)
else:
label_names['node_attrs'].append(key)
for _, _, attrs in g.edges(data=True):
for key, val in attrs.items():
try:
float(val)
except:
label_names['edge_labels'].append(key)
else:
if val.isdigit():
label_names['edge_labels'].append(key)
else:
label_names['edge_attrs'].append(key)
break

return g, label_names


def _append_label_names(self, label_names, new_names):
for key, val in label_names.items():
label_names[key] += [name for name in new_names[key] if name not in val]
@property
def data(self):
return self._graphs, self._targets, self._label_names
@property
def graphs(self):
return self._graphs
@property
def targets(self):
return self._targets
@property
def label_names(self):
return self._label_names
class DataSaver():
def __init__(self, graphs, targets=None, filename='gfile', gformat='gxl', group=None, **kwargs):
"""Save list of graphs.
"""
@@ -566,14 +662,14 @@ class DataSaver():
if dirname_ds != '':
dirname_ds += '/'
os.makedirs(dirname_ds, exist_ok=True)
if 'graph_dir' in kwargs:
graph_dir = kwargs['graph_dir'] + '/'
os.makedirs(graph_dir, exist_ok=True)
del kwargs['graph_dir']
else:
graph_dir = dirname_ds
graph_dir = dirname_ds
if group == 'xml' and gformat == 'gxl':
with open(filename + '.xml', 'w') as fgroup:
fgroup.write("<?xml version=\"1.0\"?>")
@@ -600,20 +696,20 @@ class DataSaver():
gxl_file.write("<graph id=\"" + name + "\" edgeids=\"false\" edgemode=\"undirected\">\n")
for v, attrs in graph.nodes(data=True):
gxl_file.write("<node id=\"_" + str(v) + "\">")
for l_name in node_labels:
gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
for l_name in node_labels:
gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
str(attrs[l_name]) + "</int></attr>")
for a_name in node_attrs:
gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
for a_name in node_attrs:
gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
str(attrs[a_name]) + "</float></attr>")
gxl_file.write("</node>\n")
for v1, v2, attrs in graph.edges(data=True):
gxl_file.write("<edge from=\"_" + str(v1) + "\" to=\"_" + str(v2) + "\">")
for l_name in edge_labels:
gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
for l_name in edge_labels:
gxl_file.write("<attr name=\"" + l_name + "\"><int>" +
str(attrs[l_name]) + "</int></attr>")
for a_name in edge_attrs:
gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
for a_name in edge_attrs:
gxl_file.write("<attr name=\"" + a_name + "\"><float>" +
str(attrs[a_name]) + "</float></attr>")
gxl_file.write("</edge>\n")
gxl_file.write("</graph>\n")
@@ -627,7 +723,7 @@ class DataSaver():
attr['edgeids'] = 'true'
attr['edgemode'] = 'undirected'
graph_node = ET.SubElement(root_node, 'graph', attrib=attr)
for v in graph:
current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)})
for attr in graph.nodes[v].keys():
@@ -636,7 +732,7 @@ class DataSaver():
cur_value = ET.SubElement(cur_attr,
graph.nodes[v][attr].__class__.__name__)
cur_value.text = graph.nodes[v][attr]
for v1 in graph:
for v2 in graph[v1]:
if (v1 < v2): # Non oriented graphs
@@ -653,7 +749,7 @@ class DataSaver():
cur_value = ET.SubElement(
cur_attr, graph[v1][v2][attr].__class__.__name__)
cur_value.text = str(graph[v1][v2][attr])
tree = ET.ElementTree(root_node)
tree.write(filename)
elif method == 'gedlib':
@@ -748,7 +844,7 @@ class DataSaver():

# def load_from_cxl(filename):
# import xml.etree.ElementTree as ET
#
#
# dirname_dataset = dirname(filename)
# tree = ET.parse(filename)
# root = tree.getroot()
@@ -759,9 +855,9 @@ class DataSaver():
# mol_class = graph.attrib['class']
# data.append(load_gxl(dirname_dataset + '/' + mol_filename))
# y.append(mol_class)
if __name__ == '__main__':
if __name__ == '__main__':
# ### Load dataset from .ds file.
# # .ct files.
# ds = {'name': 'Alkane', 'dataset': '../../datasets/Alkane/dataset.ds',
@@ -777,7 +873,7 @@ if __name__ == '__main__':
# print(Gn[1].nodes(data=True))
# print(Gn[1].edges(data=True))
# print(targets[1])
# # .gxl file.
# ds_file = '../../datasets/monoterpenoides/dataset_10+.ds' # node/edge symb
# Gn, y, label_names = load_dataset(ds_file)
@@ -799,7 +895,7 @@ if __name__ == '__main__':
# ### Convert graph from one format to another.
# # .gxl file.
# import networkx as nx
# ds = {'name': 'monoterpenoides',
# ds = {'name': 'monoterpenoides',
# 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb
# Gn, y = loadDataset(ds['dataset'])
# y = [int(i) for i in y]
@@ -822,13 +918,13 @@ if __name__ == '__main__':
# filename = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/gxl/monoterpenoides'
# xparams = {'method': 'gedlib'}
# saveDataset(Gn, y, gformat='gxl', group='xml', filename=filename, xparams=xparams)
# save dataset.
# ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat',
# 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb
# Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params'])
# saveDataset(Gn, y, group='xml', filename='temp/temp')
# test - new way to add labels and attributes.
# dataset = '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt'
# filename = '../../datasets/Fingerprint/Fingerprint_A.txt'


+ 15
- 11
gklearn/dataset/metadata.py View File

@@ -33,6 +33,7 @@ GREYC_META = {
'train_valid_test': [],
'stereoisomerism': True,
'load_files': ['data.ds'],
'extra_params': {'gformat': 'cml'}
},
'Acyclic': {
'database': 'greyc',
@@ -108,7 +109,8 @@ GREYC_META = {
'domain': 'small molecules',
'train_valid_test': [],
'stereoisomerism': True,
'load_files': [],
'load_files': ['data.txt'],
'extra_params': {'gformat': 'cml'}
},
'MAO': {
'database': 'greyc',
@@ -204,7 +206,8 @@ GREYC_META = {
'domain': 'small molecules',
'train_valid_test': ['trainset_0.ds', None, 'testset_0.ds'],
'stereoisomerism': False,
'load_files': [],
'load_files': ['dataWithOutsider.ds'],
'extra_params': {'gformat': 'cml'}
},
'Vitamin_D': {
'database': 'greyc',
@@ -223,7 +226,8 @@ GREYC_META = {
'domain': 'small molecules',
'train_valid_test': [],
'stereoisomerism': True,
'load_files': [],
'load_files': ['data.txt'],
'extra_params': {'gformat': 'cml'}
},
}

@@ -247,7 +251,7 @@ IAM_META = {
'url': 'https://iapr-tc15.greyc.fr/IAM/GREC.zip',
'domain': None,
'train_valid_test': ['data/test.cxl','data/train.cxl', 'data/valid.cxl'],
'load_files': [],
'load_files': [['data/test.cxl','data/train.cxl', 'data/valid.cxl']],
},
'Web': {
'database': 'iam',
@@ -265,13 +269,13 @@ IAM_META = {
'url': 'https://iapr-tc15.greyc.fr/IAM/Web.zip',
'domain': None,
'train_valid_test': ['data/test.cxl', 'data/train.cxl', 'data/valid.cxl'],
'load_files': [],
'load_files': [['data/test.cxl','data/train.cxl', 'data/valid.cxl']],
},
}

### -------- database tudataset -------- ###
TUDataset_META = {
TUDataset_META = {
### small molecules
'AIDS': {
'database': 'tudataset',
@@ -1697,7 +1701,7 @@ TUDataset_META = {
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/ZINC_val.zip',
'domain': 'small molecules',
},
### bioinformatics
'DD': {
'database': 'tudataset',
@@ -1811,7 +1815,7 @@ TUDataset_META = {
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/PROTEINS_full.zip',
'domain': 'bioinformatics',
},
### computer vision
'COIL-DEL': {
'database': 'tudataset',
@@ -1989,7 +1993,7 @@ TUDataset_META = {
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/MSRC_21C.zip',
'domain': 'computer vision',
},
### social networks
'COLLAB': {
'database': 'tudataset',
@@ -2375,7 +2379,7 @@ TUDataset_META = {
'url': 'https://www.chrsmrrs.com/graphkerneldatasets/TWITTER-Real-Graph-Partial.zip',
'domain': 'social networks',
},
### synthetic
'COLORS-3': {
'database': 'tudataset',


Loading…
Cancel
Save