diff --git a/lang/fr/gklearn/ged/learning/costs_learner.py b/lang/fr/gklearn/ged/learning/costs_learner.py new file mode 100644 index 0000000..9c77fc5 --- /dev/null +++ b/lang/fr/gklearn/ged/learning/costs_learner.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Jul 7 11:30:31 2020 + +@author: ljia +""" +import numpy as np +import cvxpy as cp +import time +from gklearn.utils import Timer + + +class CostsLearner(object): + + + def __init__(self, parallel, verbose): + ### To set. + self._parallel = parallel + self._verbose = verbose + # For update(). + self._time_limit_in_sec = 0 + self._max_itrs = 100 + self._max_itrs_without_update = 3 + self._epsilon_residual = 0.01 + self._epsilon_ec = 0.1 + ### To compute. + self._residual_list = [] + self._runtime_list = [] + self._cost_list = [] + self._nb_eo = None + # For update(). + self._itrs = 0 + self._converged = False + self._num_updates_ecs = 0 + self._ec_changed = None + self._residual_changed = None + self._itrs_without_update = 0 + ### Both set and get. + self._ged_options = None + + + def fit(self, X, y): + pass + + + def preprocess(self): + pass # @todo: remove the zero numbers of edit costs. + + + def postprocess(self): + for i in range(len(self._cost_list[-1])): + if -1e-9 <= self._cost_list[-1][i] <= 1e-9: + self._cost_list[-1][i] = 0 + if self._cost_list[-1][i] < 0: + raise ValueError('The edit cost is negative.') + + + def set_update_params(self, **kwargs): + self._time_limit_in_sec = kwargs.get('time_limit_in_sec', self._time_limit_in_sec) + self._max_itrs = kwargs.get('max_itrs', self._max_itrs) + self._max_itrs_without_update = kwargs.get('max_itrs_without_update', self._max_itrs_without_update) + self._epsilon_residual = kwargs.get('epsilon_residual', self._epsilon_residual) + self._epsilon_ec = kwargs.get('epsilon_ec', self._epsilon_ec) + + + def update(self, y, graphs, ged_options, **kwargs): + # Set parameters. + self._ged_options = ged_options + if kwargs != {}: + self.set_update_params(**kwargs) + + # The initial iteration. + if self._verbose >= 2: + print('\ninitial:') + self.init_geds_and_nb_eo(y, graphs) + + self._converged = False + self._itrs_without_update = 0 + self._itrs = 0 + self._num_updates_ecs = 0 + timer = Timer(self._time_limit_in_sec) + # Run iterations from initial edit costs. + while not self.termination_criterion_met(self._converged, timer, self._itrs, self._itrs_without_update): + if self._verbose >= 2: + print('\niteration', self._itrs + 1) + time0 = time.time() + + # Fit GED space to the target space. + self.preprocess() + self.fit(self._nb_eo, y) + self.postprocess() + + # Compute new GEDs and numbers of edit operations. + self.update_geds_and_nb_eo(y, graphs, time0) + + # Check convergency. + self.check_convergency() + + # Print current states. + if self._verbose >= 2: + self.print_current_states() + + self._itrs += 1 + + + def init_geds_and_nb_eo(self, y, graphs): + pass + + + def update_geds_and_nb_eo(self, y, graphs, time0): + pass + + + def compute_geds_and_nb_eo(self, graphs): + pass + + + def check_convergency(self): + pass + + + def print_current_states(self): + pass + + + def termination_criterion_met(self, converged, timer, itr, itrs_without_update): + if timer.expired() or (itr >= self._max_itrs if self._max_itrs >= 0 else False): +# if self.__state == AlgorithmState.TERMINATED: +# self.__state = AlgorithmState.INITIALIZED + return True + return converged or (itrs_without_update > self._max_itrs_without_update if self._max_itrs_without_update >= 0 else False) + + + def execute_cvx(self, prob): + try: + prob.solve(verbose=(self._verbose>=2)) + except MemoryError as error0: + if self._verbose >= 2: + print('\nUsing solver "OSQP" caused a memory error.') + print('the original error message is\n', error0) + print('solver status: ', prob.status) + print('trying solver "CVXOPT" instead...\n') + try: + prob.solve(solver=cp.CVXOPT, verbose=(self._verbose>=2)) + except Exception as error1: + if self._verbose >= 2: + print('\nAn error occured when using solver "CVXOPT".') + print('the original error message is\n', error1) + print('solver status: ', prob.status) + print('trying solver "MOSEK" instead. Notice this solver is commercial and a lisence is required.\n') + prob.solve(solver=cp.MOSEK, verbose=(self._verbose>=2)) + else: + if self._verbose >= 2: + print('solver status: ', prob.status) + else: + if self._verbose >= 2: + print('solver status: ', prob.status) + if self._verbose >= 2: + print() + + + def get_results(self): + results = {} + results['residual_list'] = self._residual_list + results['runtime_list'] = self._runtime_list + results['cost_list'] = self._cost_list + results['nb_eo'] = self._nb_eo + results['itrs'] = self._itrs + results['converged'] = self._converged + results['num_updates_ecs'] = self._num_updates_ecs + results['ec_changed'] = self._ec_changed + results['residual_changed'] = self._residual_changed + results['itrs_without_update'] = self._itrs_without_update + return results \ No newline at end of file