diff --git a/gklearn/kernels/treeletKernel.py b/gklearn/kernels/treeletKernel.py index 51cdd08..946f20f 100644 --- a/gklearn/kernels/treeletKernel.py +++ b/gklearn/kernels/treeletKernel.py @@ -310,7 +310,7 @@ def get_canonkeys(G, node_label, edge_label, labeled, is_directed): for pattern in patterns[str(i)]: canonlist = list(chain.from_iterable((G.nodes[node][node_label], \ G[node][pattern[idx+1]][edge_label]) for idx, node in enumerate(pattern[:-1]))) - canonlist.append(G.node[pattern[-1]][node_label]) + canonlist.append(G.nodes[pattern[-1]][node_label]) canonkey_t = canonlist if canonlist < canonlist[::-1] else canonlist[::-1] treelet.append(tuple([str(i)] + canonkey_t)) canonkey_l.update(Counter(treelet)) @@ -319,26 +319,26 @@ def get_canonkeys(G, node_label, edge_label, labeled, is_directed): for i in range(3, 6): treelet = [] for pattern in patterns[str(i) + 'star']: - canonlist = [tuple((G.node[leaf][node_label], + canonlist = [tuple((G.nodes[leaf][node_label], G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:]] canonlist.sort() canonlist = list(chain.from_iterable(canonlist)) canonkey_t = tuple(['d' if i == 5 else str(i * 2)] + - [G.node[pattern[0]][node_label]] + canonlist) + [G.nodes[pattern[0]][node_label]] + canonlist) treelet.append(canonkey_t) canonkey_l.update(Counter(treelet)) # pattern 7 treelet = [] for pattern in patterns['7']: - canonlist = [tuple((G.node[leaf][node_label], + canonlist = [tuple((G.nodes[leaf][node_label], G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:3]] canonlist.sort() canonlist = list(chain.from_iterable(canonlist)) - canonkey_t = tuple(['7'] + [G.node[pattern[0]][node_label]] + canonlist - + [G.node[pattern[3]][node_label]] + canonkey_t = tuple(['7'] + [G.nodes[pattern[0]][node_label]] + canonlist + + [G.nodes[pattern[3]][node_label]] + [G[pattern[3]][pattern[0]][edge_label]] - + [G.node[pattern[4]][node_label]] + + [G.nodes[pattern[4]][node_label]] + [G[pattern[4]][pattern[3]][edge_label]]) treelet.append(canonkey_t) canonkey_l.update(Counter(treelet)) @@ -346,14 +346,14 @@ def get_canonkeys(G, node_label, edge_label, labeled, is_directed): # pattern 11 treelet = [] for pattern in patterns['11']: - canonlist = [tuple((G.node[leaf][node_label], + canonlist = [tuple((G.nodes[leaf][node_label], G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:4]] canonlist.sort() canonlist = list(chain.from_iterable(canonlist)) - canonkey_t = tuple(['b'] + [G.node[pattern[0]][node_label]] + canonlist - + [G.node[pattern[4]][node_label]] + canonkey_t = tuple(['b'] + [G.nodes[pattern[0]][node_label]] + canonlist + + [G.nodes[pattern[4]][node_label]] + [G[pattern[4]][pattern[0]][edge_label]] - + [G.node[pattern[5]][node_label]] + + [G.nodes[pattern[5]][node_label]] + [G[pattern[5]][pattern[4]][edge_label]]) treelet.append(canonkey_t) canonkey_l.update(Counter(treelet)) @@ -361,15 +361,15 @@ def get_canonkeys(G, node_label, edge_label, labeled, is_directed): # pattern 10 treelet = [] for pattern in patterns['10']: - canonkey4 = [G.node[pattern[5]][node_label], G[pattern[5]][pattern[4]][edge_label]] - canonlist = [tuple((G.node[leaf][node_label], + canonkey4 = [G.nodes[pattern[5]][node_label], G[pattern[5]][pattern[4]][edge_label]] + canonlist = [tuple((G.nodes[leaf][node_label], G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:3]] canonlist.sort() canonkey0 = list(chain.from_iterable(canonlist)) - canonkey_t = tuple(['a'] + [G.node[pattern[3]][node_label]] - + [G.node[pattern[4]][node_label]] + canonkey_t = tuple(['a'] + [G.nodes[pattern[3]][node_label]] + + [G.nodes[pattern[4]][node_label]] + [G[pattern[4]][pattern[3]][edge_label]] - + [G.node[pattern[0]][node_label]] + + [G.nodes[pattern[0]][node_label]] + [G[pattern[0]][pattern[3]][edge_label]] + canonkey4 + canonkey0) treelet.append(canonkey_t) @@ -378,23 +378,23 @@ def get_canonkeys(G, node_label, edge_label, labeled, is_directed): # pattern 12 treelet = [] for pattern in patterns['12']: - canonlist0 = [tuple((G.node[leaf][node_label], + canonlist0 = [tuple((G.nodes[leaf][node_label], G[leaf][pattern[0]][edge_label])) for leaf in pattern[1:3]] canonlist0.sort() canonlist0 = list(chain.from_iterable(canonlist0)) - canonlist3 = [tuple((G.node[leaf][node_label], + canonlist3 = [tuple((G.nodes[leaf][node_label], G[leaf][pattern[3]][edge_label])) for leaf in pattern[4:6]] canonlist3.sort() canonlist3 = list(chain.from_iterable(canonlist3)) # 2 possible key can be generated from 2 nodes with extended label 3, # select the one with lower lexicographic order. - canonkey_t1 = tuple(['c'] + [G.node[pattern[0]][node_label]] + canonlist0 - + [G.node[pattern[3]][node_label]] + canonkey_t1 = tuple(['c'] + [G.nodes[pattern[0]][node_label]] + canonlist0 + + [G.nodes[pattern[3]][node_label]] + [G[pattern[3]][pattern[0]][edge_label]] + canonlist3) - canonkey_t2 = tuple(['c'] + [G.node[pattern[3]][node_label]] + canonlist3 - + [G.node[pattern[0]][node_label]] + canonkey_t2 = tuple(['c'] + [G.nodes[pattern[3]][node_label]] + canonlist3 + + [G.nodes[pattern[0]][node_label]] + [G[pattern[0]][pattern[3]][edge_label]] + canonlist0) treelet.append(canonkey_t1 if canonkey_t1 < canonkey_t2 else canonkey_t2) @@ -403,19 +403,19 @@ def get_canonkeys(G, node_label, edge_label, labeled, is_directed): # pattern 9 treelet = [] for pattern in patterns['9']: - canonkey2 = [G.node[pattern[4]][node_label], G[pattern[4]][pattern[2]][edge_label]] - canonkey3 = [G.node[pattern[5]][node_label], G[pattern[5]][pattern[3]][edge_label]] - prekey2 = [G.node[pattern[2]][node_label], G[pattern[2]][pattern[0]][edge_label]] - prekey3 = [G.node[pattern[3]][node_label], G[pattern[3]][pattern[0]][edge_label]] + canonkey2 = [G.nodes[pattern[4]][node_label], G[pattern[4]][pattern[2]][edge_label]] + canonkey3 = [G.nodes[pattern[5]][node_label], G[pattern[5]][pattern[3]][edge_label]] + prekey2 = [G.nodes[pattern[2]][node_label], G[pattern[2]][pattern[0]][edge_label]] + prekey3 = [G.nodes[pattern[3]][node_label], G[pattern[3]][pattern[0]][edge_label]] if prekey2 + canonkey2 < prekey3 + canonkey3: - canonkey_t = [G.node[pattern[1]][node_label]] \ + canonkey_t = [G.nodes[pattern[1]][node_label]] \ + [G[pattern[1]][pattern[0]][edge_label]] \ + prekey2 + prekey3 + canonkey2 + canonkey3 else: - canonkey_t = [G.node[pattern[1]][node_label]] \ + canonkey_t = [G.nodes[pattern[1]][node_label]] \ + [G[pattern[1]][pattern[0]][edge_label]] \ + prekey3 + prekey2 + canonkey3 + canonkey2 - treelet.append(tuple(['9'] + [G.node[pattern[0]][node_label]] + canonkey_t)) + treelet.append(tuple(['9'] + [G.nodes[pattern[0]][node_label]] + canonkey_t)) canonkey_l.update(Counter(treelet)) return canonkey_l diff --git a/gklearn/preimage/median_preimage_generator.py b/gklearn/preimage/median_preimage_generator.py index b4f4f80..ef1d57a 100644 --- a/gklearn/preimage/median_preimage_generator.py +++ b/gklearn/preimage/median_preimage_generator.py @@ -99,6 +99,11 @@ class MedianPreimageGenerator(PreimageGenerator): self._graph_kernel.gram_matrix = self._graph_kernel.normalize_gm(np.copy(self.__gram_matrix_unnorm)) end_precompute_gm = time.time() start -= self.__runtime_precompute_gm + + if self.__fit_method != 'k-graphs' and self.__fit_method != 'whole-dataset': + start = time.time() + self.__runtime_precompute_gm = 0 + end_precompute_gm = start # 2. optimize edit cost constants. self.__optimize_edit_cost_constants() diff --git a/gklearn/preimage/utils.py b/gklearn/preimage/utils.py index cbe00a1..a3a661e 100644 --- a/gklearn/preimage/utils.py +++ b/gklearn/preimage/utils.py @@ -24,7 +24,7 @@ import csv import networkx as nx -def generate_median_preimages_by_class(ds_name, mpg_options, kernel_options, ged_options, mge_options, save_results=True, save_medians=True, plot_medians=True, load_gm='auto', dir_save='', irrelevant_labels=None): +def generate_median_preimages_by_class(ds_name, mpg_options, kernel_options, ged_options, mge_options, save_results=True, save_medians=True, plot_medians=True, load_gm='auto', dir_save='', irrelevant_labels=None, edge_required=False): import os.path from gklearn.preimage import MedianPreimageGenerator from gklearn.utils import split_dataset_by_target @@ -34,7 +34,8 @@ def generate_median_preimages_by_class(ds_name, mpg_options, kernel_options, ged print('1. getting dataset...') dataset_all = Dataset() dataset_all.load_predefined_dataset(ds_name) - if not irrelevant_labels is None: + dataset_all.trim_dataset(edge_required=edge_required) + if irrelevant_labels is not None: dataset_all.remove_labels(**irrelevant_labels) # dataset_all.cut_graphs(range(0, 100)) datasets = split_dataset_by_target(dataset_all) @@ -228,7 +229,7 @@ def generate_median_preimages_by_class(ds_name, mpg_options, kernel_options, ged # plot median graphs. if plot_medians and save_medians: - if ds_name == 'Letter-high' or ds_name == 'Letter-med' or ds_name == 'Letter-low': + if ged_options['edit_cost'] == 'LETTER2' or ged_options['edit_cost'] == 'LETTER' or ds_name == 'Letter-high' or ds_name == 'Letter-med' or ds_name == 'Letter-low': draw_Letter_graph(mpg.set_median, fn_pre_sm) draw_Letter_graph(mpg.gen_median, fn_pre_gm) draw_Letter_graph(mpg.best_from_dataset, fn_best_dataset) diff --git a/gklearn/utils/__init__.py b/gklearn/utils/__init__.py index d5301c6..78832d3 100644 --- a/gklearn/utils/__init__.py +++ b/gklearn/utils/__init__.py @@ -16,6 +16,7 @@ __date__ = "November 2017" # from utils import graphfiles # from utils import utils from gklearn.utils.dataset import Dataset, split_dataset_by_target +from gklearn.utils.graph_files import load_dataset, save_dataset from gklearn.utils.timer import Timer from gklearn.utils.utils import get_graph_kernel_by_name from gklearn.utils.utils import compute_gram_matrices_by_class diff --git a/gklearn/utils/dataset.py b/gklearn/utils/dataset.py index 7c2b732..6f5389c 100644 --- a/gklearn/utils/dataset.py +++ b/gklearn/utils/dataset.py @@ -7,13 +7,13 @@ Created on Thu Mar 26 18:48:27 2020 """ import numpy as np import networkx as nx -from gklearn.utils.graphfiles import loadDataset +from gklearn.utils.graph_files import load_dataset import os class Dataset(object): - def __init__(self, filename=None, filename_y=None, extra_params=None): + def __init__(self, filename=None, filename_targets=None, **kwargs): if filename is None: self.__graphs = None self.__targets = None @@ -22,7 +22,7 @@ class Dataset(object): self.__node_attrs = None self.__edge_attrs = None else: - self.load_dataset(filename, filename_y=filename_y, extra_params=extra_params) + self.load_dataset(filename, filename_targets=filename_targets, **kwargs) self.__substructures = None self.__node_label_dim = None @@ -50,9 +50,12 @@ class Dataset(object): self.__class_number = None - def load_dataset(self, filename, filename_y=None, extra_params=None): - self.__graphs, self.__targets = loadDataset(filename, filename_y=filename_y, extra_params=extra_params) - self.set_labels_attrs() + def load_dataset(self, filename, filename_targets=None, **kwargs): + self.__graphs, self.__targets, label_names = load_dataset(filename, filename_targets=filename_targets, **kwargs) + self.__node_labels = label_names['node_labels'] + self.__node_attrs = label_names['node_attrs'] + self.__edge_labels = label_names['edge_labels'] + self.__edge_attrs = label_names['edge_attrs'] def load_graphs(self, graphs, targets=None): @@ -66,26 +69,26 @@ class Dataset(object): current_path = os.path.dirname(os.path.realpath(__file__)) + '/' if ds_name == 'Letter-high': # node non-symb ds_file = current_path + '../../datasets/Letter-high/Letter-high_A.txt' - self.__graphs, self.__targets = loadDataset(ds_file) + self.__graphs, self.__targets, label_names = load_dataset(ds_file) elif ds_name == 'Letter-med': # node non-symb ds_file = current_path + '../../datasets/Letter-high/Letter-med_A.txt' - self.__graphs, self.__targets = loadDataset(ds_file) + self.__graphs, self.__targets, label_names = load_dataset(ds_file) elif ds_name == 'Letter-low': # node non-symb ds_file = current_path + '../../datasets/Letter-high/Letter-low_A.txt' - self.__graphs, self.__targets = loadDataset(ds_file) + self.__graphs, self.__targets, label_names = load_dataset(ds_file) elif ds_name == 'Fingerprint': ds_file = current_path + '../../datasets/Fingerprint/Fingerprint_A.txt' - self.__graphs, self.__targets = loadDataset(ds_file) + self.__graphs, self.__targets, label_names = load_dataset(ds_file) elif ds_name == 'SYNTHETIC': pass elif ds_name == 'SYNTHETICnew': ds_file = current_path + '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt' - self.__graphs, self.__targets = loadDataset(ds_file) + self.__graphs, self.__targets, label_names = load_dataset(ds_file) elif ds_name == 'Synthie': pass elif ds_name == 'COIL-DEL': ds_file = current_path + '../../datasets/COIL-DEL/COIL-DEL_A.txt' - self.__graphs, self.__targets = loadDataset(ds_file) + self.__graphs, self.__targets, label_names = load_dataset(ds_file) elif ds_name == 'COIL-RAG': pass elif ds_name == 'COLORS-3': @@ -93,7 +96,10 @@ class Dataset(object): elif ds_name == 'FRANKENSTEIN': pass - self.set_labels_attrs() + self.__node_labels = label_names['node_labels'] + self.__node_attrs = label_names['node_attrs'] + self.__edge_labels = label_names['edge_labels'] + self.__edge_attrs = label_names['edge_attrs'] def set_labels(self, node_labels=[], node_attrs=[], edge_labels=[], edge_attrs=[]): diff --git a/gklearn/utils/graph_files.py b/gklearn/utils/graph_files.py new file mode 100644 index 0000000..c00149e --- /dev/null +++ b/gklearn/utils/graph_files.py @@ -0,0 +1,770 @@ +""" Utilities function to manage graph files +""" +from os.path import dirname, splitext + + +def load_dataset(filename, filename_targets=None, gformat=None, **kwargs): + """Read graph data from filename and load them as NetworkX graphs. + + Parameters + ---------- + filename : string + The name of the file from where the dataset is read. + filename_y : string + The name of file of the targets corresponding to graphs. + extra_params : dict + Extra parameters only designated to '.mat' format. + + Return + ------ + data : List of NetworkX graph. + + y : List + + Targets corresponding to graphs. + + Notes + ----- + This function supports following graph dataset formats: + + 'ds': load data from .ds file. See comments of function loadFromDS for a example. + + 'cxl': load data from Graph eXchange Language file (.cxl file). See + `here `__ for detail. + + 'sdf': load data from structured data file (.sdf file). See + `here `__ + for details. + + 'mat': Load graph data from a MATLAB (up to version 7.1) .mat file. See + README in `downloadable file `__ + for details. + + 'txt': Load graph data from a special .txt file. See + `here `__ + for details. Note here filename is the name of either .txt file in + the dataset directory. + """ + extension = splitext(filename)[1][1:] + if extension == "ds": + data, y = loadFromDS(filename, filename_targets) + elif extension == "cxl": + import xml.etree.ElementTree as ET + + dirname_dataset = dirname(filename) + tree = ET.parse(filename) + root = tree.getroot() + data = [] + y = [] + for graph in root.iter('graph'): + mol_filename = graph.attrib['file'] + mol_class = graph.attrib['class'] + data.append(loadGXL(dirname_dataset + '/' + mol_filename)) + y.append(mol_class) + elif extension == 'xml': + dir_dataset = kwargs.get('dirname_dataset', None) + data, y = loadFromXML(filename, dir_dataset) + elif extension == "sdf": +# import numpy as np + from tqdm import tqdm + import sys + + data = loadSDF(filename) + + y_raw = open(filename_targets).read().splitlines() + y_raw.pop(0) + tmp0 = [] + tmp1 = [] + for i in range(0, len(y_raw)): + tmp = y_raw[i].split(',') + tmp0.append(tmp[0]) + tmp1.append(tmp[1].strip()) + + y = [] + for i in tqdm(range(0, len(data)), desc='ajust data', file=sys.stdout): + try: + y.append(tmp1[tmp0.index(data[i].name)].strip()) + except ValueError: # if data[i].name not in tmp0 + data[i] = [] + data = list(filter(lambda a: a != [], data)) + elif extension == "mat": + order = kwargs.get('order') + data, y = loadMAT(filename, order) + elif extension == 'txt': + data, y, label_names = load_tud(filename) + + return data, y, label_names + + +def save_dataset(Gn, y, gformat='gxl', group=None, filename='gfile', xparams=None): + """Save list of graphs. + """ + import os + dirname_ds = os.path.dirname(filename) + if dirname_ds != '': + dirname_ds += '/' + if not os.path.exists(dirname_ds) : + os.makedirs(dirname_ds) + + if xparams is not None and 'graph_dir' in xparams: + graph_dir = xparams['graph_dir'] + '/' + if not os.path.exists(graph_dir): + os.makedirs(graph_dir) + else: + graph_dir = dirname_ds + + if group == 'xml' and gformat == 'gxl': + kwargs = {'method': xparams['method']} if xparams is not None else {} + with open(filename + '.xml', 'w') as fgroup: + fgroup.write("") + fgroup.write("\n") + fgroup.write("\n") + for idx, g in enumerate(Gn): + fname_tmp = "graph" + str(idx) + ".gxl" + saveGXL(g, graph_dir + fname_tmp, **kwargs) + fgroup.write("\n\t") + fgroup.write("\n") + fgroup.close() + + +def loadCT(filename): + """load data from a Chemical Table (.ct) file. + + Notes + ------ + a typical example of data in .ct is like this: + + 3 2 <- number of nodes and edges + + 0.0000 0.0000 0.0000 C <- each line describes a node (x,y,z + label) + + 0.0000 0.0000 0.0000 C + + 0.0000 0.0000 0.0000 O + + 1 3 1 1 <- each line describes an edge : to, from, bond type, bond stereo + + 2 3 1 1 + + Check `CTFile Formats file `__ + for detailed format discription. + """ + import networkx as nx + from os.path import basename + g = nx.Graph() + with open(filename) as f: + content = f.read().splitlines() + g = nx.Graph( + name = str(content[0]), + filename = basename(filename)) # set name of the graph + tmp = content[1].split(" ") + if tmp[0] == '': + nb_nodes = int(tmp[1]) # number of the nodes + nb_edges = int(tmp[2]) # number of the edges + else: + nb_nodes = int(tmp[0]) + nb_edges = int(tmp[1]) + # patch for compatibility : label will be removed later + for i in range(0, nb_nodes): + tmp = content[i + 2].split(" ") + tmp = [x for x in tmp if x != ''] + g.add_node(i, atom=tmp[3].strip(), + label=[item.strip() for item in tmp[3:]], + attributes=[item.strip() for item in tmp[0:3]]) + for i in range(0, nb_edges): + tmp = content[i + g.number_of_nodes() + 2].split(" ") + tmp = [x for x in tmp if x != ''] + g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1, + bond_type=tmp[2].strip(), + label=[item.strip() for item in tmp[2:]]) + return g + + +def loadGXL(filename): + from os.path import basename + import networkx as nx + import xml.etree.ElementTree as ET + + tree = ET.parse(filename) + root = tree.getroot() + index = 0 + g = nx.Graph(filename=basename(filename), name=root[0].attrib['id']) + dic = {} # used to retrieve incident nodes of edges + for node in root.iter('node'): + dic[node.attrib['id']] = index + labels = {} + for attr in node.iter('attr'): + labels[attr.attrib['name']] = attr[0].text + if 'chem' in labels: + labels['label'] = labels['chem'] + labels['atom'] = labels['chem'] + g.add_node(index, **labels) + index += 1 + + for edge in root.iter('edge'): + labels = {} + for attr in edge.iter('attr'): + labels[attr.attrib['name']] = attr[0].text + if 'valence' in labels: + labels['label'] = labels['valence'] + labels['bond_type'] = labels['valence'] + g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels) + return g + + +def saveGXL(graph, filename, method='default', node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]): + if method == 'default': + gxl_file = open(filename, 'w') + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("\n") + if 'name' in graph.graph: + name = str(graph.graph['name']) + else: + name = 'dummy' + gxl_file.write("\n") + for v, attrs in graph.nodes(data=True): + gxl_file.write("") + for l_name in node_labels: + gxl_file.write("" + + str(attrs[l_name]) + "") + for a_name in node_attrs: + gxl_file.write("" + + str(attrs[a_name]) + "") + gxl_file.write("\n") + for v1, v2, attrs in graph.edges(data=True): + gxl_file.write("") + for l_name in edge_labels: + gxl_file.write("" + + str(attrs[l_name]) + "") + for a_name in edge_attrs: + gxl_file.write("" + + str(attrs[a_name]) + "") + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("") + gxl_file.close() + elif method == 'benoit': + import xml.etree.ElementTree as ET + root_node = ET.Element('gxl') + attr = dict() + attr['id'] = str(graph.graph['name']) + attr['edgeids'] = 'true' + attr['edgemode'] = 'undirected' + graph_node = ET.SubElement(root_node, 'graph', attrib=attr) + + for v in graph: + current_node = ET.SubElement(graph_node, 'node', attrib={'id': str(v)}) + for attr in graph.nodes[v].keys(): + cur_attr = ET.SubElement( + current_node, 'attr', attrib={'name': attr}) + cur_value = ET.SubElement(cur_attr, + graph.nodes[v][attr].__class__.__name__) + cur_value.text = graph.nodes[v][attr] + + for v1 in graph: + for v2 in graph[v1]: + if (v1 < v2): # Non oriented graphs + cur_edge = ET.SubElement( + graph_node, + 'edge', + attrib={ + 'from': str(v1), + 'to': str(v2) + }) + for attr in graph[v1][v2].keys(): + cur_attr = ET.SubElement( + cur_edge, 'attr', attrib={'name': attr}) + cur_value = ET.SubElement( + cur_attr, graph[v1][v2][attr].__class__.__name__) + cur_value.text = str(graph[v1][v2][attr]) + + tree = ET.ElementTree(root_node) + tree.write(filename) + elif method == 'gedlib': + # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22 +# pass + gxl_file = open(filename, 'w') + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("\n") + for v, attrs in graph.nodes(data=True): + gxl_file.write("") + gxl_file.write("" + str(attrs['chem']) + "") + gxl_file.write("\n") + for v1, v2, attrs in graph.edges(data=True): + gxl_file.write("") + gxl_file.write("" + str(attrs['valence']) + "") +# gxl_file.write("" + "1" + "") + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("") + gxl_file.close() + elif method == 'gedlib-letter': + # reference: https://github.com/dbblumenthal/gedlib/blob/master/data/generate_molecules.py#L22 + # and https://github.com/dbblumenthal/gedlib/blob/master/data/datasets/Letter/HIGH/AP1_0000.gxl + gxl_file = open(filename, 'w') + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("\n") + for v, attrs in graph.nodes(data=True): + gxl_file.write("") + gxl_file.write("" + str(attrs['attributes'][0]) + "") + gxl_file.write("" + str(attrs['attributes'][1]) + "") + gxl_file.write("\n") + for v1, v2, attrs in graph.edges(data=True): + gxl_file.write("\n") + gxl_file.write("\n") + gxl_file.write("") + gxl_file.close() + + +def loadSDF(filename): + """load data from structured data file (.sdf file). + + Notes + ------ + A SDF file contains a group of molecules, represented in the similar way as in MOL format. + Check `here `__ for detailed structure. + """ + import networkx as nx + from os.path import basename + from tqdm import tqdm + import sys + data = [] + with open(filename) as f: + content = f.read().splitlines() + index = 0 + pbar = tqdm(total=len(content) + 1, desc='load SDF', file=sys.stdout) + while index < len(content): + index_old = index + + g = nx.Graph(name=content[index].strip()) # set name of the graph + + tmp = content[index + 3] + nb_nodes = int(tmp[:3]) # number of the nodes + nb_edges = int(tmp[3:6]) # number of the edges + + for i in range(0, nb_nodes): + tmp = content[i + index + 4] + g.add_node(i, atom=tmp[31:34].strip()) + + for i in range(0, nb_edges): + tmp = content[i + index + g.number_of_nodes() + 4] + tmp = [tmp[i:i + 3] for i in range(0, len(tmp), 3)] + g.add_edge( + int(tmp[0]) - 1, int(tmp[1]) - 1, bond_type=tmp[2].strip()) + + data.append(g) + + index += 4 + g.number_of_nodes() + g.number_of_edges() + while content[index].strip() != '$$$$': # seperator + index += 1 + index += 1 + + pbar.update(index - index_old) + pbar.update(1) + pbar.close() + + return data + + +def loadMAT(filename, order): + """Load graph data from a MATLAB (up to version 7.1) .mat file. + + Notes + ------ + A MAT file contains a struct array containing graphs, and a column vector lx containing a class label for each graph. + Check README in `downloadable file `__ for detailed structure. + """ + from scipy.io import loadmat + import numpy as np + import networkx as nx + data = [] + content = loadmat(filename) + # print(content) + # print('----') + for key, value in content.items(): + if key[0] == 'l': # class label + y = np.transpose(value)[0].tolist() + # print(y) + elif key[0] != '_': + # print(value[0][0][0]) + # print() + # print(value[0][0][1]) + # print() + # print(value[0][0][2]) + # print() + # if len(value[0][0]) > 3: + # print(value[0][0][3]) + # print('----') + # if adjacency matrix is not compressed / edge label exists + if order[1] == 0: + for i, item in enumerate(value[0]): + # print(item) + # print('------') + g = nx.Graph(name=i) # set name of the graph + nl = np.transpose(item[order[3]][0][0][0]) # node label + # print(item[order[3]]) + # print() + for index, label in enumerate(nl[0]): + g.add_node(index, atom=str(label)) + el = item[order[4]][0][0][0] # edge label + for edge in el: + g.add_edge( + edge[0] - 1, edge[1] - 1, bond_type=str(edge[2])) + data.append(g) + else: + from scipy.sparse import csc_matrix + for i, item in enumerate(value[0]): + # print(item) + # print('------') + g = nx.Graph(name=i) # set name of the graph + nl = np.transpose(item[order[3]][0][0][0]) # node label + # print(nl) + # print() + for index, label in enumerate(nl[0]): + g.add_node(index, atom=str(label)) + sam = item[order[0]] # sparse adjacency matrix + index_no0 = sam.nonzero() + for col, row in zip(index_no0[0], index_no0[1]): + # print(col) + # print(row) + g.add_edge(col, row) + data.append(g) + # print(g.edges(data=True)) + return data, y + + +def load_tud(filename): + """Load graph data from TUD dataset files. + + Notes + ------ + The graph data is loaded from separate files. + Check README in `downloadable file `__, 2018 for detailed structure. + """ + import networkx as nx + from os import listdir + from os.path import dirname, basename + + + def get_infos_from_readme(frm): # @todo: add README (cuniform), maybe node/edge label maps. + """Get information from DS_label_readme.txt file. + """ + + def get_label_names_from_line(line): + """Get names of labels/attributes from a line. + """ + str_names = line.split('[')[1].split(']')[0] + names = str_names.split(',') + names = [attr.strip() for attr in names] + return names + + + def get_class_label_map(label_map_strings): + label_map = {} + for string in label_map_strings: + integer, label = string.split('\t') + label_map[int(integer.strip())] = label.strip() + return label_map + + + label_names = {'node_labels': [], 'node_attrs': [], + 'edge_labels': [], 'edge_attrs': []} + class_label_map_strings = [] + content_rm = open(frm).read().splitlines() + i = 0 + while i < len(content_rm): + line = content_rm[i].strip() + # get node/edge labels and attributes. + if line.startswith('Node labels:'): + label_names['node_labels'] = get_label_names_from_line(line) + elif line.startswith('Node attributes:'): + label_names['node_attrs'] = get_label_names_from_line(line) + elif line.startswith('Edge labels:'): + label_names['edge_labels'] = get_label_names_from_line(line) + elif line.startswith('Edge attributes:'): + label_names['edge_attrs'] = get_label_names_from_line(line) + # get class label map. + elif line.startswith('Class labels were converted to integer values using this map:'): + i += 2 + line = content_rm[i].strip() + while line != '' and i < len(content_rm): + class_label_map_strings.append(line) + i += 1 + line = content_rm[i].strip() + class_label_map = get_class_label_map(class_label_map_strings) + i += 1 + + return label_names, class_label_map + + + # get dataset name. + dirname_dataset = dirname(filename) + filename = basename(filename) + fn_split = filename.split('_A') + ds_name = fn_split[0].strip() + + # load data file names + for name in listdir(dirname_dataset): + if ds_name + '_A' in name: + fam = dirname_dataset + '/' + name + elif ds_name + '_graph_indicator' in name: + fgi = dirname_dataset + '/' + name + elif ds_name + '_graph_labels' in name: + fgl = dirname_dataset + '/' + name + elif ds_name + '_node_labels' in name: + fnl = dirname_dataset + '/' + name + elif ds_name + '_edge_labels' in name: + fel = dirname_dataset + '/' + name + elif ds_name + '_edge_attributes' in name: + fea = dirname_dataset + '/' + name + elif ds_name + '_node_attributes' in name: + fna = dirname_dataset + '/' + name + elif ds_name + '_graph_attributes' in name: + fga = dirname_dataset + '/' + name + elif ds_name + '_label_readme' in name: + frm = dirname_dataset + '/' + name + # this is supposed to be the node attrs, make sure to put this as the last 'elif' + elif ds_name + '_attributes' in name: + fna = dirname_dataset + '/' + name + + # get labels and attributes names. + if 'frm' in locals(): + label_names, class_label_map = get_infos_from_readme(frm) + else: + label_names = {'node_labels': [], 'node_attrs': [], + 'edge_labels': [], 'edge_attrs': []} + + content_gi = open(fgi).read().splitlines() # graph indicator + content_am = open(fam).read().splitlines() # adjacency matrix + content_gl = open(fgl).read().splitlines() # graph labels + + # create graphs and add nodes + data = [nx.Graph(name=str(i)) for i in range(0, len(content_gl))] + if 'fnl' in locals(): + content_nl = open(fnl).read().splitlines() # node labels + for idx, line in enumerate(content_gi): + # transfer to int first in case of unexpected blanks + data[int(line) - 1].add_node(idx) + labels = [l.strip() for l in content_nl[idx].split(',')] + if label_names['node_labels'] == []: + for i, label in enumerate(labels): + l_name = 'label_' + str(i) + data[int(line) - 1].nodes[idx][l_name] = label + label_names['node_labels'].append(l_name) + else: + for i, l_name in enumerate(label_names['node_labels']): + data[int(line) - 1].nodes[idx][l_name] = labels[i] + else: + for i, line in enumerate(content_gi): + data[int(line) - 1].add_node(i) + + # add edges + for line in content_am: + tmp = line.split(',') + n1 = int(tmp[0]) - 1 + n2 = int(tmp[1]) - 1 + # ignore edge weight here. + g = int(content_gi[n1]) - 1 + data[g].add_edge(n1, n2) + + # add edge labels + if 'fel' in locals(): + content_el = open(fel).read().splitlines() + for idx, line in enumerate(content_el): + labels = [l.strip() for l in line.split(',')] + n = [int(i) - 1 for i in content_am[idx].split(',')] + g = int(content_gi[n[0]]) - 1 + if label_names['edge_labels'] == []: + for i, label in enumerate(labels): + l_name = 'label_' + str(i) + data[g].edges[n[0], n[1]][l_name] = label + label_names['edge_labels'].append(l_name) + else: + for i, l_name in enumerate(label_names['edge_labels']): + data[g].edges[n[0], n[1]][l_name] = labels[i] + + # add node attributes + if 'fna' in locals(): + content_na = open(fna).read().splitlines() + for idx, line in enumerate(content_na): + attrs = [a.strip() for a in line.split(',')] + g = int(content_gi[idx]) - 1 + if label_names['node_attrs'] == []: + for i, attr in enumerate(attrs): + a_name = 'attr_' + str(i) + data[g].nodes[idx][a_name] = attr + label_names['node_attrs'].append(a_name) + else: + for i, a_name in enumerate(label_names['node_attrs']): + data[g].nodes[idx][a_name] = attrs[i] + + # add edge attributes + if 'fea' in locals(): + content_ea = open(fea).read().splitlines() + for idx, line in enumerate(content_ea): + attrs = [a.strip() for a in line.split(',')] + n = [int(i) - 1 for i in content_am[idx].split(',')] + g = int(content_gi[n[0]]) - 1 + if label_names['edge_attrs'] == []: + for i, attr in enumerate(attrs): + a_name = 'attr_' + str(i) + data[g].edges[n[0], n[1]][a_name] = attr + label_names['edge_attrs'].append(a_name) + else: + for i, a_name in enumerate(label_names['edge_attrs']): + data[g].edges[n[0], n[1]][a_name] = attrs[i] + + # load targets. + targets = [int(i) for i in content_gl] + if 'class_label_map' in locals(): + targets = [class_label_map[t] for t in targets] + + return data, targets, label_names + + +def loadFromXML(filename, dir_dataset=None): + import xml.etree.ElementTree as ET + + if dir_dataset is not None: + dir_dataset = dir_dataset + else: + dir_dataset = dirname(filename) + tree = ET.parse(filename) + root = tree.getroot() + data = [] + y = [] + for graph in root.iter('graph'): + mol_filename = graph.attrib['file'] + mol_class = graph.attrib['class'] + data.append(loadGXL(dir_dataset + '/' + mol_filename)) + y.append(mol_class) + + return data, y + + +def loadFromDS(filename, filename_y): + """Load data from .ds file. + + Possible graph formats include: + + '.ct': see function loadCT for detail. + + '.gxl': see dunction loadGXL for detail. + + Note these graph formats are checked automatically by the extensions of + graph files. + """ + dirname_dataset = dirname(filename) + data = [] + y = [] + content = open(filename).read().splitlines() + extension = splitext(content[0].split(' ')[0])[1][1:] + if filename_y is None or filename_y == '': + if extension == 'ct': + for i in range(0, len(content)): + tmp = content[i].split(' ') + # remove the '#'s in file names + data.append( + loadCT(dirname_dataset + '/' + tmp[0].replace('#', '', 1))) + y.append(float(tmp[1])) + elif extension == 'gxl': + for i in range(0, len(content)): + tmp = content[i].split(' ') + # remove the '#'s in file names + data.append( + loadGXL(dirname_dataset + '/' + tmp[0].replace('#', '', 1))) + y.append(float(tmp[1])) + else: # y in a seperate file + if extension == 'ct': + for i in range(0, len(content)): + tmp = content[i] + # remove the '#'s in file names + data.append( + loadCT(dirname_dataset + '/' + tmp.replace('#', '', 1))) + elif extension == 'gxl': + for i in range(0, len(content)): + tmp = content[i] + # remove the '#'s in file names + data.append( + loadGXL(dirname_dataset + '/' + tmp.replace('#', '', 1))) + + content_y = open(filename_y).read().splitlines() + # assume entries in filename and filename_y have the same order. + for item in content_y: + tmp = item.split(' ') + # assume the 3rd entry in a line is y (for Alkane dataset) + y.append(float(tmp[2])) + + return data, y + + +if __name__ == '__main__': +# ### Load dataset from .ds file. +# # .ct files. +# ds = {'name': 'Alkane', 'dataset': '../../datasets/Alkane/dataset.ds', +# 'dataset_y': '../../datasets/Alkane/dataset_boiling_point_names.txt'} +# Gn, y = loadDataset(ds['dataset'], filename_y=ds['dataset_y']) +## ds = {'name': 'Acyclic', 'dataset': '../../datasets/acyclic/dataset_bps.ds'} # node symb +## Gn, y = loadDataset(ds['dataset']) +## ds = {'name': 'MAO', 'dataset': '../../datasets/MAO/dataset.ds'} # node/edge symb +## Gn, y = loadDataset(ds['dataset']) +## ds = {'name': 'PAH', 'dataset': '../../datasets/PAH/dataset.ds'} # unlabeled +## Gn, y = loadDataset(ds['dataset']) +# print(Gn[1].nodes(data=True)) +# print(Gn[1].edges(data=True)) +# print(y[1]) + +# # .gxl file. +# ds = {'name': 'monoterpenoides', +# 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb +# Gn, y = loadDataset(ds['dataset']) +# print(Gn[1].nodes(data=True)) +# print(Gn[1].edges(data=True)) +# print(y[1]) + +# ### Convert graph from one format to another. +# # .gxl file. +# import networkx as nx +# ds = {'name': 'monoterpenoides', +# 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb +# Gn, y = loadDataset(ds['dataset']) +# y = [int(i) for i in y] +# print(Gn[1].nodes(data=True)) +# print(Gn[1].edges(data=True)) +# print(y[1]) +# # Convert a graph to the proper NetworkX format that can be recognized by library gedlib. +# Gn_new = [] +# for G in Gn: +# G_new = nx.Graph() +# for nd, attrs in G.nodes(data=True): +# G_new.add_node(str(nd), chem=attrs['atom']) +# for nd1, nd2, attrs in G.edges(data=True): +# G_new.add_edge(str(nd1), str(nd2), valence=attrs['bond_type']) +## G_new.add_edge(str(nd1), str(nd2)) +# Gn_new.append(G_new) +# print(Gn_new[1].nodes(data=True)) +# print(Gn_new[1].edges(data=True)) +# print(Gn_new[1]) +# filename = '/media/ljia/DATA/research-repo/codes/others/gedlib/tests_linlin/generated_datsets/monoterpenoides/gxl/monoterpenoides' +# xparams = {'method': 'gedlib'} +# saveDataset(Gn, y, gformat='gxl', group='xml', filename=filename, xparams=xparams) + + # save dataset. +# ds = {'name': 'MUTAG', 'dataset': '../../datasets/MUTAG/MUTAG.mat', +# 'extra_params': {'am_sp_al_nl_el': [0, 0, 3, 1, 2]}} # node/edge symb +# Gn, y = loadDataset(ds['dataset'], extra_params=ds['extra_params']) +# saveDataset(Gn, y, group='xml', filename='temp/temp') + + # test - new way to add labels and attributes. +# dataset = '../../datasets/SYNTHETICnew/SYNTHETICnew_A.txt' +# filename = '../../datasets/Fingerprint/Fingerprint_A.txt' +# dataset = '../../datasets/Letter-med/Letter-med_A.txt' +# dataset = '../../datasets/AIDS/AIDS_A.txt' +# dataset = '../../datasets/ENZYMES_txt/ENZYMES_A_sparse.txt' +# Gn, targets = load_dataset(filename) + pass