|
@@ -0,0 +1,175 @@ |
|
|
|
|
|
#!/usr/bin/env python3 |
|
|
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
|
|
""" |
|
|
|
|
|
Created on Tue Jul 7 11:30:31 2020 |
|
|
|
|
|
|
|
|
|
|
|
@author: ljia |
|
|
|
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import cvxpy as cp |
|
|
|
|
|
import time |
|
|
|
|
|
from gklearn.utils import Timer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CostsLearner(object): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, parallel, verbose): |
|
|
|
|
|
### To set. |
|
|
|
|
|
self._parallel = parallel |
|
|
|
|
|
self._verbose = verbose |
|
|
|
|
|
# For update(). |
|
|
|
|
|
self._time_limit_in_sec = 0 |
|
|
|
|
|
self._max_itrs = 100 |
|
|
|
|
|
self._max_itrs_without_update = 3 |
|
|
|
|
|
self._epsilon_residual = 0.01 |
|
|
|
|
|
self._epsilon_ec = 0.1 |
|
|
|
|
|
### To compute. |
|
|
|
|
|
self._residual_list = [] |
|
|
|
|
|
self._runtime_list = [] |
|
|
|
|
|
self._cost_list = [] |
|
|
|
|
|
self._nb_eo = None |
|
|
|
|
|
# For update(). |
|
|
|
|
|
self._itrs = 0 |
|
|
|
|
|
self._converged = False |
|
|
|
|
|
self._num_updates_ecs = 0 |
|
|
|
|
|
self._ec_changed = None |
|
|
|
|
|
self._residual_changed = None |
|
|
|
|
|
self._itrs_without_update = 0 |
|
|
|
|
|
### Both set and get. |
|
|
|
|
|
self._ged_options = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, X, y): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self): |
|
|
|
|
|
pass # @todo: remove the zero numbers of edit costs. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess(self): |
|
|
|
|
|
for i in range(len(self._cost_list[-1])): |
|
|
|
|
|
if -1e-9 <= self._cost_list[-1][i] <= 1e-9: |
|
|
|
|
|
self._cost_list[-1][i] = 0 |
|
|
|
|
|
if self._cost_list[-1][i] < 0: |
|
|
|
|
|
raise ValueError('The edit cost is negative.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_update_params(self, **kwargs): |
|
|
|
|
|
self._time_limit_in_sec = kwargs.get('time_limit_in_sec', self._time_limit_in_sec) |
|
|
|
|
|
self._max_itrs = kwargs.get('max_itrs', self._max_itrs) |
|
|
|
|
|
self._max_itrs_without_update = kwargs.get('max_itrs_without_update', self._max_itrs_without_update) |
|
|
|
|
|
self._epsilon_residual = kwargs.get('epsilon_residual', self._epsilon_residual) |
|
|
|
|
|
self._epsilon_ec = kwargs.get('epsilon_ec', self._epsilon_ec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update(self, y, graphs, ged_options, **kwargs): |
|
|
|
|
|
# Set parameters. |
|
|
|
|
|
self._ged_options = ged_options |
|
|
|
|
|
if kwargs != {}: |
|
|
|
|
|
self.set_update_params(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
# The initial iteration. |
|
|
|
|
|
if self._verbose >= 2: |
|
|
|
|
|
print('\ninitial:') |
|
|
|
|
|
self.init_geds_and_nb_eo(y, graphs) |
|
|
|
|
|
|
|
|
|
|
|
self._converged = False |
|
|
|
|
|
self._itrs_without_update = 0 |
|
|
|
|
|
self._itrs = 0 |
|
|
|
|
|
self._num_updates_ecs = 0 |
|
|
|
|
|
timer = Timer(self._time_limit_in_sec) |
|
|
|
|
|
# Run iterations from initial edit costs. |
|
|
|
|
|
while not self.termination_criterion_met(self._converged, timer, self._itrs, self._itrs_without_update): |
|
|
|
|
|
if self._verbose >= 2: |
|
|
|
|
|
print('\niteration', self._itrs + 1) |
|
|
|
|
|
time0 = time.time() |
|
|
|
|
|
|
|
|
|
|
|
# Fit GED space to the target space. |
|
|
|
|
|
self.preprocess() |
|
|
|
|
|
self.fit(self._nb_eo, y) |
|
|
|
|
|
self.postprocess() |
|
|
|
|
|
|
|
|
|
|
|
# Compute new GEDs and numbers of edit operations. |
|
|
|
|
|
self.update_geds_and_nb_eo(y, graphs, time0) |
|
|
|
|
|
|
|
|
|
|
|
# Check convergency. |
|
|
|
|
|
self.check_convergency() |
|
|
|
|
|
|
|
|
|
|
|
# Print current states. |
|
|
|
|
|
if self._verbose >= 2: |
|
|
|
|
|
self.print_current_states() |
|
|
|
|
|
|
|
|
|
|
|
self._itrs += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_geds_and_nb_eo(self, y, graphs): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_geds_and_nb_eo(self, y, graphs, time0): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_geds_and_nb_eo(self, graphs): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_convergency(self): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_current_states(self): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def termination_criterion_met(self, converged, timer, itr, itrs_without_update): |
|
|
|
|
|
if timer.expired() or (itr >= self._max_itrs if self._max_itrs >= 0 else False): |
|
|
|
|
|
# if self.__state == AlgorithmState.TERMINATED: |
|
|
|
|
|
# self.__state = AlgorithmState.INITIALIZED |
|
|
|
|
|
return True |
|
|
|
|
|
return converged or (itrs_without_update > self._max_itrs_without_update if self._max_itrs_without_update >= 0 else False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def execute_cvx(self, prob): |
|
|
|
|
|
try: |
|
|
|
|
|
prob.solve(verbose=(self._verbose>=2)) |
|
|
|
|
|
except MemoryError as error0: |
|
|
|
|
|
if self._verbose >= 2: |
|
|
|
|
|
print('\nUsing solver "OSQP" caused a memory error.') |
|
|
|
|
|
print('the original error message is\n', error0) |
|
|
|
|
|
print('solver status: ', prob.status) |
|
|
|
|
|
print('trying solver "CVXOPT" instead...\n') |
|
|
|
|
|
try: |
|
|
|
|
|
prob.solve(solver=cp.CVXOPT, verbose=(self._verbose>=2)) |
|
|
|
|
|
except Exception as error1: |
|
|
|
|
|
if self._verbose >= 2: |
|
|
|
|
|
print('\nAn error occured when using solver "CVXOPT".') |
|
|
|
|
|
print('the original error message is\n', error1) |
|
|
|
|
|
print('solver status: ', prob.status) |
|
|
|
|
|
print('trying solver "MOSEK" instead. Notice this solver is commercial and a lisence is required.\n') |
|
|
|
|
|
prob.solve(solver=cp.MOSEK, verbose=(self._verbose>=2)) |
|
|
|
|
|
else: |
|
|
|
|
|
if self._verbose >= 2: |
|
|
|
|
|
print('solver status: ', prob.status) |
|
|
|
|
|
else: |
|
|
|
|
|
if self._verbose >= 2: |
|
|
|
|
|
print('solver status: ', prob.status) |
|
|
|
|
|
if self._verbose >= 2: |
|
|
|
|
|
print() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_results(self): |
|
|
|
|
|
results = {} |
|
|
|
|
|
results['residual_list'] = self._residual_list |
|
|
|
|
|
results['runtime_list'] = self._runtime_list |
|
|
|
|
|
results['cost_list'] = self._cost_list |
|
|
|
|
|
results['nb_eo'] = self._nb_eo |
|
|
|
|
|
results['itrs'] = self._itrs |
|
|
|
|
|
results['converged'] = self._converged |
|
|
|
|
|
results['num_updates_ecs'] = self._num_updates_ecs |
|
|
|
|
|
results['ec_changed'] = self._ec_changed |
|
|
|
|
|
results['residual_changed'] = self._residual_changed |
|
|
|
|
|
results['itrs_without_update'] = self._itrs_without_update |
|
|
|
|
|
return results |