diff --git a/gklearn/utils/graph_files.py b/gklearn/utils/graph_files.py index 5206ee4..fafecf6 100644 --- a/gklearn/utils/graph_files.py +++ b/gklearn/utils/graph_files.py @@ -127,7 +127,7 @@ def save_dataset(Gn, y, gformat='gxl', group=None, filename='gfile', xparams=Non fgroup.close() -def load_ct(filename): +def load_ct(filename): # @todo: this function is only tested on CTFile V2000; header not considered; only simple cases (atoms and bonds are considered.) """load data from a Chemical Table (.ct) file. Notes @@ -154,30 +154,65 @@ def load_ct(filename): g = nx.Graph() with open(filename) as f: content = f.read().splitlines() - g = nx.Graph( - name = str(content[0]), - filename = basename(filename)) # set name of the graph - tmp = content[1].split(" ") - if tmp[0] == '': - nb_nodes = int(tmp[1]) # number of the nodes - nb_edges = int(tmp[2]) # number of the edges - else: - nb_nodes = int(tmp[0]) - nb_edges = int(tmp[1]) - # patch for compatibility : label will be removed later - for i in range(0, nb_nodes): - tmp = content[i + 2].split(" ") + g = nx.Graph(name=str(content[0]), filename=basename(filename)) # set name of the graph + + # read the counts line. + tmp = content[1].split(' ') + tmp = [x for x in tmp if x != ''] + nb_atoms = int(tmp[0].strip()) # number of atoms + nb_bonds = int(tmp[1].strip()) # number of bonds + count_line_tags = ['number_of_atoms', 'number_of_bonds', 'number_of_atom_lists', '', 'chiral_flag', 'number_of_stext_entries', '', '', '', '', 'number_of_properties', 'CT_version'] + i = 0 + while i < len(tmp): + if count_line_tags[i] != '': # if not obsoleted + g.graph[count_line_tags[i]] = tmp[i].strip() + i += 1 + + # read the atom block. + atom_tags = ['x', 'y', 'z', 'atom_symbol', 'mass_difference', 'charge', 'atom_stereo_parity', 'hydrogen_count_plus_1', 'stereo_care_box', 'valence', 'h0_designator', '', '', 'atom_atom_mapping_number', 'inversion_retention_flag', 'exact_change_flag'] + for i in range(0, nb_atoms): + tmp = content[i + 2].split(' ') tmp = [x for x in tmp if x != ''] - g.add_node(i, atom=tmp[3].strip(), - label=[item.strip() for item in tmp[3:]], - attributes=[item.strip() for item in tmp[0:3]]) - for i in range(0, nb_edges): - tmp = content[i + g.number_of_nodes() + 2].split(" ") + g.add_node(i) + j = 0 + while j < len(tmp): + if atom_tags[j] != '': + g.nodes[i][atom_tags[j]] = tmp[j].strip() + j += 1 + + # read the bond block. + bond_tags = ['first_atom_number', 'second_atom_number', 'bond_type', 'bond_stereo', '', 'bond_topology', 'reacting_center_status'] + for i in range(0, nb_bonds): + tmp = content[i + g.number_of_nodes() + 2].split(' ') tmp = [x for x in tmp if x != ''] - g.add_edge(int(tmp[0]) - 1, int(tmp[1]) - 1, - bond_type=tmp[2].strip(), - label=[item.strip() for item in tmp[2:]]) - return g + n1, n2 = int(tmp[0].strip()) - 1, int(tmp[1].strip()) - 1 + g.add_edge(n1, n2) + j = 2 + while j < len(tmp): + if bond_tags[j] != '': + g.edges[(n1, n2)][bond_tags[j]] = tmp[j].strip() + j += 1 + + # get label names. + label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []} + atom_symbolic = [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, None, None, 1, 1, 1] + for nd in g.nodes(): + for key in g.nodes[nd]: + if atom_symbolic[atom_tags.index(key)] == 1: + label_names['node_labels'].append(key) + else: + label_names['node_attrs'].append(key) + break + bond_symbolic = [None, None, 1, 1, None, 1, 1] + for ed in g.edges(): + for key in g.edges[ed]: + if bond_symbolic[bond_tags.index(key)] == 1: + label_names['edge_labels'].append(key) + else: + label_names['edge_attrs'].append(key) + break + + return g, label_names def load_gxl(filename): # @todo: directed graphs. @@ -678,11 +713,7 @@ def load_from_ds(filename, filename_targets): Note these graph formats are checked automatically by the extensions of graph files. - """ - def append_label_names(label_names, new_names): - for key, val in label_names.items(): - label_names[key] += [name for name in new_names[key] if name not in val] - + """ dirname_dataset = dirname(filename) data = [] y = [] @@ -694,8 +725,9 @@ def load_from_ds(filename, filename_targets): for i in range(0, len(content)): tmp = content[i].split(' ') # remove the '#'s in file names - data.append( - load_ct(dirname_dataset + '/' + tmp[0].replace('#', '', 1))) + g, l_names = load_ct(dirname_dataset + '/' + tmp[0].replace('#', '', 1)) + data.append(g) + __append_label_names(label_names, l_names) y.append(float(tmp[1])) elif extension == 'gxl': for i in range(0, len(content)): @@ -703,22 +735,23 @@ def load_from_ds(filename, filename_targets): # remove the '#'s in file names g, l_names = load_gxl(dirname_dataset + '/' + tmp[0].replace('#', '', 1)) data.append(g) - append_label_names(label_names, l_names) + __append_label_names(label_names, l_names) y.append(float(tmp[1])) else: # y in a seperate file if extension == 'ct': for i in range(0, len(content)): tmp = content[i] # remove the '#'s in file names - data.append( - load_ct(dirname_dataset + '/' + tmp.replace('#', '', 1))) + g, l_names = load_ct(dirname_dataset + '/' + tmp.replace('#', '', 1)) + data.append(g) + __append_label_names(label_names, l_names) elif extension == 'gxl': for i in range(0, len(content)): tmp = content[i] # remove the '#'s in file names - g, l_names = load_gxl(dirname_dataset + '/' + tmp[0].replace('#', '', 1)) + g, l_names = load_gxl(dirname_dataset + '/' + tmp.replace('#', '', 1)) data.append(g) - append_label_names(label_names, l_names) + __append_label_names(label_names, l_names) content_y = open(filename_targets).read().splitlines() # assume entries in filename and filename_targets have the same order. @@ -728,7 +761,12 @@ def load_from_ds(filename, filename_targets): y.append(float(tmp[2])) return data, y, label_names - + + +def __append_label_names(label_names, new_names): + for key, val in label_names.items(): + label_names[key] += [name for name in new_names[key] if name not in val] + if __name__ == '__main__': # ### Load dataset from .ds file. @@ -736,24 +774,24 @@ if __name__ == '__main__': # ds = {'name': 'Alkane', 'dataset': '../../datasets/Alkane/dataset.ds', # 'dataset_y': '../../datasets/Alkane/dataset_boiling_point_names.txt'} # Gn, y = loadDataset(ds['dataset'], filename_y=ds['dataset_y']) -## ds = {'name': 'Acyclic', 'dataset': '../../datasets/acyclic/dataset_bps.ds'} # node symb -## Gn, y = loadDataset(ds['dataset']) + ds_file = '../../datasets/acyclic/dataset_bps.ds' # node symb + Gn, targets, label_names = load_dataset(ds_file) ## ds = {'name': 'MAO', 'dataset': '../../datasets/MAO/dataset.ds'} # node/edge symb ## Gn, y = loadDataset(ds['dataset']) ## ds = {'name': 'PAH', 'dataset': '../../datasets/PAH/dataset.ds'} # unlabeled ## Gn, y = loadDataset(ds['dataset']) -# print(Gn[1].nodes(data=True)) -# print(Gn[1].edges(data=True)) -# print(y[1]) - - # .gxl file. - ds = {'name': 'monoterpenoides', - 'dataset': '../../datasets/monoterpenoides/dataset_10+.ds'} # node/edge symb - Gn, y, label_names = load_dataset(ds['dataset']) print(Gn[1].graph) print(Gn[1].nodes(data=True)) print(Gn[1].edges(data=True)) - print(y[1]) + print(targets[1]) + +# # .gxl file. +# ds_file = '../../datasets/monoterpenoides/dataset_10+.ds' # node/edge symb +# Gn, y, label_names = load_dataset(ds_file) +# print(Gn[1].graph) +# print(Gn[1].nodes(data=True)) +# print(Gn[1].edges(data=True)) +# print(y[1]) # ### Convert graph from one format to another. # # .gxl file.