@@ -131,22 +131,37 @@ class DataLoader():
def load_from_xml(self, filename, dir_dataset=None):
import xml.etree.ElementTree as ET
if dir_dataset is not None:
dir_dataset = dir_dataset
else:
dir_dataset = dirname(filename)
tree = ET.parse(filename)
root = tree.getroot()
def load_one_file(filename, data, y, label_names):
tree = ET.parse(filename)
root = tree.getroot()
for graph in root.iter('graph') if root.find('graph') is not None else root.iter('print'): # "graph" for ... I forgot; "print" for datasets GREC and Web.
mol_filename = graph.attrib['file']
mol_class = graph.attrib['class']
g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename)
data.append(g)
self._append_label_names(label_names, l_names)
y.append(mol_class)
data = []
y = []
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
for graph in root.iter('graph'):
mol_filename = graph.attrib['file']
mol_class = graph.attrib['class']
g, l_names = self.load_gxl(dir_dataset + '/' + mol_filename)
data.append(g)
self._append_label_names(label_names, l_names)
y.append(mol_class)
if isinstance(filename, str):
if dir_dataset is not None:
dir_dataset = dir_dataset
else:
dir_dataset = dirname(filename)
load_one_file(filename, data, y, label_names)
else: # filename is a list of files.
if dir_dataset is not None:
dir_dataset = dir_dataset
else:
dir_dataset = dirname(filename[0])
for fn in filename:
load_one_file(fn, data, y, label_names)
return data, y, label_names
@@ -505,32 +520,45 @@ class DataLoader():
for node in root.iter('node'):
dic[node.attrib['id']] = index
labels = {}
for attr in node.iter('attr'):
for attr in node.iter('attr'): # for datasets "GREC" and "Monoterpens".
labels[attr.attrib['name']] = attr[0].text
for attr in node.iter('attribute'): # for dataset "Web".
labels[attr.attrib['name']] = attr.attrib['value']
g.add_node(index, **labels)
index += 1
for edge in root.iter('edge'):
labels = {}
for attr in edge.iter('attr'):
for attr in edge.iter('attr'): # for datasets "GREC" and "Monoterpens".
labels[attr.attrib['name']] = attr[0].text
for attr in edge.iter('attribute'): # for dataset "Web".
labels[attr.attrib['name']] = attr.attrib['value']
g.add_edge(dic[edge.attrib['from']], dic[edge.attrib['to']], **labels)
# get label names.
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
# @todo: possible loss of label names if some nodes miss some labels.
for node in root.iter('node'):
for attr in node.iter('attr'):
if attr[0].tag == 'int': # @todo: this maybe wrong, and slow.
for attr in node.iter('attr'): # for datasets "GREC" and "Monoterpens".
if attr[0].tag == 'int' or attr.attrib['name'] == 'type' : # @todo: this maybe wrong, and slow. "type" is for dataset GREC; "int" is for dataset "Monoterpens".
label_names['node_labels'].append(attr.attrib['name'])
else:
label_names['node_attrs'].append(attr.attrib['name'])
for attr in node.iter('attribute'): # for dataset "Web".
label_names['node_attrs'].append(attr.attrib['name'])
# @todo: is id useful in dataset "Web"? is "FREQUENCY" symbolic or not?
break
for edge in root.iter('edge'):
for attr in edge.iter('attr'):
if attr[0].tag == 'int': # @todo: this maybe wrong, and slow.
for attr in edge.iter('attr'): # for datasets "GREC" and "Monoterpens".
if attr[0].tag == 'int' or attr.attrib['name'] == 'frequency' or 'type' in attr.attrib['name'] : # @todo: this maybe wrong, and slow. "frequency" and "type" are for dataset GREC; "int" is for dataset "Monoterpens".
label_names['edge_labels'].append(attr.attrib['name'])
else:
label_names['edge_attrs'].append(attr.attrib['name'])
for attr in edge.iter('attribute'): # for dataset "Web".
label_names['edge_attrs'].append(attr.attrib['name'])
break
return g, label_names
@@ -546,7 +574,10 @@ class DataLoader():
tree = ET.parse(filename)
root = tree.getroot()
index = 0
g_id = root.find(xmlns + 'molecule').attrib['id']
if root.tag == xmlns + 'molecule':
g_id = root.attrib['id']
else:
g_id = root.find(xmlns + 'molecule').attrib['id']
g = nx.Graph(filename=basename(filename), name=g_id)
dic = {} # used to retrieve incident nodes of edges
for atom in root.iter(xmlns + 'atom'):
@@ -561,13 +592,14 @@ class DataLoader():
for bond in root.iter(xmlns + 'bond'):
labels = {}
for key, val in bond.attrib.items():
if key != 'atomRefs2':
if key != 'atomRefs2' and key != 'id' : # "id" is in dataset "ACE".
labels[key] = val
n1, n2 = bond.attrib['atomRefs2'].strip().split(' ')
g.add_edge(dic[n1], dic[n2], **labels)
# get label names.
label_names = {'node_labels': [], 'edge_labels': [], 'node_attrs': [], 'edge_attrs': []}
# @todo: possible loss of label names if some nodes miss some labels.
for key, val in g.nodes[0].items():
try:
float(val)