You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

costs_learner.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Jul 7 11:30:31 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import cvxpy as cp
  9. import time
  10. from gklearn.utils import Timer
  11. class CostsLearner(object):
  12. def __init__(self, parallel, verbose):
  13. ### To set.
  14. self._parallel = parallel
  15. self._verbose = verbose
  16. # For update().
  17. self._time_limit_in_sec = 0
  18. self._max_itrs = 100
  19. self._max_itrs_without_update = 3
  20. self._epsilon_residual = 0.01
  21. self._epsilon_ec = 0.1
  22. ### To compute.
  23. self._residual_list = []
  24. self._runtime_list = []
  25. self._cost_list = []
  26. self._nb_eo = None
  27. # For update().
  28. self._itrs = 0
  29. self._converged = False
  30. self._num_updates_ecs = 0
  31. self._ec_changed = None
  32. self._residual_changed = None
  33. self._itrs_without_update = 0
  34. ### Both set and get.
  35. self._ged_options = None
  36. def fit(self, X, y):
  37. pass
  38. def preprocess(self):
  39. pass # @todo: remove the zero numbers of edit costs.
  40. def postprocess(self):
  41. for i in range(len(self._cost_list[-1])):
  42. if -1e-9 <= self._cost_list[-1][i] <= 1e-9:
  43. self._cost_list[-1][i] = 0
  44. if self._cost_list[-1][i] < 0:
  45. raise ValueError('The edit cost is negative.')
  46. def set_update_params(self, **kwargs):
  47. self._time_limit_in_sec = kwargs.get('time_limit_in_sec', self._time_limit_in_sec)
  48. self._max_itrs = kwargs.get('max_itrs', self._max_itrs)
  49. self._max_itrs_without_update = kwargs.get('max_itrs_without_update', self._max_itrs_without_update)
  50. self._epsilon_residual = kwargs.get('epsilon_residual', self._epsilon_residual)
  51. self._epsilon_ec = kwargs.get('epsilon_ec', self._epsilon_ec)
  52. def update(self, y, graphs, ged_options, **kwargs):
  53. # Set parameters.
  54. self._ged_options = ged_options
  55. if kwargs != {}:
  56. self.set_update_params(**kwargs)
  57. # The initial iteration.
  58. if self._verbose >= 2:
  59. print('\ninitial:')
  60. self.init_geds_and_nb_eo(y, graphs)
  61. self._converged = False
  62. self._itrs_without_update = 0
  63. self._itrs = 0
  64. self._num_updates_ecs = 0
  65. timer = Timer(self._time_limit_in_sec)
  66. # Run iterations from initial edit costs.
  67. while not self.termination_criterion_met(self._converged, timer, self._itrs, self._itrs_without_update):
  68. if self._verbose >= 2:
  69. print('\niteration', self._itrs + 1)
  70. time0 = time.time()
  71. # Fit GED space to the target space.
  72. self.preprocess()
  73. self.fit(self._nb_eo, y)
  74. self.postprocess()
  75. # Compute new GEDs and numbers of edit operations.
  76. self.update_geds_and_nb_eo(y, graphs, time0)
  77. # Check convergency.
  78. self.check_convergency()
  79. # Print current states.
  80. if self._verbose >= 2:
  81. self.print_current_states()
  82. self._itrs += 1
  83. def init_geds_and_nb_eo(self, y, graphs):
  84. pass
  85. def update_geds_and_nb_eo(self, y, graphs, time0):
  86. pass
  87. def compute_geds_and_nb_eo(self, graphs):
  88. pass
  89. def check_convergency(self):
  90. pass
  91. def print_current_states(self):
  92. pass
  93. def termination_criterion_met(self, converged, timer, itr, itrs_without_update):
  94. if timer.expired() or (itr >= self._max_itrs if self._max_itrs >= 0 else False):
  95. # if self._state == AlgorithmState.TERMINATED:
  96. # self._state = AlgorithmState.INITIALIZED
  97. return True
  98. return converged or (itrs_without_update > self._max_itrs_without_update if self._max_itrs_without_update >= 0 else False)
  99. def execute_cvx(self, prob):
  100. try:
  101. prob.solve(verbose=(self._verbose>=2))
  102. except MemoryError as error0:
  103. if self._verbose >= 2:
  104. print('\nUsing solver "OSQP" caused a memory error.')
  105. print('the original error message is\n', error0)
  106. print('solver status: ', prob.status)
  107. print('trying solver "CVXOPT" instead...\n')
  108. try:
  109. prob.solve(solver=cp.CVXOPT, verbose=(self._verbose>=2))
  110. except Exception as error1:
  111. if self._verbose >= 2:
  112. print('\nAn error occured when using solver "CVXOPT".')
  113. print('the original error message is\n', error1)
  114. print('solver status: ', prob.status)
  115. print('trying solver "MOSEK" instead. Notice this solver is commercial and a lisence is required.\n')
  116. prob.solve(solver=cp.MOSEK, verbose=(self._verbose>=2))
  117. else:
  118. if self._verbose >= 2:
  119. print('solver status: ', prob.status)
  120. else:
  121. if self._verbose >= 2:
  122. print('solver status: ', prob.status)
  123. if self._verbose >= 2:
  124. print()
  125. def get_results(self):
  126. results = {}
  127. results['residual_list'] = self._residual_list
  128. results['runtime_list'] = self._runtime_list
  129. results['cost_list'] = self._cost_list
  130. results['nb_eo'] = self._nb_eo
  131. results['itrs'] = self._itrs
  132. results['converged'] = self._converged
  133. results['num_updates_ecs'] = self._num_updates_ecs
  134. results['ec_changed'] = self._ec_changed
  135. results['residual_changed'] = self._residual_changed
  136. results['itrs_without_update'] = self._itrs_without_update
  137. return results

A Python package for graph kernels, graph edit distances and graph pre-image problem.