You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cost_matrices_learner.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Tue Jul 7 11:42:48 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import cvxpy as cp
  9. import time
  10. from gklearn.ged.learning.costs_learner import CostsLearner
  11. from gklearn.ged.util import compute_geds_cml
  12. class CostMatricesLearner(CostsLearner):
  13. def __init__(self, edit_cost='CONSTANT', triangle_rule=False, allow_zeros=True, parallel=False, verbose=2):
  14. super().__init__(parallel, verbose)
  15. self._edit_cost = edit_cost
  16. self._triangle_rule = triangle_rule
  17. self._allow_zeros = allow_zeros
  18. def fit(self, X, y):
  19. if self._edit_cost == 'LETTER':
  20. raise Exception('Cannot compute for cost "LETTER".')
  21. elif self._edit_cost == 'LETTER2':
  22. raise Exception('Cannot compute for cost "LETTER2".')
  23. elif self._edit_cost == 'NON_SYMBOLIC':
  24. raise Exception('Cannot compute for cost "NON_SYMBOLIC".')
  25. elif self._edit_cost == 'CONSTANT': # @todo: node/edge may not labeled.
  26. if not self._triangle_rule and self._allow_zeros:
  27. w = cp.Variable(X.shape[1])
  28. cost_fun = cp.sum_squares(X @ w - y)
  29. constraints = [w >= [0.0 for i in range(X.shape[1])]]
  30. prob = cp.Problem(cp.Minimize(cost_fun), constraints)
  31. self.execute_cvx(prob)
  32. edit_costs_new = w.value
  33. residual = np.sqrt(prob.value)
  34. elif self._triangle_rule and self._allow_zeros: # @todo
  35. x = cp.Variable(nb_cost_mat.shape[1])
  36. cost_fun = cp.sum_squares(nb_cost_mat @ x - dis_k_vec)
  37. constraints = [x >= [0.0 for i in range(nb_cost_mat.shape[1])],
  38. np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0]).T@x >= 0.01,
  39. np.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0]).T@x >= 0.01,
  40. np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).T@x >= 0.01,
  41. np.array([0.0, 0.0, 0.0, 0.0, 1.0, 0.0]).T@x >= 0.01,
  42. np.array([1.0, 1.0, -1.0, 0.0, 0.0, 0.0]).T@x >= 0.0,
  43. np.array([0.0, 0.0, 0.0, 1.0, 1.0, -1.0]).T@x >= 0.0]
  44. prob = cp.Problem(cp.Minimize(cost_fun), constraints)
  45. self._execute_cvx(prob)
  46. edit_costs_new = x.value
  47. residual = np.sqrt(prob.value)
  48. elif not self._triangle_rule and not self._allow_zeros: # @todo
  49. x = cp.Variable(nb_cost_mat.shape[1])
  50. cost_fun = cp.sum_squares(nb_cost_mat @ x - dis_k_vec)
  51. constraints = [x >= [0.01 for i in range(nb_cost_mat.shape[1])]]
  52. prob = cp.Problem(cp.Minimize(cost_fun), constraints)
  53. self._execute_cvx(prob)
  54. edit_costs_new = x.value
  55. residual = np.sqrt(prob.value)
  56. elif self._triangle_rule and not self._allow_zeros: # @todo
  57. x = cp.Variable(nb_cost_mat.shape[1])
  58. cost_fun = cp.sum_squares(nb_cost_mat @ x - dis_k_vec)
  59. constraints = [x >= [0.01 for i in range(nb_cost_mat.shape[1])],
  60. np.array([1.0, 1.0, -1.0, 0.0, 0.0, 0.0]).T@x >= 0.0,
  61. np.array([0.0, 0.0, 0.0, 1.0, 1.0, -1.0]).T@x >= 0.0]
  62. prob = cp.Problem(cp.Minimize(cost_fun), constraints)
  63. self._execute_cvx(prob)
  64. edit_costs_new = x.value
  65. residual = np.sqrt(prob.value)
  66. else:
  67. raise Exception('The edit cost "', self._ged_options['edit_cost'], '" is not supported for update progress.')
  68. self._cost_list.append(edit_costs_new)
  69. def init_geds_and_nb_eo(self, y, graphs):
  70. time0 = time.time()
  71. self._cost_list.append(np.concatenate((self._ged_options['node_label_costs'],
  72. self._ged_options['edge_label_costs'])))
  73. ged_vec, self._nb_eo = self.compute_geds_and_nb_eo(graphs)
  74. self._residual_list.append(np.sqrt(np.sum(np.square(np.array(ged_vec) - y))))
  75. self._runtime_list.append(time.time() - time0)
  76. if self._verbose >= 2:
  77. print('Current node label costs:', self._cost_list[-1][0:len(self._ged_options['node_label_costs'])])
  78. print('Current edge label costs:', self._cost_list[-1][len(self._ged_options['node_label_costs']):])
  79. print('Residual list:', self._residual_list)
  80. def update_geds_and_nb_eo(self, y, graphs, time0):
  81. self._ged_options['node_label_costs'] = self._cost_list[-1][0:len(self._ged_options['node_label_costs'])]
  82. self._ged_options['edge_label_costs'] = self._cost_list[-1][len(self._ged_options['node_label_costs']):]
  83. ged_vec, self._nb_eo = self.compute_geds_and_nb_eo(graphs)
  84. self._residual_list.append(np.sqrt(np.sum(np.square(np.array(ged_vec) - y))))
  85. self._runtime_list.append(time.time() - time0)
  86. def compute_geds_and_nb_eo(self, graphs):
  87. ged_vec, ged_mat, n_edit_operations = compute_geds_cml(graphs, options=self._ged_options, parallel=self._parallel, verbose=(self._verbose > 1))
  88. return ged_vec, np.array(n_edit_operations)
  89. def check_convergency(self):
  90. self._ec_changed = False
  91. for i, cost in enumerate(self._cost_list[-1]):
  92. if cost == 0:
  93. if self._cost_list[-2][i] > self._epsilon_ec:
  94. self._ec_changed = True
  95. break
  96. elif abs(cost - self._cost_list[-2][i]) / cost > self._epsilon_ec:
  97. self._ec_changed = True
  98. break
  99. # if abs(cost - edit_cost_list[-2][i]) > self._epsilon_ec:
  100. # ec_changed = True
  101. # break
  102. self._residual_changed = False
  103. if self._residual_list[-1] == 0:
  104. if self._residual_list[-2] > self._epsilon_residual:
  105. self._residual_changed = True
  106. elif abs(self._residual_list[-1] - self._residual_list[-2]) / self._residual_list[-1] > self._epsilon_residual:
  107. self._residual_changed = True
  108. self._converged = not (self._ec_changed or self._residual_changed)
  109. if self._converged:
  110. self._itrs_without_update += 1
  111. else:
  112. self._itrs_without_update = 0
  113. self._num_updates_ecs += 1
  114. def print_current_states(self):
  115. print()
  116. print('-------------------------------------------------------------------------')
  117. print('States of iteration', self._itrs + 1)
  118. print('-------------------------------------------------------------------------')
  119. # print('Time spend:', self._runtime_optimize_ec)
  120. print('Total number of iterations for optimizing:', self._itrs + 1)
  121. print('Total number of updating edit costs:', self._num_updates_ecs)
  122. print('Was optimization of edit costs converged:', self._converged)
  123. print('Did edit costs change:', self._ec_changed)
  124. print('Did residual change:', self._residual_changed)
  125. print('Iterations without update:', self._itrs_without_update)
  126. print('Current node label costs:', self._cost_list[-1][0:len(self._ged_options['node_label_costs'])])
  127. print('Current edge label costs:', self._cost_list[-1][len(self._ged_options['node_label_costs']):])
  128. print('Residual list:', self._residual_list)
  129. print('-------------------------------------------------------------------------')

A Python package for graph kernels, graph edit distances and graph pre-image problem.