@@ -4,7 +4,7 @@ | |||||
Created on Wed Oct 20 11:48:02 2020 | Created on Wed Oct 20 11:48:02 2020 | ||||
@author: ljia | @author: ljia | ||||
""" | |||||
""" | |||||
# This script tests the influence of the ratios between node costs and edge costs on the stability of the GED computation, where the base edit costs are [1, 1, 1, 1, 1, 1]. | # This script tests the influence of the ratios between node costs and edge costs on the stability of the GED computation, where the base edit costs are [1, 1, 1, 1, 1, 1]. | ||||
import os | import os | ||||
@@ -13,15 +13,15 @@ import pickle | |||||
import logging | import logging | ||||
from gklearn.ged.util import compute_geds | from gklearn.ged.util import compute_geds | ||||
import time | import time | ||||
from utils import get_dataset | |||||
from utils import get_dataset, set_edit_cost_consts | |||||
import sys | import sys | ||||
from group_results import group_trials | |||||
from group_results import group_trials, check_group_existence, update_group_marker | |||||
def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | ||||
save_file_suffix = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.trial_' + str(trial) | save_file_suffix = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.trial_' + str(trial) | ||||
# Return if the file exists. | # Return if the file exists. | ||||
if os.path.isfile(save_dir + 'ged_matrix' + save_file_suffix + '.pkl'): | if os.path.isfile(save_dir + 'ged_matrix' + save_file_suffix + '.pkl'): | ||||
return None, None | return None, None | ||||
@@ -41,8 +41,11 @@ def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||||
'threads': multiprocessing.cpu_count(), | 'threads': multiprocessing.cpu_count(), | ||||
'init_option': 'EAGER_WITHOUT_SHUFFLED_COPIES' | 'init_option': 'EAGER_WITHOUT_SHUFFLED_COPIES' | ||||
} | } | ||||
edit_cost_constants = [i * ratio for i in [1, 1, 1]] + [1, 1, 1] | |||||
edit_cost_constants = set_edit_cost_consts(ratio, | |||||
node_labeled=len(dataset.node_labels), | |||||
edge_labeled=len(dataset.edge_labels), | |||||
mode='uniform') | |||||
# edit_cost_constants = [item * 0.01 for item in edit_cost_constants] | # edit_cost_constants = [item * 0.01 for item in edit_cost_constants] | ||||
# pickle.dump(edit_cost_constants, open(save_dir + "edit_costs" + save_file_suffix + ".pkl", "wb")) | # pickle.dump(edit_cost_constants, open(save_dir + "edit_costs" + save_file_suffix + ".pkl", "wb")) | ||||
@@ -53,7 +56,7 @@ def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||||
options['node_attrs'] = dataset.node_attrs | options['node_attrs'] = dataset.node_attrs | ||||
options['edge_attrs'] = dataset.edge_attrs | options['edge_attrs'] = dataset.edge_attrs | ||||
parallel = True # if num_solutions == 1 else False | parallel = True # if num_solutions == 1 else False | ||||
"""**5. Compute GED matrix.**""" | """**5. Compute GED matrix.**""" | ||||
ged_mat = 'error' | ged_mat = 'error' | ||||
runtime = 0 | runtime = 0 | ||||
@@ -67,9 +70,9 @@ def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||||
logging.basicConfig(filename=LOG_FILENAME, level=logging.DEBUG) | logging.basicConfig(filename=LOG_FILENAME, level=logging.DEBUG) | ||||
logging.exception(save_file_suffix) | logging.exception(save_file_suffix) | ||||
print(repr(exp)) | print(repr(exp)) | ||||
"""**6. Get results.**""" | """**6. Get results.**""" | ||||
with open(save_dir + 'ged_matrix' + save_file_suffix + '.pkl', 'wb') as f: | with open(save_dir + 'ged_matrix' + save_file_suffix + '.pkl', 'wb') as f: | ||||
pickle.dump(ged_mat, f) | pickle.dump(ged_mat, f) | ||||
with open(save_dir + 'runtime' + save_file_suffix + '.pkl', 'wb') as f: | with open(save_dir + 'runtime' + save_file_suffix + '.pkl', 'wb') as f: | ||||
@@ -77,66 +80,76 @@ def xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial): | |||||
return ged_mat, runtime | return ged_mat, runtime | ||||
def save_trials_as_group(dataset, ds_name, num_solutions, ratio): | def save_trials_as_group(dataset, ds_name, num_solutions, ratio): | ||||
# Return if the group file exists. | # Return if the group file exists. | ||||
name_middle = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.' | name_middle = '.' + ds_name + '.num_sols_' + str(num_solutions) + '.ratio_' + "{:.2f}".format(ratio) + '.' | ||||
name_group = save_dir + 'groups/ged_mats' + name_middle + 'npy' | name_group = save_dir + 'groups/ged_mats' + name_middle + 'npy' | ||||
if os.path.isfile(name_group): | |||||
if check_group_existence(name_group): | |||||
return | return | ||||
ged_mats = [] | ged_mats = [] | ||||
runtimes = [] | runtimes = [] | ||||
for trial in range(1, 101): | |||||
num_trials = 100 | |||||
for trial in range(1, num_trials + 1): | |||||
print() | print() | ||||
print('Trial:', trial) | print('Trial:', trial) | ||||
ged_mat, runtime = xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial) | ged_mat, runtime = xp_compute_ged_matrix(dataset, ds_name, num_solutions, ratio, trial) | ||||
ged_mats.append(ged_mat) | ged_mats.append(ged_mat) | ||||
runtimes.append(runtime) | runtimes.append(runtime) | ||||
# Group trials and Remove single files. | # Group trials and Remove single files. | ||||
# @todo: if the program stops between the following lines, then there may be errors. | |||||
name_prefix = 'ged_matrix' + name_middle | name_prefix = 'ged_matrix' + name_middle | ||||
group_trials(save_dir, name_prefix, True, True, False) | |||||
group_trials(save_dir, name_prefix, True, True, False, num_trials=num_trials) | |||||
name_prefix = 'runtime' + name_middle | name_prefix = 'runtime' + name_middle | ||||
group_trials(save_dir, name_prefix, True, True, False) | |||||
group_trials(save_dir, name_prefix, True, True, False, num_trials=num_trials) | |||||
update_group_marker(name_group) | |||||
def results_for_a_dataset(ds_name): | def results_for_a_dataset(ds_name): | ||||
"""**1. Get dataset.**""" | """**1. Get dataset.**""" | ||||
dataset = get_dataset(ds_name) | dataset = get_dataset(ds_name) | ||||
for num_solutions in num_solutions_list: | |||||
for ratio in ratio_list: | |||||
print() | print() | ||||
print('# of solutions:', num_solutions) | |||||
for ratio in ratio_list: | |||||
print('Ratio:', ratio) | |||||
for num_solutions in num_solutions_list: | |||||
print() | print() | ||||
print('Ratio:', ratio) | |||||
print('# of solutions:', num_solutions) | |||||
save_trials_as_group(dataset, ds_name, num_solutions, ratio) | save_trials_as_group(dataset, ds_name, num_solutions, ratio) | ||||
def get_param_lists(ds_name): | |||||
def get_param_lists(ds_name, test=False): | |||||
if test: | |||||
num_solutions_list = [1, 10, 20, 30, 40, 50] | |||||
ratio_list = [10] | |||||
return num_solutions_list, ratio_list | |||||
if ds_name == 'AIDS_symb': | if ds_name == 'AIDS_symb': | ||||
num_solutions_list = [1, 20, 40, 60, 80, 100] | num_solutions_list = [1, 20, 40, 60, 80, 100] | ||||
ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9] | ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9] | ||||
else: | else: | ||||
num_solutions_list = [1, 20, 40, 60, 80, 100] | |||||
ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9] | |||||
num_solutions_list = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] # [1, 20, 40, 60, 80, 100] | |||||
ratio_list = [0.1, 0.3, 0.5, 0.7, 0.9, 1, 3, 5, 7, 9, 10][::-1] | |||||
return num_solutions_list, ratio_list | return num_solutions_list, ratio_list | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
if len(sys.argv) > 1: | if len(sys.argv) > 1: | ||||
ds_name_list = sys.argv[1:] | ds_name_list = sys.argv[1:] | ||||
else: | else: | ||||
ds_name_list = ['MAO', 'Monoterpenoides', 'MUTAG', 'AIDS_symb'] | |||||
save_dir = 'outputs/edit_costs.num_sols.ratios.IPFP/' | |||||
ds_name_list = ['Acyclic', 'Alkane_unlabeled', 'MAO_lite', 'Monoterpenoides', 'MUTAG'] | |||||
# ds_name_list = ['Acyclic'] # 'Alkane_unlabeled'] | |||||
# ds_name_list = ['Acyclic', 'MAO', 'Monoterpenoides', 'MUTAG', 'AIDS_symb'] | |||||
save_dir = 'outputs/edit_costs.real_data.num_sols.ratios.IPFP/' | |||||
os.makedirs(save_dir, exist_ok=True) | os.makedirs(save_dir, exist_ok=True) | ||||
os.makedirs(save_dir + 'groups/', exist_ok=True) | os.makedirs(save_dir + 'groups/', exist_ok=True) | ||||
for ds_name in ds_name_list: | for ds_name in ds_name_list: | ||||
print() | print() | ||||
print('Dataset:', ds_name) | print('Dataset:', ds_name) | ||||
num_solutions_list, ratio_list = get_param_lists(ds_name) | |||||
num_solutions_list, ratio_list = get_param_lists(ds_name, test=False) | |||||
results_for_a_dataset(ds_name) | results_for_a_dataset(ds_name) |
@@ -5,7 +5,7 @@ Created on Thu Oct 29 17:26:43 2020 | |||||
@author: ljia | @author: ljia | ||||
This script groups results together into a single file for the sake of faster | |||||
This script groups results together into a single file for the sake of faster | |||||
searching and loading. | searching and loading. | ||||
""" | """ | ||||
import os | import os | ||||
@@ -16,9 +16,55 @@ from tqdm import tqdm | |||||
import sys | import sys | ||||
def check_group_existence(file_name): | |||||
path, name = os.path.split(file_name) | |||||
marker_fn = os.path.join(path, 'group_names_finished.pkl') | |||||
if os.path.isfile(marker_fn): | |||||
with open(marker_fn, 'rb') as f: | |||||
fns = pickle.load(f) | |||||
if name in fns: | |||||
return True | |||||
if os.path.isfile(file_name): | |||||
return True | |||||
return False | |||||
def update_group_marker(file_name): | |||||
path, name = os.path.split(file_name) | |||||
marker_fn = os.path.join(path, 'group_names_finished.pkl') | |||||
if os.path.isfile(marker_fn): | |||||
with open(marker_fn, 'rb') as f: | |||||
fns = pickle.loads(f) | |||||
if name in fns: | |||||
return | |||||
else: | |||||
fns.add(name) | |||||
else: | |||||
fns = set({name}) | |||||
with open(marker_fn, 'wb') as f: | |||||
pickle.dump(fns, f) | |||||
def create_group_marker_file(dir_folder, overwrite=True): | |||||
if not overwrite: | |||||
return | |||||
fns = set() | |||||
for file in sorted(os.listdir(dir_folder)): | |||||
if os.path.isfile(os.path.join(dir_folder, file)): | |||||
if file.endswith('.npy'): | |||||
fns.add(file) | |||||
marker_fn = os.path.join(dir_folder, 'group_names_finished.pkl') | |||||
with open(marker_fn, 'wb') as f: | |||||
pickle.dump(fns, f) | |||||
# This function is used by other scripts. Modify it carefully. | # This function is used by other scripts. Modify it carefully. | ||||
def group_trials(dir_folder, name_prefix, override, clear, backup): | |||||
def group_trials(dir_folder, name_prefix, overwrite, clear, backup, num_trials=100): | |||||
# Get group name. | # Get group name. | ||||
label_name = name_prefix.split('.')[0] | label_name = name_prefix.split('.')[0] | ||||
if label_name == 'ged_matrix': | if label_name == 'ged_matrix': | ||||
@@ -33,10 +79,10 @@ def group_trials(dir_folder, name_prefix, override, clear, backup): | |||||
else: | else: | ||||
name_group = dir_folder + 'groups/' + group_label + name_suffix + 'pkl' | name_group = dir_folder + 'groups/' + group_label + name_suffix + 'pkl' | ||||
if not override and os.path.isfile(name_group): | |||||
if not overwrite and os.path.isfile(name_group): | |||||
# Check if all trial files exist. | # Check if all trial files exist. | ||||
trials_complete = True | trials_complete = True | ||||
for trial in range(1, 101): | |||||
for trial in range(1, num_trials + 1): | |||||
file_name = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | file_name = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | ||||
if not os.path.isfile(file_name): | if not os.path.isfile(file_name): | ||||
trials_complete = False | trials_complete = False | ||||
@@ -44,7 +90,7 @@ def group_trials(dir_folder, name_prefix, override, clear, backup): | |||||
else: | else: | ||||
# Get data. | # Get data. | ||||
data_group = [] | data_group = [] | ||||
for trial in range(1, 101): | |||||
for trial in range(1, num_trials + 1): | |||||
file_name = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | file_name = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | ||||
if os.path.isfile(file_name): | if os.path.isfile(file_name): | ||||
with open(file_name, 'rb') as f: | with open(file_name, 'rb') as f: | ||||
@@ -64,7 +110,7 @@ def group_trials(dir_folder, name_prefix, override, clear, backup): | |||||
else: # Not all trials are completed. | else: # Not all trials are completed. | ||||
return | return | ||||
# Write groups. | # Write groups. | ||||
if label_name == 'ged_matrix': | if label_name == 'ged_matrix': | ||||
data_group = np.array(data_group) | data_group = np.array(data_group) | ||||
@@ -73,31 +119,31 @@ def group_trials(dir_folder, name_prefix, override, clear, backup): | |||||
else: | else: | ||||
with open(name_group, 'wb') as f: | with open(name_group, 'wb') as f: | ||||
pickle.dump(data_group, f) | pickle.dump(data_group, f) | ||||
trials_complete = True | trials_complete = True | ||||
if trials_complete: | if trials_complete: | ||||
# Backup. | # Backup. | ||||
if backup: | if backup: | ||||
for trial in range(1, 101): | |||||
for trial in range(1, num_trials + 1): | |||||
src = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | src = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | ||||
dst = dir_folder + 'backups/' + name_prefix + 'trial_' + str(trial) + '.pkl' | dst = dir_folder + 'backups/' + name_prefix + 'trial_' + str(trial) + '.pkl' | ||||
copyfile(src, dst) | copyfile(src, dst) | ||||
# Clear. | # Clear. | ||||
if clear: | if clear: | ||||
for trial in range(1, 101): | |||||
for trial in range(1, num_trials + 1): | |||||
src = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | src = dir_folder + name_prefix + 'trial_' + str(trial) + '.pkl' | ||||
os.remove(src) | os.remove(src) | ||||
def group_all_in_folder(dir_folder, override=False, clear=True, backup=True): | |||||
def group_all_in_folder(dir_folder, overwrite=False, clear=True, backup=True): | |||||
# Create folders. | # Create folders. | ||||
os.makedirs(dir_folder + 'groups/', exist_ok=True) | os.makedirs(dir_folder + 'groups/', exist_ok=True) | ||||
if backup: | if backup: | ||||
os.makedirs(dir_folder + 'backups', exist_ok=True) | os.makedirs(dir_folder + 'backups', exist_ok=True) | ||||
# Iterate all files. | # Iterate all files. | ||||
cur_file_prefix = '' | cur_file_prefix = '' | ||||
for file in tqdm(sorted(os.listdir(dir_folder)), desc='Grouping', file=sys.stdout): | for file in tqdm(sorted(os.listdir(dir_folder)), desc='Grouping', file=sys.stdout): | ||||
@@ -106,20 +152,23 @@ def group_all_in_folder(dir_folder, override=False, clear=True, backup=True): | |||||
# print(name) | # print(name) | ||||
# print(name_prefix) | # print(name_prefix) | ||||
if name_prefix != cur_file_prefix: | if name_prefix != cur_file_prefix: | ||||
group_trials(dir_folder, name_prefix, override, clear, backup) | |||||
group_trials(dir_folder, name_prefix, overwrite, clear, backup) | |||||
cur_file_prefix = name_prefix | cur_file_prefix = name_prefix | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
dir_folder = 'outputs/CRIANN/edit_costs.num_sols.ratios.IPFP/' | |||||
group_all_in_folder(dir_folder) | |||||
dir_folder = 'outputs/CRIANN/edit_costs.repeats.ratios.IPFP/' | |||||
group_all_in_folder(dir_folder) | |||||
dir_folder = 'outputs/CRIANN/edit_costs.max_num_sols.ratios.bipartite/' | |||||
group_all_in_folder(dir_folder) | |||||
dir_folder = 'outputs/CRIANN/edit_costs.repeats.ratios.bipartite/' | |||||
group_all_in_folder(dir_folder) | |||||
# dir_folder = 'outputs/CRIANN/edit_costs.num_sols.ratios.IPFP/' | |||||
# group_all_in_folder(dir_folder) | |||||
# dir_folder = 'outputs/CRIANN/edit_costs.repeats.ratios.IPFP/' | |||||
# group_all_in_folder(dir_folder) | |||||
# dir_folder = 'outputs/CRIANN/edit_costs.max_num_sols.ratios.bipartite/' | |||||
# group_all_in_folder(dir_folder) | |||||
# dir_folder = 'outputs/CRIANN/edit_costs.repeats.ratios.bipartite/' | |||||
# group_all_in_folder(dir_folder) | |||||
dir_folder = 'outputs/edit_costs.real_data.num_sols.ratios.IPFP/groups/' | |||||
create_group_marker_file(dir_folder) |
@@ -15,30 +15,30 @@ def get_job_script(arg): | |||||
#SBATCH --exclusive | #SBATCH --exclusive | ||||
#SBATCH --job-name="st.""" + arg + r""".IPFP" | #SBATCH --job-name="st.""" + arg + r""".IPFP" | ||||
#SBATCH --partition=tlong | |||||
#SBATCH --partition=court | |||||
#SBATCH --mail-type=ALL | #SBATCH --mail-type=ALL | ||||
#SBATCH --mail-user=jajupmochi@gmail.com | #SBATCH --mail-user=jajupmochi@gmail.com | ||||
#SBATCH --output="outputs/output_edit_costs.nums_sols.ratios.IPFP.""" + arg + """.txt" | |||||
#SBATCH --error="errors/error_edit_costs.nums_sols.ratios.IPFP.""" + arg + """.txt" | |||||
#SBATCH --output="outputs/output_edit_costs.real_data.nums_sols.ratios.IPFP.""" + arg + """.txt" | |||||
#SBATCH --error="errors/error_edit_costs.real_data.nums_sols.ratios.IPFP.""" + arg + """.txt" | |||||
# | # | ||||
#SBATCH --ntasks=1 | #SBATCH --ntasks=1 | ||||
#SBATCH --nodes=1 | #SBATCH --nodes=1 | ||||
#SBATCH --cpus-per-task=1 | #SBATCH --cpus-per-task=1 | ||||
#SBATCH --time=300:00:00 | |||||
#SBATCH --time=48:00:00 | |||||
#SBATCH --mem-per-cpu=4000 | #SBATCH --mem-per-cpu=4000 | ||||
srun hostname | srun hostname | ||||
srun cd /home/2019015/ljia02/graphkit-learn/gklearn/experiments/ged/stability | srun cd /home/2019015/ljia02/graphkit-learn/gklearn/experiments/ged/stability | ||||
srun python3 edit_costs.nums_sols.ratios.IPFP.py """ + arg | |||||
srun python3 edit_costs.real_data.nums_sols.ratios.IPFP.py """ + arg | |||||
script = script.strip() | script = script.strip() | ||||
script = re.sub('\n\t+', '\n', script) | script = re.sub('\n\t+', '\n', script) | ||||
script = re.sub('\n +', '\n', script) | script = re.sub('\n +', '\n', script) | ||||
return script | return script | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
ds_list = ['MAO', 'Monoterpenoides', 'MUTAG', 'AIDS_symb'] | |||||
for ds_name in [ds_list[i] for i in [0, 3]]: | |||||
ds_list = ['Acyclic', 'Alkane_unlabeled', 'MAO_lite', 'Monoterpenoides', 'MUTAG'] | |||||
for ds_name in [ds_list[i] for i in [0, 1, 2, 3, 4]]: | |||||
job_script = get_job_script(ds_name) | job_script = get_job_script(ds_name) | ||||
command = 'sbatch <<EOF\n' + job_script + '\nEOF' | command = 'sbatch <<EOF\n' + job_script + '\nEOF' | ||||
# print(command) | # print(command) |
@@ -5,26 +5,251 @@ Created on Thu Oct 29 19:17:36 2020 | |||||
@author: ljia | @author: ljia | ||||
""" | """ | ||||
from gklearn.utils import Dataset | |||||
import os | |||||
import pickle | |||||
import numpy as np | |||||
from tqdm import tqdm | |||||
import sys | |||||
from gklearn.dataset import Dataset | |||||
from gklearn.experiments import DATASET_ROOT | |||||
def get_dataset(ds_name): | def get_dataset(ds_name): | ||||
# The node/edge labels that will not be used in the computation. | # The node/edge labels that will not be used in the computation. | ||||
if ds_name == 'MAO': | |||||
irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']} | |||||
elif ds_name == 'Monoterpenoides': | |||||
irrelevant_labels = {'edge_labels': ['valence']} | |||||
elif ds_name == 'MUTAG': | |||||
irrelevant_labels = {'edge_labels': ['label_0']} | |||||
elif ds_name == 'AIDS_symb': | |||||
# if ds_name == 'MAO': | |||||
# irrelevant_labels = {'node_attrs': ['x', 'y', 'z'], 'edge_labels': ['bond_stereo']} | |||||
# if ds_name == 'Monoterpenoides': | |||||
# irrelevant_labels = {'edge_labels': ['valence']} | |||||
# elif ds_name == 'MUTAG': | |||||
# irrelevant_labels = {'edge_labels': ['label_0']} | |||||
if ds_name == 'AIDS_symb': | |||||
irrelevant_labels = {'node_attrs': ['chem', 'charge', 'x', 'y'], 'edge_labels': ['valence']} | irrelevant_labels = {'node_attrs': ['chem', 'charge', 'x', 'y'], 'edge_labels': ['valence']} | ||||
ds_name = 'AIDS' | ds_name = 'AIDS' | ||||
else: | |||||
irrelevant_labels = {} | |||||
# Initialize a Dataset. | |||||
dataset = Dataset() | |||||
# Load predefined dataset. | # Load predefined dataset. | ||||
dataset.load_predefined_dataset(ds_name) | |||||
dataset = Dataset(ds_name, root=DATASET_ROOT) | |||||
# Remove irrelevant labels. | # Remove irrelevant labels. | ||||
dataset.remove_labels(**irrelevant_labels) | dataset.remove_labels(**irrelevant_labels) | ||||
print('dataset size:', len(dataset.graphs)) | print('dataset size:', len(dataset.graphs)) | ||||
return dataset | |||||
return dataset | |||||
def set_edit_cost_consts(ratio, node_labeled=True, edge_labeled=True, mode='uniform'): | |||||
if mode == 'uniform': | |||||
edit_cost_constants = [i * ratio for i in [1, 1, 1]] + [1, 1, 1] | |||||
if not node_labeled: | |||||
edit_cost_constants[2] = 0 | |||||
if not edge_labeled: | |||||
edit_cost_constants[5] = 0 | |||||
return edit_cost_constants | |||||
def nested_keys_exists(element, *keys): | |||||
''' | |||||
Check if *keys (nested) exists in `element` (dict). | |||||
''' | |||||
if not isinstance(element, dict): | |||||
raise AttributeError('keys_exists() expects dict as first argument.') | |||||
if len(keys) == 0: | |||||
raise AttributeError('keys_exists() expects at least two arguments, one given.') | |||||
_element = element | |||||
for key in keys: | |||||
try: | |||||
_element = _element[key] | |||||
except KeyError: | |||||
return False | |||||
return True | |||||
# Check average relative error along elements in two ged matrices. | |||||
def matrices_ave_relative_error(m1, m2): | |||||
error = 0 | |||||
base = 0 | |||||
for i in range(m1.shape[0]): | |||||
for j in range(m1.shape[1]): | |||||
error += np.abs(m1[i, j] - m2[i, j]) | |||||
base += (np.abs(m1[i, j]) + np.abs(m2[i, j])) / 2 | |||||
return error / base | |||||
def compute_relative_error(ged_mats): | |||||
if len(ged_mats) != 0: | |||||
# get the smallest "correct" GED matrix. | |||||
ged_mat_s = np.ones(ged_mats[0].shape) * np.inf | |||||
for i in range(ged_mats[0].shape[0]): | |||||
for j in range(ged_mats[0].shape[1]): | |||||
ged_mat_s[i, j] = np.min([mat[i, j] for mat in ged_mats]) | |||||
# compute average error. | |||||
errors = [] | |||||
for i, mat in enumerate(ged_mats): | |||||
err = matrices_ave_relative_error(mat, ged_mat_s) | |||||
# if not per_correct: | |||||
# print('matrix # ', str(i)) | |||||
# pass | |||||
errors.append(err) | |||||
else: | |||||
errors = [0] | |||||
return np.mean(errors) | |||||
def parse_group_file_name(fn): | |||||
splits_all = fn.split('.') | |||||
key1 = splits_all[1] | |||||
pos2 = splits_all[2].rfind('_') | |||||
# key2 = splits_all[2][:pos2] | |||||
val2 = splits_all[2][pos2+1:] | |||||
pos3 = splits_all[3].rfind('_') | |||||
# key3 = splits_all[3][:pos3] | |||||
val3 = splits_all[3][pos3+1:] + '.' + splits_all[4] | |||||
return key1, val2, val3 | |||||
def get_all_errors(save_dir, errors): | |||||
# Loop for each GED matrix file. | |||||
for file in tqdm(sorted(os.listdir(save_dir)), desc='Getting errors', file=sys.stdout): | |||||
if os.path.isfile(os.path.join(save_dir, file)) and file.startswith('ged_mats.'): | |||||
keys = parse_group_file_name(file) | |||||
# Check if the results is in the errors. | |||||
if not keys[0] in errors: | |||||
errors[keys[0]] = {} | |||||
if not keys[1] in errors[keys[0]]: | |||||
errors[keys[0]][keys[1]] = {} | |||||
# Compute the error if not exist. | |||||
if not keys[2] in errors[keys[0]][keys[1]]: | |||||
ged_mats = np.load(os.path.join(save_dir, file)) | |||||
errors[keys[0]][keys[1]][keys[2]] = compute_relative_error(ged_mats) | |||||
return errors | |||||
def get_relative_errors(save_dir, overwrite=False): | |||||
""" # Read relative errors from previous computed and saved file. Create the | |||||
file, compute the errors, or add and save the new computed errors to the | |||||
file if necessary. | |||||
Parameters | |||||
---------- | |||||
save_dir : TYPE | |||||
DESCRIPTION. | |||||
overwrite : TYPE, optional | |||||
DESCRIPTION. The default is False. | |||||
Returns | |||||
------- | |||||
None. | |||||
""" | |||||
if not overwrite: | |||||
fn_err = save_dir + '/relative_errors.pkl' | |||||
# If error file exists. | |||||
if os.path.isfile(fn_err): | |||||
with open(fn_err, 'rb') as f: | |||||
errors = pickle.load(f) | |||||
errors = get_all_errors(save_dir, errors) | |||||
else: | |||||
errors = get_all_errors(save_dir, {}) | |||||
else: | |||||
errors = get_all_errors(save_dir, {}) | |||||
with open(fn_err, 'wb') as f: | |||||
pickle.dump(errors, f) | |||||
return errors | |||||
def interpolate_result(Z, method='linear'): | |||||
values = Z.copy() | |||||
for i in range(Z.shape[0]): | |||||
for j in range(Z.shape[1]): | |||||
if np.isnan(Z[i, j]): | |||||
# Get the nearest non-nan values. | |||||
x_neg = np.nan | |||||
for idx, val in enumerate(Z[i, :][j::-1]): | |||||
if not np.isnan(val): | |||||
x_neg = val | |||||
x_neg_off = idx | |||||
break | |||||
x_pos = np.nan | |||||
for idx, val in enumerate(Z[i, :][j:]): | |||||
if not np.isnan(val): | |||||
x_pos = val | |||||
x_pos_off = idx | |||||
break | |||||
# Interpolate. | |||||
if not np.isnan(x_neg) and not np.isnan(x_pos): | |||||
val_int = (x_pos_off / (x_neg_off + x_pos_off)) * (x_neg - x_pos) + x_pos | |||||
values[i, j] = val_int | |||||
break | |||||
y_neg = np.nan | |||||
for idx, val in enumerate(Z[:, j][i::-1]): | |||||
if not np.isnan(val): | |||||
y_neg = val | |||||
y_neg_off = idx | |||||
break | |||||
y_pos = np.nan | |||||
for idx, val in enumerate(Z[:, j][i:]): | |||||
if not np.isnan(val): | |||||
y_pos = val | |||||
y_pos_off = idx | |||||
break | |||||
# Interpolate. | |||||
if not np.isnan(y_neg) and not np.isnan(y_pos): | |||||
val_int = (y_pos_off / (y_neg_off + y_neg_off)) * (y_neg - y_pos) + y_pos | |||||
values[i, j] = val_int | |||||
break | |||||
return values | |||||
def set_axis_style(ax): | |||||
ax.set_axisbelow(True) | |||||
ax.spines['top'].set_visible(False) | |||||
ax.spines['bottom'].set_visible(False) | |||||
ax.spines['right'].set_visible(False) | |||||
ax.spines['left'].set_visible(False) | |||||
ax.xaxis.set_ticks_position('none') | |||||
ax.yaxis.set_ticks_position('none') | |||||
ax.tick_params(labelsize=8, color='w', pad=1, grid_color='w') | |||||
ax.tick_params(axis='x', pad=-2) | |||||
ax.tick_params(axis='y', labelrotation=-40, pad=-2) | |||||
# ax.zaxis._axinfo['juggled'] = (1, 2, 0) | |||||
ax.set_xlabel(ax.get_xlabel(), fontsize=10, labelpad=-3) | |||||
ax.set_ylabel(ax.get_ylabel(), fontsize=10, labelpad=-2, rotation=50) | |||||
ax.set_zlabel(ax.get_zlabel(), fontsize=10, labelpad=-2) | |||||
ax.set_title(ax.get_title(), pad=30, fontsize=15) | |||||
return | |||||
if __name__ == '__main__': | |||||
root_dir = 'outputs/CRIANN/' | |||||
# for dir_ in sorted(os.listdir(root_dir)): | |||||
# if os.path.isdir(root_dir): | |||||
# full_dir = os.path.join(root_dir, dir_) | |||||
# print('---', full_dir,':') | |||||
# save_dir = os.path.join(full_dir, 'groups/') | |||||
# if os.path.exists(save_dir): | |||||
# try: | |||||
# get_relative_errors(save_dir) | |||||
# except Exception as exp: | |||||
# print('An exception occured when running this experiment:') | |||||
# print(repr(exp)) |