diff --git a/gklearn/dataset/dataset.py b/gklearn/dataset/dataset.py index f0f02e8..faca89b 100644 --- a/gklearn/dataset/dataset.py +++ b/gklearn/dataset/dataset.py @@ -117,6 +117,8 @@ class Dataset(object): if 'extra_params' in DATASET_META[ds_name]: kwargs = DATASET_META[ds_name]['extra_params'] + else: + kwargs = {} self._graphs, self._targets, label_names = DataLoader(ds_file, filename_targets=fn_targets, **kwargs).data diff --git a/gklearn/dataset/file_managers.py b/gklearn/dataset/file_managers.py index 3a6e831..9a804f5 100644 --- a/gklearn/dataset/file_managers.py +++ b/gklearn/dataset/file_managers.py @@ -131,22 +131,37 @@ class DataLoader(): def load_from_xml(self, filename, dir_dataset=None): import xml.etree.ElementTree as ET - if dir_dataset is not None: - dir_dataset = dir_dataset - else: - dir_dataset = dirname(filename) - tree = ET.parse(filename) - root = tree.getroot() + def load_one_file(filename, data, y, label_names): + tree = ET.parse(filename) + root = tree.getroot() + for graph in root.iter('graph') if root.find('graph') is not None else root.iter('print'): # "graph" for ... I forgot; "print" for datasets GREC and Web. + mol_filename = graph.attrib['file'] + mol_class = graph.attrib['class'] + g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename) + data.append(g) + self._append_label_names(label_names, l_names) + y.append(mol_class) + data = [] y = [] label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []} - for graph in root.iter('graph'): - mol_filename = graph.attrib['file'] - mol_class = graph.attrib['class'] - g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename) - data.append(g) - self._append_label_names(label_names, l_names) - y.append(mol_class) + + if isinstance(filename, str): + if dir_dataset is not None: + dir_dataset = dir_dataset + else: + dir_dataset = dirname(filename) + load_one_file(filename, data, y, label_names) + + + else: # filename is a list of files. + if dir_dataset is not None: + dir_dataset = dir_dataset + else: + dir_dataset = dirname(filename[0]) + + for fn in filename: + load_one_file(fn, data, y, label_names) return data, y, label_names @@ -505,32 +520,45 @@ class DataLoader(): for node in root.iter('node'): dic[node.attrib['id']] = index labels = {} - for attr in node.iter('attr'): + for attr in node.iter('attr'): # for datasets "GREC" and "Monoterpens". labels[attr.attrib['name']] = attr[0].text + for attr in node.iter('attribute'): # for dataset "Web". + labels[attr.attrib['name']] = attr.attrib['value'] g.add_node(index, **labels) index += 1 for edge in root.iter('edge'): labels = {} - for attr in edge.iter('attr'): + for attr in edge.iter('attr'): # for datasets "GREC" and "Monoterpens". labels[attr.attrib['name']] = attr[0].text + for attr in edge.iter('attribute'): # for dataset "Web". + labels[attr.attrib['name']] = attr.attrib['value'] g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels) # get label names. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []} + # @todo: possible loss of label names if some nodes miss some labels. for node in root.iter('node'): - for attr in node.iter('attr'): - if attr[0].tag == 'int': # @todo: this maybe wrong, and slow. + for attr in node.iter('attr'): # for datasets "GREC" and "Monoterpens". + if attr[0].tag == 'int' or attr.attrib['name'] == 'type': # @todo: this maybe wrong, and slow. "type" is for dataset GREC; "int" is for dataset "Monoterpens". label_names['node_labels'].append(attr.attrib['name']) else: label_names['node_attrs'].append(attr.attrib['name']) + + for attr in node.iter('attribute'): # for dataset "Web". + label_names['node_attrs'].append(attr.attrib['name']) + # @todo: is id useful in dataset "Web"? is "FREQUENCY" symbolic or not? break + for edge in root.iter('edge'): - for attr in edge.iter('attr'): - if attr[0].tag == 'int': # @todo: this maybe wrong, and slow. + for attr in edge.iter('attr'): # for datasets "GREC" and "Monoterpens". + if attr[0].tag == 'int' or attr.attrib['name'] == 'frequency' or 'type' in attr.attrib['name']: # @todo: this maybe wrong, and slow. "frequency" and "type" are for dataset GREC; "int" is for dataset "Monoterpens". label_names['edge_labels'].append(attr.attrib['name']) else: label_names['edge_attrs'].append(attr.attrib['name']) + + for attr in edge.iter('attribute'): # for dataset "Web". + label_names['edge_attrs'].append(attr.attrib['name']) break return g, label_names @@ -546,7 +574,10 @@ class DataLoader(): tree = ET.parse(filename) root = tree.getroot() index = 0 - g_id = root.find(xmlns + 'molecule').attrib['id'] + if root.tag == xmlns + 'molecule': + g_id = root.attrib['id'] + else: + g_id = root.find(xmlns + 'molecule').attrib['id'] g = nx.Graph(filename=basename(filename), name=g_id) dic = {} # used to retrieve incident nodes of edges for atom in root.iter(xmlns + 'atom'): @@ -561,13 +592,14 @@ class DataLoader(): for bond in root.iter(xmlns + 'bond'): labels = {} for key, val in bond.attrib.items(): - if key != 'atomRefs2': + if key != 'atomRefs2' and key != 'id': # "id" is in dataset "ACE". labels[key] = val n1, n2 = bond.attrib['atomRefs2'].strip().split(' ') g.add_edge(dic[n1], dic[n2], **labels) # get label names. label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []} + # @todo: possible loss of label names if some nodes miss some labels. for key, val in g.nodes[0].items(): try: float(val) diff --git a/gklearn/dataset/metadata.py b/gklearn/dataset/metadata.py index ed3af41..27ccf3b 100644 --- a/gklearn/dataset/metadata.py +++ b/gklearn/dataset/metadata.py @@ -206,7 +206,8 @@ GREYC_META = { 'domain': 'small molecules', 'train_valid_test': ['trainset_0.ds', None, 'testset_0.ds'], 'stereoisomerism': False, - 'load_files': [], + 'load_files': ['dataWithOutsider.ds'], + 'extra_params': {'gformat': 'cml'} }, 'Vitamin_D': { 'database': 'greyc', @@ -250,7 +251,7 @@ IAM_META = { 'url': 'https://iapr-tc15.greyc.fr/IAM/GREC.zip', 'domain': None, 'train_valid_test': ['data/test.cxl','data/train.cxl', 'data/valid.cxl'], - 'load_files': [], + 'load_files': [['data/test.cxl','data/train.cxl', 'data/valid.cxl']], }, 'Web': { 'database': 'iam', @@ -268,7 +269,7 @@ IAM_META = { 'url': 'https://iapr-tc15.greyc.fr/IAM/Web.zip', 'domain': None, 'train_valid_test': ['data/test.cxl', 'data/train.cxl', 'data/valid.cxl'], - 'load_files': [], + 'load_files': [['data/test.cxl','data/train.cxl', 'data/valid.cxl']], }, }