You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

median_graph_estimator.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 16 18:04:55 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. from gklearn.preimage.common_types import AlgorithmState
  9. from gklearn.preimage import misc
  10. from gklearn.preimage.timer import Timer
  11. from gklearn.utils.utils import graph_isIdentical
  12. import time
  13. from tqdm import tqdm
  14. import sys
  15. import networkx as nx
  16. class MedianGraphEstimator(object):
  17. def __init__(self, ged_env, constant_node_costs):
  18. """Constructor.
  19. Parameters
  20. ----------
  21. ged_env : gklearn.gedlib.gedlibpy.GEDEnv
  22. Initialized GED environment. The edit costs must be set by the user.
  23. constant_node_costs : Boolean
  24. Set to True if the node relabeling costs are constant.
  25. """
  26. self.__ged_env = ged_env
  27. self.__init_method = 'BRANCH_FAST'
  28. self.__init_options = ''
  29. self.__descent_method = 'BRANCH_FAST'
  30. self.__descent_options = ''
  31. self.__refine_method = 'IPFP'
  32. self.__refine_options = ''
  33. self.__constant_node_costs = constant_node_costs
  34. self.__labeled_nodes = (ged_env.get_num_node_labels() > 1)
  35. self.__node_del_cost = ged_env.get_node_del_cost(ged_env.get_node_label(1))
  36. self.__node_ins_cost = ged_env.get_node_ins_cost(ged_env.get_node_label(1))
  37. self.__labeled_edges = (ged_env.get_num_edge_labels() > 1)
  38. self.__edge_del_cost = ged_env.get_edge_del_cost(ged_env.get_edge_label(1))
  39. self.__edge_ins_cost = ged_env.get_edge_ins_cost(ged_env.get_edge_label(1))
  40. self.__init_type = 'RANDOM'
  41. self.__num_random_inits = 10
  42. self.__desired_num_random_inits = 10
  43. self.__use_real_randomness = True
  44. self.__seed = 0
  45. self.__refine = True
  46. self.__time_limit_in_sec = 0
  47. self.__epsilon = 0.0001
  48. self.__max_itrs = 100
  49. self.__max_itrs_without_update = 3
  50. self.__num_inits_increase_order = 10
  51. self.__init_type_increase_order = 'K-MEANS++'
  52. self.__max_itrs_increase_order = 10
  53. self.__print_to_stdout = 2
  54. self.__median_id = np.inf # @todo: check
  55. self.__median_node_id_prefix = '' # @todo: check
  56. self.__node_maps_from_median = {}
  57. self.__sum_of_distances = 0
  58. self.__best_init_sum_of_distances = np.inf
  59. self.__converged_sum_of_distances = np.inf
  60. self.__runtime = None
  61. self.__runtime_initialized = None
  62. self.__runtime_converged = None
  63. self.__itrs = [] # @todo: check: {} ?
  64. self.__num_decrease_order = 0
  65. self.__num_increase_order = 0
  66. self.__num_converged_descents = 0
  67. self.__state = AlgorithmState.TERMINATED
  68. if ged_env is None:
  69. raise Exception('The GED environment pointer passed to the constructor of MedianGraphEstimator is null.')
  70. elif not ged_env.is_initialized():
  71. raise Exception('The GED environment is uninitialized. Call gedlibpy.GEDEnv.init() before passing it to the constructor of MedianGraphEstimator.')
  72. def set_options(self, options):
  73. """Sets the options of the estimator.
  74. Parameters
  75. ----------
  76. options : string
  77. String that specifies with which options to run the estimator.
  78. """
  79. self.__set_default_options()
  80. options_map = misc.options_string_to_options_map(options)
  81. for opt_name, opt_val in options_map.items():
  82. if opt_name == 'init-type':
  83. self.__init_type = opt_val
  84. if opt_val != 'MEDOID' and opt_val != 'RANDOM' and opt_val != 'MIN' and opt_val != 'MAX' and opt_val != 'MEAN':
  85. raise Exception('Invalid argument ' + opt_val + ' for option init-type. Usage: options = "[--init-type RANDOM|MEDOID|EMPTY|MIN|MAX|MEAN] [...]"')
  86. elif opt_name == 'random-inits':
  87. try:
  88. self.__num_random_inits = int(opt_val)
  89. self.__desired_num_random_inits = self.__num_random_inits
  90. except:
  91. raise Exception('Invalid argument "' + opt_val + '" for option random-inits. Usage: options = "[--random-inits <convertible to int greater 0>]"')
  92. if self.__num_random_inits <= 0:
  93. raise Exception('Invalid argument "' + opt_val + '" for option random-inits. Usage: options = "[--random-inits <convertible to int greater 0>]"')
  94. elif opt_name == 'randomness':
  95. if opt_val == 'PSEUDO':
  96. self.__use_real_randomness = False
  97. elif opt_val == 'REAL':
  98. self.__use_real_randomness = True
  99. else:
  100. raise Exception('Invalid argument "' + opt_val + '" for option randomness. Usage: options = "[--randomness REAL|PSEUDO] [...]"')
  101. elif opt_name == 'stdout':
  102. if opt_val == '0':
  103. self.__print_to_stdout = 0
  104. elif opt_val == '1':
  105. self.__print_to_stdout = 1
  106. elif opt_val == '2':
  107. self.__print_to_stdout = 2
  108. else:
  109. raise Exception('Invalid argument "' + opt_val + '" for option stdout. Usage: options = "[--stdout 0|1|2] [...]"')
  110. elif opt_name == 'refine':
  111. if opt_val == 'TRUE':
  112. self.__refine = True
  113. elif opt_val == 'FALSE':
  114. self.__refine = False
  115. else:
  116. raise Exception('Invalid argument "' + opt_val + '" for option refine. Usage: options = "[--refine TRUE|FALSE] [...]"')
  117. elif opt_name == 'time-limit':
  118. try:
  119. self.__time_limit_in_sec = float(opt_val)
  120. except:
  121. raise Exception('Invalid argument "' + opt_val + '" for option time-limit. Usage: options = "[--time-limit <convertible to double>] [...]')
  122. elif opt_name == 'max-itrs':
  123. try:
  124. self.__max_itrs = int(opt_val)
  125. except:
  126. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs. Usage: options = "[--max-itrs <convertible to int>] [...]')
  127. elif opt_name == 'max-itrs-without-update':
  128. try:
  129. self.__max_itrs_without_update = int(opt_val)
  130. except:
  131. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs-without-update. Usage: options = "[--max-itrs-without-update <convertible to int>] [...]')
  132. elif opt_name == 'seed':
  133. try:
  134. self.__seed = int(opt_val)
  135. except:
  136. raise Exception('Invalid argument "' + opt_val + '" for option seed. Usage: options = "[--seed <convertible to int greater equal 0>] [...]')
  137. elif opt_name == 'epsilon':
  138. try:
  139. self.__epsilon = float(opt_val)
  140. except:
  141. raise Exception('Invalid argument "' + opt_val + '" for option epsilon. Usage: options = "[--epsilon <convertible to double greater 0>] [...]')
  142. if self.__epsilon <= 0:
  143. raise Exception('Invalid argument "' + opt_val + '" for option epsilon. Usage: options = "[--epsilon <convertible to double greater 0>] [...]')
  144. elif opt_name == 'inits-increase-order':
  145. try:
  146. self.__num_inits_increase_order = int(opt_val)
  147. except:
  148. raise Exception('Invalid argument "' + opt_val + '" for option inits-increase-order. Usage: options = "[--inits-increase-order <convertible to int greater 0>]"')
  149. if self.__num_inits_increase_order <= 0:
  150. raise Exception('Invalid argument "' + opt_val + '" for option inits-increase-order. Usage: options = "[--inits-increase-order <convertible to int greater 0>]"')
  151. elif opt_name == 'init-type-increase-order':
  152. self.__init_type_increase_order = opt_val
  153. if opt_val != 'CLUSTERS' and opt_val != 'K-MEANS++':
  154. raise Exception('Invalid argument ' + opt_val + ' for option init-type-increase-order. Usage: options = "[--init-type-increase-order CLUSTERS|K-MEANS++] [...]"')
  155. elif opt_name == 'max-itrs-increase-order':
  156. try:
  157. self.__max_itrs_increase_order = int(opt_val)
  158. except:
  159. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs-increase-order. Usage: options = "[--max-itrs-increase-order <convertible to int>] [...]')
  160. else:
  161. valid_options = '[--init-type <arg>] [--random-inits <arg>] [--randomness <arg>] [--seed <arg>] [--stdout <arg>] '
  162. valid_options += '[--time-limit <arg>] [--max-itrs <arg>] [--epsilon <arg>] '
  163. valid_options += '[--inits-increase-order <arg>] [--init-type-increase-order <arg>] [--max-itrs-increase-order <arg>]'
  164. raise Exception('Invalid option "' + opt_name + '". Usage: options = "' + valid_options + '"')
  165. def set_init_method(self, init_method, init_options=''):
  166. """Selects method to be used for computing the initial medoid graph.
  167. Parameters
  168. ----------
  169. init_method : string
  170. The selected method. Default: ged::Options::GEDMethod::BRANCH_UNIFORM.
  171. init_options : string
  172. The options for the selected method. Default: "".
  173. Notes
  174. -----
  175. Has no effect unless "--init-type MEDOID" is passed to set_options().
  176. """
  177. self.__init_method = init_method;
  178. self.__init_options = init_options;
  179. def set_descent_method(self, descent_method, descent_options=''):
  180. """Selects method to be used for block gradient descent..
  181. Parameters
  182. ----------
  183. descent_method : string
  184. The selected method. Default: ged::Options::GEDMethod::BRANCH_FAST.
  185. descent_options : string
  186. The options for the selected method. Default: "".
  187. Notes
  188. -----
  189. Has no effect unless "--init-type MEDOID" is passed to set_options().
  190. """
  191. self.__descent_method = descent_method;
  192. self.__descent_options = descent_options;
  193. def set_refine_method(self, refine_method, refine_options):
  194. """Selects method to be used for improving the sum of distances and the node maps for the converged median.
  195. Parameters
  196. ----------
  197. refine_method : string
  198. The selected method. Default: "IPFP".
  199. refine_options : string
  200. The options for the selected method. Default: "".
  201. Notes
  202. -----
  203. Has no effect if "--refine FALSE" is passed to set_options().
  204. """
  205. self.__refine_method = refine_method
  206. self.__refine_options = refine_options
  207. def run(self, graph_ids, set_median_id, gen_median_id):
  208. """Computes a generalized median graph.
  209. Parameters
  210. ----------
  211. graph_ids : list[integer]
  212. The IDs of the graphs for which the median should be computed. Must have been added to the environment passed to the constructor.
  213. set_median_id : integer
  214. The ID of the computed set-median. A dummy graph with this ID must have been added to the environment passed to the constructor. Upon termination, the computed median can be obtained via gklearn.gedlib.gedlibpy.GEDEnv.get_graph().
  215. gen_median_id : integer
  216. The ID of the computed generalized median. Upon termination, the computed median can be obtained via gklearn.gedlib.gedlibpy.GEDEnv.get_graph().
  217. """
  218. # Sanity checks.
  219. if len(graph_ids) == 0:
  220. raise Exception('Empty vector of graph IDs, unable to compute median.')
  221. all_graphs_empty = True
  222. for graph_id in graph_ids:
  223. if self.__ged_env.get_graph_num_nodes(graph_id) > 0:
  224. self.__median_node_id_prefix = self.__ged_env.get_original_node_ids(graph_id)[0]
  225. all_graphs_empty = False
  226. break
  227. if all_graphs_empty:
  228. raise Exception('All graphs in the collection are empty.')
  229. # Start timer and record start time.
  230. start = time.time()
  231. timer = Timer(self.__time_limit_in_sec)
  232. self.__median_id = gen_median_id
  233. self.__state = AlgorithmState.TERMINATED
  234. # Get ExchangeGraph representations of the input graphs.
  235. graphs = {}
  236. for graph_id in graph_ids:
  237. # @todo: get_nx_graph() function may need to be modified according to the coming code.
  238. graphs[graph_id] = self.__ged_env.get_nx_graph(graph_id, True, True, False)
  239. # print(self.__ged_env.get_graph_internal_id(0))
  240. # print(graphs[0].graph)
  241. # print(graphs[0].nodes(data=True))
  242. # print(graphs[0].edges(data=True))
  243. # print(nx.adjacency_matrix(graphs[0]))
  244. # Construct initial medians.
  245. medians = []
  246. self.__construct_initial_medians(graph_ids, timer, medians)
  247. end_init = time.time()
  248. self.__runtime_initialized = end_init - start
  249. # print(medians[0].graph)
  250. # print(medians[0].nodes(data=True))
  251. # print(medians[0].edges(data=True))
  252. # print(nx.adjacency_matrix(medians[0]))
  253. # Reset information about iterations and number of times the median decreases and increases.
  254. self.__itrs = [0] * len(medians)
  255. self.__num_decrease_order = 0
  256. self.__num_increase_order = 0
  257. self.__num_converged_descents = 0
  258. # Initialize the best median.
  259. best_sum_of_distances = np.inf
  260. self.__best_init_sum_of_distances = np.inf
  261. node_maps_from_best_median = {}
  262. # Run block gradient descent from all initial medians.
  263. self.__ged_env.set_method(self.__descent_method, self.__descent_options)
  264. for median_pos in range(0, len(medians)):
  265. # Terminate if the timer has expired and at least one SOD has been computed.
  266. if timer.expired() and median_pos > 0:
  267. break
  268. # Print information about current iteration.
  269. if self.__print_to_stdout == 2:
  270. print('\n===========================================================')
  271. print('Block gradient descent for initial median', str(median_pos + 1), 'of', str(len(medians)), '.')
  272. print('-----------------------------------------------------------')
  273. # Get reference to the median.
  274. median = medians[median_pos]
  275. # Load initial median into the environment.
  276. self.__ged_env.load_nx_graph(median, gen_median_id)
  277. self.__ged_env.init(self.__ged_env.get_init_type())
  278. # Print information about current iteration.
  279. if self.__print_to_stdout == 2:
  280. progress = tqdm(desc='\rComputing initial node maps', total=len(graph_ids), file=sys.stdout)
  281. # Compute node maps and sum of distances for initial median.
  282. self.__sum_of_distances = 0
  283. self.__node_maps_from_median.clear() # @todo
  284. for graph_id in graph_ids:
  285. self.__ged_env.run_method(gen_median_id, graph_id)
  286. self.__node_maps_from_median[graph_id] = self.__ged_env.get_node_map(gen_median_id, graph_id)
  287. # print(self.__node_maps_from_median[graph_id])
  288. self.__sum_of_distances += self.__ged_env.get_induced_cost(gen_median_id, graph_id) # @todo: the C++ implementation for this function in GedLibBind.ipp re-call get_node_map() once more, this is not neccessary.
  289. # print(self.__sum_of_distances)
  290. # Print information about current iteration.
  291. if self.__print_to_stdout == 2:
  292. progress.update(1)
  293. self.__best_init_sum_of_distances = min(self.__best_init_sum_of_distances, self.__sum_of_distances)
  294. self.__ged_env.load_nx_graph(median, set_median_id)
  295. # print(self.__best_init_sum_of_distances)
  296. # Print information about current iteration.
  297. if self.__print_to_stdout == 2:
  298. print('\n')
  299. # Run block gradient descent from initial median.
  300. converged = False
  301. itrs_without_update = 0
  302. while not self.__termination_criterion_met(converged, timer, self.__itrs[median_pos], itrs_without_update):
  303. # Print information about current iteration.
  304. if self.__print_to_stdout == 2:
  305. print('\n===========================================================')
  306. print('Iteration', str(self.__itrs[median_pos] + 1), 'for initial median', str(median_pos + 1), 'of', str(len(medians)), '.')
  307. print('-----------------------------------------------------------')
  308. # Initialize flags that tell us what happened in the iteration.
  309. median_modified = False
  310. node_maps_modified = False
  311. decreased_order = False
  312. increased_order = False
  313. # Update the median. # @todo!!!!!!!!!!!!!!!!!!!!!!
  314. median_modified = self.__update_median(graphs, median)
  315. if not median_modified or self.__itrs[median_pos] == 0:
  316. decreased_order = False
  317. if not decreased_order or self.__itrs[median_pos] == 0:
  318. increased_order = False
  319. # Update the number of iterations without update of the median.
  320. if median_modified or decreased_order or increased_order:
  321. itrs_without_update = 0
  322. else:
  323. itrs_without_update += 1
  324. # Print information about current iteration.
  325. if self.__print_to_stdout == 2:
  326. print('Loading median to environment: ... ', end='')
  327. # Load the median into the environment.
  328. # @todo: should this function use the original node label?
  329. self.__ged_env.load_nx_graph(median, gen_median_id)
  330. self.__ged_env.init(self.__ged_env.get_init_type())
  331. # Print information about current iteration.
  332. if self.__print_to_stdout == 2:
  333. print('done.')
  334. # Print information about current iteration.
  335. if self.__print_to_stdout == 2:
  336. print('Updating induced costs: ... ', end='')
  337. # Compute induced costs of the old node maps w.r.t. the updated median.
  338. for graph_id in graph_ids:
  339. # print(self.__ged_env.get_induced_cost(gen_median_id, graph_id))
  340. # @todo: watch out if compute_induced_cost is correct, this may influence: increase/decrease order, induced_cost() in the following code.!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  341. self.__ged_env.compute_induced_cost(gen_median_id, graph_id)
  342. # print('---------------------------------------')
  343. # print(self.__ged_env.get_induced_cost(gen_median_id, graph_id))
  344. # Print information about current iteration.
  345. if self.__print_to_stdout == 2:
  346. print('done.')
  347. # Update the node maps.
  348. node_maps_modified = self.__update_node_maps() # @todo
  349. # Update the order of the median if no improvement can be found with the current order.
  350. # Update the sum of distances.
  351. old_sum_of_distances = self.__sum_of_distances
  352. self.__sum_of_distances = 0
  353. for graph_id in self.__node_maps_from_median:
  354. self.__sum_of_distances += self.__ged_env.get_induced_cost(gen_median_id, graph_id) # @todo: see above.
  355. # Print information about current iteration.
  356. if self.__print_to_stdout == 2:
  357. print('Old local SOD: ', old_sum_of_distances)
  358. print('New local SOD: ', self.__sum_of_distances)
  359. print('Best converged SOD: ', best_sum_of_distances)
  360. print('Modified median: ', median_modified)
  361. print('Modified node maps: ', node_maps_modified)
  362. print('Decreased order: ', decreased_order)
  363. print('Increased order: ', increased_order)
  364. print('===========================================================\n')
  365. converged = not (median_modified or node_maps_modified or decreased_order or increased_order)
  366. self.__itrs[median_pos] += 1
  367. # Update the best median.
  368. if self.__sum_of_distances < self.__best_init_sum_of_distances:
  369. best_sum_of_distances = self.__sum_of_distances
  370. node_maps_from_best_median = self.__node_maps_from_median
  371. best_median = median
  372. # Update the number of converged descents.
  373. if converged:
  374. self.__num_converged_descents += 1
  375. # Store the best encountered median.
  376. self.__sum_of_distances = best_sum_of_distances
  377. self.__node_maps_from_median = node_maps_from_best_median
  378. self.__ged_env.load_nx_graph(best_median, gen_median_id)
  379. self.__ged_env.init(self.__ged_env.get_init_type())
  380. end_descent = time.time()
  381. self.__runtime_converged = end_descent - start
  382. # Refine the sum of distances and the node maps for the converged median.
  383. self.__converged_sum_of_distances = self.__sum_of_distances
  384. if self.__refine:
  385. self.__improve_sum_of_distances(timer) # @todo
  386. # Record end time, set runtime and reset the number of initial medians.
  387. end = time.time()
  388. self.__runtime = end - start
  389. self.__num_random_inits = self.__desired_num_random_inits
  390. # Print global information.
  391. if self.__print_to_stdout != 0:
  392. print('\n===========================================================')
  393. print('Finished computation of generalized median graph.')
  394. print('-----------------------------------------------------------')
  395. print('Best SOD after initialization: ', self.__best_init_sum_of_distances)
  396. print('Converged SOD: ', self.__converged_sum_of_distances)
  397. if self.__refine:
  398. print('Refined SOD: ', self.__sum_of_distances)
  399. print('Overall runtime: ', self.__runtime)
  400. print('Runtime of initialization: ', self.__runtime_initialized)
  401. print('Runtime of block gradient descent: ', self.__runtime_converged - self.__runtime_initialized)
  402. if self.__refine:
  403. print('Runtime of refinement: ', self.__runtime - self.__runtime_converged)
  404. print('Number of initial medians: ', len(medians))
  405. total_itr = 0
  406. num_started_descents = 0
  407. for itr in self.__itrs:
  408. total_itr += itr
  409. if itr > 0:
  410. num_started_descents += 1
  411. print('Size of graph collection: ', len(graph_ids))
  412. print('Number of started descents: ', num_started_descents)
  413. print('Number of converged descents: ', self.__num_converged_descents)
  414. print('Overall number of iterations: ', total_itr)
  415. print('Overall number of times the order decreased: ', self.__num_decrease_order)
  416. print('Overall number of times the order increased: ', self.__num_increase_order)
  417. print('===========================================================\n')
  418. def get_sum_of_distances(self, state=''):
  419. """Returns the sum of distances.
  420. Parameters
  421. ----------
  422. state : string
  423. The state of the estimator. Can be 'initialized' or 'converged'. Default: ""
  424. Returns
  425. -------
  426. float
  427. The sum of distances (SOD) of the median when the estimator was in the state `state` during the last call to run(). If `state` is not given, the converged SOD (without refinement) or refined SOD (with refinement) is returned.
  428. """
  429. if not self.__median_available():
  430. raise Exception('No median has been computed. Call run() before calling get_sum_of_distances().')
  431. if state == 'initialized':
  432. return self.__best_init_sum_of_distances
  433. if state == 'converged':
  434. return self.__converged_sum_of_distances
  435. return self.__sum_of_distances
  436. def __set_default_options(self):
  437. self.__init_type = 'RANDOM'
  438. self.__num_random_inits = 10
  439. self.__desired_num_random_inits = 10
  440. self.__use_real_randomness = True
  441. self.__seed = 0
  442. self.__refine = True
  443. self.__time_limit_in_sec = 0
  444. self.__epsilon = 0.0001
  445. self.__max_itrs = 100
  446. self.__max_itrs_without_update = 3
  447. self.__num_inits_increase_order = 10
  448. self.__init_type_increase_order = 'K-MEANS++'
  449. self.__max_itrs_increase_order = 10
  450. self.__print_to_stdout = 2
  451. def __construct_initial_medians(self, graph_ids, timer, initial_medians):
  452. # Print information about current iteration.
  453. if self.__print_to_stdout == 2:
  454. print('\n===========================================================')
  455. print('Constructing initial median(s).')
  456. print('-----------------------------------------------------------')
  457. # Compute or sample the initial median(s).
  458. initial_medians.clear()
  459. if self.__init_type == 'MEDOID':
  460. self.__compute_medoid(graph_ids, timer, initial_medians)
  461. elif self.__init_type == 'MAX':
  462. pass # @todo
  463. # compute_max_order_graph_(graph_ids, initial_medians)
  464. elif self.__init_type == 'MIN':
  465. pass # @todo
  466. # compute_min_order_graph_(graph_ids, initial_medians)
  467. elif self.__init_type == 'MEAN':
  468. pass # @todo
  469. # compute_mean_order_graph_(graph_ids, initial_medians)
  470. else:
  471. pass # @todo
  472. # sample_initial_medians_(graph_ids, initial_medians)
  473. # Print information about current iteration.
  474. if self.__print_to_stdout == 2:
  475. print('===========================================================')
  476. def __compute_medoid(self, graph_ids, timer, initial_medians):
  477. # Use method selected for initialization phase.
  478. self.__ged_env.set_method(self.__init_method, self.__init_options)
  479. # Print information about current iteration.
  480. if self.__print_to_stdout == 2:
  481. progress = tqdm(desc='\rComputing medoid', total=len(graph_ids), file=sys.stdout)
  482. # Compute the medoid.
  483. medoid_id = graph_ids[0]
  484. best_sum_of_distances = np.inf
  485. for g_id in graph_ids:
  486. if timer.expired():
  487. self.__state = AlgorithmState.CALLED
  488. break
  489. sum_of_distances = 0
  490. for h_id in graph_ids:
  491. self.__ged_env.run_method(g_id, h_id)
  492. sum_of_distances += self.__ged_env.get_upper_bound(g_id, h_id)
  493. if sum_of_distances < best_sum_of_distances:
  494. best_sum_of_distances = sum_of_distances
  495. medoid_id = g_id
  496. # Print information about current iteration.
  497. if self.__print_to_stdout == 2:
  498. progress.update(1)
  499. initial_medians.append(self.__ged_env.get_nx_graph(medoid_id, True, True, False)) # @todo
  500. # Print information about current iteration.
  501. if self.__print_to_stdout == 2:
  502. print('\n')
  503. def __termination_criterion_met(self, converged, timer, itr, itrs_without_update):
  504. if timer.expired() or (itr >= self.__max_itrs if self.__max_itrs >= 0 else False):
  505. if self.__state == AlgorithmState.TERMINATED:
  506. self.__state = AlgorithmState.INITIALIZED
  507. return True
  508. return converged or (itrs_without_update > self.__max_itrs_without_update if self.__max_itrs_without_update >= 0 else False)
  509. def __update_median(self, graphs, median):
  510. # Print information about current iteration.
  511. if self.__print_to_stdout == 2:
  512. print('Updating median: ', end='')
  513. # Store copy of the old median.
  514. old_median = median.copy() # @todo: this is just a shallow copy.
  515. # Update the node labels.
  516. if self.__labeled_nodes:
  517. self.__update_node_labels(graphs, median)
  518. # Update the edges and their labels.
  519. self.__update_edges(graphs, median)
  520. # Print information about current iteration.
  521. if self.__print_to_stdout == 2:
  522. print('done.')
  523. return not self.__are_graphs_equal(median, old_median)
  524. def __update_node_labels(self, graphs, median):
  525. # Print information about current iteration.
  526. if self.__print_to_stdout == 2:
  527. print('nodes ... ', end='')
  528. # Iterate through all nodes of the median.
  529. for i in range(0, nx.number_of_nodes(median)):
  530. # print('i: ', i)
  531. # Collect the labels of the substituted nodes.
  532. node_labels = []
  533. for graph_id, graph in graphs.items():
  534. # print('graph_id: ', graph_id)
  535. # print(self.__node_maps_from_median[graph_id])
  536. k = self.__get_node_image_from_map(self.__node_maps_from_median[graph_id], i)
  537. # print('k: ', k)
  538. if k != np.inf:
  539. node_labels.append(graph.nodes[k])
  540. # Compute the median label and update the median.
  541. if len(node_labels) > 0:
  542. median_label = self.__ged_env.get_median_node_label(node_labels)
  543. if self.__ged_env.get_node_rel_cost(median.nodes[i], median_label) > self.__epsilon:
  544. nx.set_node_attributes(median, {i: median_label})
  545. def __update_edges(self, graphs, median):
  546. # Print information about current iteration.
  547. if self.__print_to_stdout == 2:
  548. print('edges ... ', end='')
  549. # Clear the adjacency lists of the median and reset number of edges to 0.
  550. median_edges = list(median.edges)
  551. for (head, tail) in median_edges:
  552. median.remove_edge(head, tail)
  553. # @todo: what if edge is not labeled?
  554. # Iterate through all possible edges (i,j) of the median.
  555. for i in range(0, nx.number_of_nodes(median)):
  556. for j in range(i + 1, nx.number_of_nodes(median)):
  557. # Collect the labels of the edges to which (i,j) is mapped by the node maps.
  558. edge_labels = []
  559. for graph_id, graph in graphs.items():
  560. k = self.__get_node_image_from_map(self.__node_maps_from_median[graph_id], i)
  561. l = self.__get_node_image_from_map(self.__node_maps_from_median[graph_id], j)
  562. if k != np.inf and l != np.inf:
  563. if graph.has_edge(k, l):
  564. edge_labels.append(graph.edges[(k, l)])
  565. # Compute the median edge label and the overall edge relabeling cost.
  566. rel_cost = 0
  567. median_label = self.__ged_env.get_edge_label(1)
  568. if median.has_edge(i, j):
  569. median_label = median.edges[(i, j)]
  570. if self.__labeled_edges and len(edge_labels) > 0:
  571. new_median_label = self.__ged_env.median_edge_label(edge_labels)
  572. if self.__ged_env.get_edge_rel_cost(median_label, new_median_label) > self.__epsilon:
  573. median_label = new_median_label
  574. for edge_label in edge_labels:
  575. rel_cost += self.__ged_env.get_edge_rel_cost(median_label, edge_label)
  576. # Update the median.
  577. if rel_cost < (self.__edge_ins_cost + self.__edge_del_cost) * len(edge_labels) - self.__edge_del_cost * len(graphs):
  578. median.add_edge(i, j, **median_label)
  579. else:
  580. if median.has_edge(i, j):
  581. median.remove_edge(i, j)
  582. def __update_node_maps(self):
  583. # Print information about current iteration.
  584. if self.__print_to_stdout == 2:
  585. progress = tqdm(desc='\rUpdating node maps', total=len(self.__node_maps_from_median), file=sys.stdout)
  586. # Update the node maps.
  587. node_maps_were_modified = False
  588. for graph_id in self.__node_maps_from_median:
  589. self.__ged_env.run_method(self.__median_id, graph_id)
  590. if self.__ged_env.get_upper_bound(self.__median_id, graph_id) < self.__ged_env.get_induced_cost(self.__median_id, graph_id) - self.__epsilon: # @todo: see above.
  591. self.__node_maps_from_median[graph_id] = self.__ged_env.get_node_map(self.__median_id, graph_id) # @todo: node_map may not assigned.
  592. node_maps_were_modified = True
  593. # Print information about current iteration.
  594. if self.__print_to_stdout == 2:
  595. progress.update(1)
  596. # Print information about current iteration.
  597. if self.__print_to_stdout == 2:
  598. print('\n')
  599. # Return true if the node maps were modified.
  600. return node_maps_were_modified
  601. def __improve_sum_of_distances(self, timer):
  602. pass
  603. def __median_available(self):
  604. return self.__median_id != np.inf
  605. def __get_node_image_from_map(self, node_map, node):
  606. """
  607. Return ID of the node mapping of `node` in `node_map`.
  608. Parameters
  609. ----------
  610. node_map : list[tuple(int, int)]
  611. List of node maps where the mapping node is found.
  612. node : int
  613. The mapping node of this node is returned
  614. Raises
  615. ------
  616. Exception
  617. If the node with ID `node` is not contained in the source nodes of the node map.
  618. Returns
  619. -------
  620. int
  621. ID of the mapping of `node`.
  622. Notes
  623. -----
  624. This function is not implemented in the `ged::MedianGraphEstimator` class of the `GEDLIB` library. Instead it is a Python implementation of the `ged::NodeMap::image` function.
  625. """
  626. if node < len(node_map):
  627. return node_map[node][1] if node_map[node][1] < len(node_map) else np.inf
  628. else:
  629. raise Exception('The node with ID ', str(node), ' is not contained in the source nodes of the node map.')
  630. return np.inf
  631. def __are_graphs_equal(self, g1, g2):
  632. """
  633. Check if the two graphs are equal.
  634. Parameters
  635. ----------
  636. g1 : NetworkX graph object
  637. Graph 1 to be compared.
  638. g2 : NetworkX graph object
  639. Graph 2 to be compared.
  640. Returns
  641. -------
  642. bool
  643. True if the two graph are equal.
  644. Notes
  645. -----
  646. This is not an identical check. Here the two graphs are equal if and only if their original_node_ids, nodes, all node labels, edges and all edge labels are equal. This function is specifically designed for class `MedianGraphEstimator` and should not be used elsewhere.
  647. """
  648. # check original node ids.
  649. if not g1.graph['original_node_ids'] == g2.graph['original_node_ids']:
  650. return False
  651. # check nodes.
  652. nlist1 = [n for n in g1.nodes(data=True)]
  653. nlist2 = [n for n in g2.nodes(data=True)]
  654. if not nlist1 == nlist2:
  655. return False
  656. # check edges.
  657. elist1 = [n for n in g1.edges(data=True)]
  658. elist2 = [n for n in g2.edges(data=True)]
  659. if not elist1 == elist2:
  660. return False
  661. return True
  662. def compute_my_cost(g, h, node_map):
  663. cost = 0.0
  664. for node in g.nodes:
  665. cost += 0

A Python package for graph kernels, graph edit distances and graph pre-image problem.