You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

median_graph_estimator.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 16 18:04:55 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. from gklearn.ged.env import AlgorithmState
  9. from gklearn.ged.util import misc
  10. from gklearn.utils import Timer
  11. import time
  12. from tqdm import tqdm
  13. import sys
  14. import networkx as nx
  15. class MedianGraphEstimator(object):
  16. def __init__(self, ged_env, constant_node_costs):
  17. """Constructor.
  18. Parameters
  19. ----------
  20. ged_env : gklearn.gedlib.gedlibpy.GEDEnv
  21. Initialized GED environment. The edit costs must be set by the user.
  22. constant_node_costs : Boolean
  23. Set to True if the node relabeling costs are constant.
  24. """
  25. self.__ged_env = ged_env
  26. self.__init_method = 'BRANCH_FAST'
  27. self.__init_options = ''
  28. self.__descent_method = 'BRANCH_FAST'
  29. self.__descent_options = ''
  30. self.__refine_method = 'IPFP'
  31. self.__refine_options = ''
  32. self.__constant_node_costs = constant_node_costs
  33. self.__labeled_nodes = (ged_env.get_num_node_labels() > 1)
  34. self.__node_del_cost = ged_env.get_node_del_cost(ged_env.get_node_label(1))
  35. self.__node_ins_cost = ged_env.get_node_ins_cost(ged_env.get_node_label(1))
  36. self.__labeled_edges = (ged_env.get_num_edge_labels() > 1)
  37. self.__edge_del_cost = ged_env.get_edge_del_cost(ged_env.get_edge_label(1))
  38. self.__edge_ins_cost = ged_env.get_edge_ins_cost(ged_env.get_edge_label(1))
  39. self.__init_type = 'RANDOM'
  40. self.__num_random_inits = 10
  41. self.__desired_num_random_inits = 10
  42. self.__use_real_randomness = True
  43. self.__seed = 0
  44. self.__refine = True
  45. self.__time_limit_in_sec = 0
  46. self.__epsilon = 0.0001
  47. self.__max_itrs = 100
  48. self.__max_itrs_without_update = 3
  49. self.__num_inits_increase_order = 10
  50. self.__init_type_increase_order = 'K-MEANS++'
  51. self.__max_itrs_increase_order = 10
  52. self.__print_to_stdout = 2
  53. self.__median_id = np.inf # @todo: check
  54. self.__median_node_id_prefix = '' # @todo: check
  55. self.__node_maps_from_median = {}
  56. self.__sum_of_distances = 0
  57. self.__best_init_sum_of_distances = np.inf
  58. self.__converged_sum_of_distances = np.inf
  59. self.__runtime = None
  60. self.__runtime_initialized = None
  61. self.__runtime_converged = None
  62. self.__itrs = [] # @todo: check: {} ?
  63. self.__num_decrease_order = 0
  64. self.__num_increase_order = 0
  65. self.__num_converged_descents = 0
  66. self.__state = AlgorithmState.TERMINATED
  67. if ged_env is None:
  68. raise Exception('The GED environment pointer passed to the constructor of MedianGraphEstimator is null.')
  69. elif not ged_env.is_initialized():
  70. raise Exception('The GED environment is uninitialized. Call gedlibpy.GEDEnv.init() before passing it to the constructor of MedianGraphEstimator.')
  71. def set_options(self, options):
  72. """Sets the options of the estimator.
  73. Parameters
  74. ----------
  75. options : string
  76. String that specifies with which options to run the estimator.
  77. """
  78. self.__set_default_options()
  79. options_map = misc.options_string_to_options_map(options)
  80. for opt_name, opt_val in options_map.items():
  81. if opt_name == 'init-type':
  82. self.__init_type = opt_val
  83. if opt_val != 'MEDOID' and opt_val != 'RANDOM' and opt_val != 'MIN' and opt_val != 'MAX' and opt_val != 'MEAN':
  84. raise Exception('Invalid argument ' + opt_val + ' for option init-type. Usage: options = "[--init-type RANDOM|MEDOID|EMPTY|MIN|MAX|MEAN] [...]"')
  85. elif opt_name == 'random-inits':
  86. try:
  87. self.__num_random_inits = int(opt_val)
  88. self.__desired_num_random_inits = self.__num_random_inits
  89. except:
  90. raise Exception('Invalid argument "' + opt_val + '" for option random-inits. Usage: options = "[--random-inits <convertible to int greater 0>]"')
  91. if self.__num_random_inits <= 0:
  92. raise Exception('Invalid argument "' + opt_val + '" for option random-inits. Usage: options = "[--random-inits <convertible to int greater 0>]"')
  93. elif opt_name == 'randomness':
  94. if opt_val == 'PSEUDO':
  95. self.__use_real_randomness = False
  96. elif opt_val == 'REAL':
  97. self.__use_real_randomness = True
  98. else:
  99. raise Exception('Invalid argument "' + opt_val + '" for option randomness. Usage: options = "[--randomness REAL|PSEUDO] [...]"')
  100. elif opt_name == 'stdout':
  101. if opt_val == '0':
  102. self.__print_to_stdout = 0
  103. elif opt_val == '1':
  104. self.__print_to_stdout = 1
  105. elif opt_val == '2':
  106. self.__print_to_stdout = 2
  107. else:
  108. raise Exception('Invalid argument "' + opt_val + '" for option stdout. Usage: options = "[--stdout 0|1|2] [...]"')
  109. elif opt_name == 'refine':
  110. if opt_val == 'TRUE':
  111. self.__refine = True
  112. elif opt_val == 'FALSE':
  113. self.__refine = False
  114. else:
  115. raise Exception('Invalid argument "' + opt_val + '" for option refine. Usage: options = "[--refine TRUE|FALSE] [...]"')
  116. elif opt_name == 'time-limit':
  117. try:
  118. self.__time_limit_in_sec = float(opt_val)
  119. except:
  120. raise Exception('Invalid argument "' + opt_val + '" for option time-limit. Usage: options = "[--time-limit <convertible to double>] [...]')
  121. elif opt_name == 'max-itrs':
  122. try:
  123. self.__max_itrs = int(opt_val)
  124. except:
  125. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs. Usage: options = "[--max-itrs <convertible to int>] [...]')
  126. elif opt_name == 'max-itrs-without-update':
  127. try:
  128. self.__max_itrs_without_update = int(opt_val)
  129. except:
  130. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs-without-update. Usage: options = "[--max-itrs-without-update <convertible to int>] [...]')
  131. elif opt_name == 'seed':
  132. try:
  133. self.__seed = int(opt_val)
  134. except:
  135. raise Exception('Invalid argument "' + opt_val + '" for option seed. Usage: options = "[--seed <convertible to int greater equal 0>] [...]')
  136. elif opt_name == 'epsilon':
  137. try:
  138. self.__epsilon = float(opt_val)
  139. except:
  140. raise Exception('Invalid argument "' + opt_val + '" for option epsilon. Usage: options = "[--epsilon <convertible to double greater 0>] [...]')
  141. if self.__epsilon <= 0:
  142. raise Exception('Invalid argument "' + opt_val + '" for option epsilon. Usage: options = "[--epsilon <convertible to double greater 0>] [...]')
  143. elif opt_name == 'inits-increase-order':
  144. try:
  145. self.__num_inits_increase_order = int(opt_val)
  146. except:
  147. raise Exception('Invalid argument "' + opt_val + '" for option inits-increase-order. Usage: options = "[--inits-increase-order <convertible to int greater 0>]"')
  148. if self.__num_inits_increase_order <= 0:
  149. raise Exception('Invalid argument "' + opt_val + '" for option inits-increase-order. Usage: options = "[--inits-increase-order <convertible to int greater 0>]"')
  150. elif opt_name == 'init-type-increase-order':
  151. self.__init_type_increase_order = opt_val
  152. if opt_val != 'CLUSTERS' and opt_val != 'K-MEANS++':
  153. raise Exception('Invalid argument ' + opt_val + ' for option init-type-increase-order. Usage: options = "[--init-type-increase-order CLUSTERS|K-MEANS++] [...]"')
  154. elif opt_name == 'max-itrs-increase-order':
  155. try:
  156. self.__max_itrs_increase_order = int(opt_val)
  157. except:
  158. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs-increase-order. Usage: options = "[--max-itrs-increase-order <convertible to int>] [...]')
  159. else:
  160. valid_options = '[--init-type <arg>] [--random-inits <arg>] [--randomness <arg>] [--seed <arg>] [--stdout <arg>] '
  161. valid_options += '[--time-limit <arg>] [--max-itrs <arg>] [--epsilon <arg>] '
  162. valid_options += '[--inits-increase-order <arg>] [--init-type-increase-order <arg>] [--max-itrs-increase-order <arg>]'
  163. raise Exception('Invalid option "' + opt_name + '". Usage: options = "' + valid_options + '"')
  164. def set_init_method(self, init_method, init_options=''):
  165. """Selects method to be used for computing the initial medoid graph.
  166. Parameters
  167. ----------
  168. init_method : string
  169. The selected method. Default: ged::Options::GEDMethod::BRANCH_UNIFORM.
  170. init_options : string
  171. The options for the selected method. Default: "".
  172. Notes
  173. -----
  174. Has no effect unless "--init-type MEDOID" is passed to set_options().
  175. """
  176. self.__init_method = init_method;
  177. self.__init_options = init_options;
  178. def set_descent_method(self, descent_method, descent_options=''):
  179. """Selects method to be used for block gradient descent..
  180. Parameters
  181. ----------
  182. descent_method : string
  183. The selected method. Default: ged::Options::GEDMethod::BRANCH_FAST.
  184. descent_options : string
  185. The options for the selected method. Default: "".
  186. Notes
  187. -----
  188. Has no effect unless "--init-type MEDOID" is passed to set_options().
  189. """
  190. self.__descent_method = descent_method;
  191. self.__descent_options = descent_options;
  192. def set_refine_method(self, refine_method, refine_options):
  193. """Selects method to be used for improving the sum of distances and the node maps for the converged median.
  194. Parameters
  195. ----------
  196. refine_method : string
  197. The selected method. Default: "IPFP".
  198. refine_options : string
  199. The options for the selected method. Default: "".
  200. Notes
  201. -----
  202. Has no effect if "--refine FALSE" is passed to set_options().
  203. """
  204. self.__refine_method = refine_method
  205. self.__refine_options = refine_options
  206. def run(self, graph_ids, set_median_id, gen_median_id):
  207. """Computes a generalized median graph.
  208. Parameters
  209. ----------
  210. graph_ids : list[integer]
  211. The IDs of the graphs for which the median should be computed. Must have been added to the environment passed to the constructor.
  212. set_median_id : integer
  213. The ID of the computed set-median. A dummy graph with this ID must have been added to the environment passed to the constructor. Upon termination, the computed median can be obtained via gklearn.gedlib.gedlibpy.GEDEnv.get_graph().
  214. gen_median_id : integer
  215. The ID of the computed generalized median. Upon termination, the computed median can be obtained via gklearn.gedlib.gedlibpy.GEDEnv.get_graph().
  216. """
  217. # Sanity checks.
  218. if len(graph_ids) == 0:
  219. raise Exception('Empty vector of graph IDs, unable to compute median.')
  220. all_graphs_empty = True
  221. for graph_id in graph_ids:
  222. if self.__ged_env.get_graph_num_nodes(graph_id) > 0:
  223. self.__median_node_id_prefix = self.__ged_env.get_original_node_ids(graph_id)[0]
  224. all_graphs_empty = False
  225. break
  226. if all_graphs_empty:
  227. raise Exception('All graphs in the collection are empty.')
  228. # Start timer and record start time.
  229. start = time.time()
  230. timer = Timer(self.__time_limit_in_sec)
  231. self.__median_id = gen_median_id
  232. self.__state = AlgorithmState.TERMINATED
  233. # Get ExchangeGraph representations of the input graphs.
  234. graphs = {}
  235. for graph_id in graph_ids:
  236. # @todo: get_nx_graph() function may need to be modified according to the coming code.
  237. graphs[graph_id] = self.__ged_env.get_nx_graph(graph_id, True, True, False)
  238. # print(self.__ged_env.get_graph_internal_id(0))
  239. # print(graphs[0].graph)
  240. # print(graphs[0].nodes(data=True))
  241. # print(graphs[0].edges(data=True))
  242. # print(nx.adjacency_matrix(graphs[0]))
  243. # Construct initial medians.
  244. medians = []
  245. self.__construct_initial_medians(graph_ids, timer, medians)
  246. end_init = time.time()
  247. self.__runtime_initialized = end_init - start
  248. # print(medians[0].graph)
  249. # print(medians[0].nodes(data=True))
  250. # print(medians[0].edges(data=True))
  251. # print(nx.adjacency_matrix(medians[0]))
  252. # Reset information about iterations and number of times the median decreases and increases.
  253. self.__itrs = [0] * len(medians)
  254. self.__num_decrease_order = 0
  255. self.__num_increase_order = 0
  256. self.__num_converged_descents = 0
  257. # Initialize the best median.
  258. best_sum_of_distances = np.inf
  259. self.__best_init_sum_of_distances = np.inf
  260. node_maps_from_best_median = {}
  261. # Run block gradient descent from all initial medians.
  262. self.__ged_env.set_method(self.__descent_method, self.__descent_options)
  263. for median_pos in range(0, len(medians)):
  264. # Terminate if the timer has expired and at least one SOD has been computed.
  265. if timer.expired() and median_pos > 0:
  266. break
  267. # Print information about current iteration.
  268. if self.__print_to_stdout == 2:
  269. print('\n===========================================================')
  270. print('Block gradient descent for initial median', str(median_pos + 1), 'of', str(len(medians)), '.')
  271. print('-----------------------------------------------------------')
  272. # Get reference to the median.
  273. median = medians[median_pos]
  274. # Load initial median into the environment.
  275. self.__ged_env.load_nx_graph(median, gen_median_id)
  276. self.__ged_env.init(self.__ged_env.get_init_type())
  277. # Print information about current iteration.
  278. if self.__print_to_stdout == 2:
  279. progress = tqdm(desc='\rComputing initial node maps', total=len(graph_ids), file=sys.stdout)
  280. # Compute node maps and sum of distances for initial median.
  281. self.__sum_of_distances = 0
  282. self.__node_maps_from_median.clear() # @todo
  283. for graph_id in graph_ids:
  284. self.__ged_env.run_method(gen_median_id, graph_id)
  285. self.__node_maps_from_median[graph_id] = self.__ged_env.get_node_map(gen_median_id, graph_id)
  286. # print(self.__node_maps_from_median[graph_id])
  287. self.__sum_of_distances += self.__ged_env.get_induced_cost(gen_median_id, graph_id) # @todo: the C++ implementation for this function in GedLibBind.ipp re-call get_node_map() once more, this is not neccessary.
  288. # print(self.__sum_of_distances)
  289. # Print information about current iteration.
  290. if self.__print_to_stdout == 2:
  291. progress.update(1)
  292. self.__best_init_sum_of_distances = min(self.__best_init_sum_of_distances, self.__sum_of_distances)
  293. self.__ged_env.load_nx_graph(median, set_median_id)
  294. # print(self.__best_init_sum_of_distances)
  295. # Print information about current iteration.
  296. if self.__print_to_stdout == 2:
  297. print('\n')
  298. # Run block gradient descent from initial median.
  299. converged = False
  300. itrs_without_update = 0
  301. while not self.__termination_criterion_met(converged, timer, self.__itrs[median_pos], itrs_without_update):
  302. # Print information about current iteration.
  303. if self.__print_to_stdout == 2:
  304. print('\n===========================================================')
  305. print('Iteration', str(self.__itrs[median_pos] + 1), 'for initial median', str(median_pos + 1), 'of', str(len(medians)), '.')
  306. print('-----------------------------------------------------------')
  307. # Initialize flags that tell us what happened in the iteration.
  308. median_modified = False
  309. node_maps_modified = False
  310. decreased_order = False
  311. increased_order = False
  312. # Update the median. # @todo!!!!!!!!!!!!!!!!!!!!!!
  313. median_modified = self.__update_median(graphs, median)
  314. if not median_modified or self.__itrs[median_pos] == 0:
  315. decreased_order = False
  316. if not decreased_order or self.__itrs[median_pos] == 0:
  317. increased_order = False
  318. # Update the number of iterations without update of the median.
  319. if median_modified or decreased_order or increased_order:
  320. itrs_without_update = 0
  321. else:
  322. itrs_without_update += 1
  323. # Print information about current iteration.
  324. if self.__print_to_stdout == 2:
  325. print('Loading median to environment: ... ', end='')
  326. # Load the median into the environment.
  327. # @todo: should this function use the original node label?
  328. self.__ged_env.load_nx_graph(median, gen_median_id)
  329. self.__ged_env.init(self.__ged_env.get_init_type())
  330. # Print information about current iteration.
  331. if self.__print_to_stdout == 2:
  332. print('done.')
  333. # Print information about current iteration.
  334. if self.__print_to_stdout == 2:
  335. print('Updating induced costs: ... ', end='')
  336. # Compute induced costs of the old node maps w.r.t. the updated median.
  337. for graph_id in graph_ids:
  338. # print(self.__ged_env.get_induced_cost(gen_median_id, graph_id))
  339. # @todo: watch out if compute_induced_cost is correct, this may influence: increase/decrease order, induced_cost() in the following code.!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
  340. self.__ged_env.compute_induced_cost(gen_median_id, graph_id)
  341. # print('---------------------------------------')
  342. # print(self.__ged_env.get_induced_cost(gen_median_id, graph_id))
  343. # Print information about current iteration.
  344. if self.__print_to_stdout == 2:
  345. print('done.')
  346. # Update the node maps.
  347. node_maps_modified = self.__update_node_maps() # @todo
  348. # Update the order of the median if no improvement can be found with the current order.
  349. # Update the sum of distances.
  350. old_sum_of_distances = self.__sum_of_distances
  351. self.__sum_of_distances = 0
  352. for graph_id in self.__node_maps_from_median:
  353. self.__sum_of_distances += self.__ged_env.get_induced_cost(gen_median_id, graph_id) # @todo: see above.
  354. # Print information about current iteration.
  355. if self.__print_to_stdout == 2:
  356. print('Old local SOD: ', old_sum_of_distances)
  357. print('New local SOD: ', self.__sum_of_distances)
  358. print('Best converged SOD: ', best_sum_of_distances)
  359. print('Modified median: ', median_modified)
  360. print('Modified node maps: ', node_maps_modified)
  361. print('Decreased order: ', decreased_order)
  362. print('Increased order: ', increased_order)
  363. print('===========================================================\n')
  364. converged = not (median_modified or node_maps_modified or decreased_order or increased_order)
  365. self.__itrs[median_pos] += 1
  366. # Update the best median.
  367. if self.__sum_of_distances < self.__best_init_sum_of_distances:
  368. best_sum_of_distances = self.__sum_of_distances
  369. node_maps_from_best_median = self.__node_maps_from_median
  370. best_median = median
  371. # Update the number of converged descents.
  372. if converged:
  373. self.__num_converged_descents += 1
  374. # Store the best encountered median.
  375. self.__sum_of_distances = best_sum_of_distances
  376. self.__node_maps_from_median = node_maps_from_best_median
  377. self.__ged_env.load_nx_graph(best_median, gen_median_id)
  378. self.__ged_env.init(self.__ged_env.get_init_type())
  379. end_descent = time.time()
  380. self.__runtime_converged = end_descent - start
  381. # Refine the sum of distances and the node maps for the converged median.
  382. self.__converged_sum_of_distances = self.__sum_of_distances
  383. if self.__refine:
  384. self.__improve_sum_of_distances(timer) # @todo
  385. # Record end time, set runtime and reset the number of initial medians.
  386. end = time.time()
  387. self.__runtime = end - start
  388. self.__num_random_inits = self.__desired_num_random_inits
  389. # Print global information.
  390. if self.__print_to_stdout != 0:
  391. print('\n===========================================================')
  392. print('Finished computation of generalized median graph.')
  393. print('-----------------------------------------------------------')
  394. print('Best SOD after initialization: ', self.__best_init_sum_of_distances)
  395. print('Converged SOD: ', self.__converged_sum_of_distances)
  396. if self.__refine:
  397. print('Refined SOD: ', self.__sum_of_distances)
  398. print('Overall runtime: ', self.__runtime)
  399. print('Runtime of initialization: ', self.__runtime_initialized)
  400. print('Runtime of block gradient descent: ', self.__runtime_converged - self.__runtime_initialized)
  401. if self.__refine:
  402. print('Runtime of refinement: ', self.__runtime - self.__runtime_converged)
  403. print('Number of initial medians: ', len(medians))
  404. total_itr = 0
  405. num_started_descents = 0
  406. for itr in self.__itrs:
  407. total_itr += itr
  408. if itr > 0:
  409. num_started_descents += 1
  410. print('Size of graph collection: ', len(graph_ids))
  411. print('Number of started descents: ', num_started_descents)
  412. print('Number of converged descents: ', self.__num_converged_descents)
  413. print('Overall number of iterations: ', total_itr)
  414. print('Overall number of times the order decreased: ', self.__num_decrease_order)
  415. print('Overall number of times the order increased: ', self.__num_increase_order)
  416. print('===========================================================\n')
  417. def get_sum_of_distances(self, state=''):
  418. """Returns the sum of distances.
  419. Parameters
  420. ----------
  421. state : string
  422. The state of the estimator. Can be 'initialized' or 'converged'. Default: ""
  423. Returns
  424. -------
  425. float
  426. The sum of distances (SOD) of the median when the estimator was in the state `state` during the last call to run(). If `state` is not given, the converged SOD (without refinement) or refined SOD (with refinement) is returned.
  427. """
  428. if not self.__median_available():
  429. raise Exception('No median has been computed. Call run() before calling get_sum_of_distances().')
  430. if state == 'initialized':
  431. return self.__best_init_sum_of_distances
  432. if state == 'converged':
  433. return self.__converged_sum_of_distances
  434. return self.__sum_of_distances
  435. def __set_default_options(self):
  436. self.__init_type = 'RANDOM'
  437. self.__num_random_inits = 10
  438. self.__desired_num_random_inits = 10
  439. self.__use_real_randomness = True
  440. self.__seed = 0
  441. self.__refine = True
  442. self.__time_limit_in_sec = 0
  443. self.__epsilon = 0.0001
  444. self.__max_itrs = 100
  445. self.__max_itrs_without_update = 3
  446. self.__num_inits_increase_order = 10
  447. self.__init_type_increase_order = 'K-MEANS++'
  448. self.__max_itrs_increase_order = 10
  449. self.__print_to_stdout = 2
  450. def __construct_initial_medians(self, graph_ids, timer, initial_medians):
  451. # Print information about current iteration.
  452. if self.__print_to_stdout == 2:
  453. print('\n===========================================================')
  454. print('Constructing initial median(s).')
  455. print('-----------------------------------------------------------')
  456. # Compute or sample the initial median(s).
  457. initial_medians.clear()
  458. if self.__init_type == 'MEDOID':
  459. self.__compute_medoid(graph_ids, timer, initial_medians)
  460. elif self.__init_type == 'MAX':
  461. pass # @todo
  462. # compute_max_order_graph_(graph_ids, initial_medians)
  463. elif self.__init_type == 'MIN':
  464. pass # @todo
  465. # compute_min_order_graph_(graph_ids, initial_medians)
  466. elif self.__init_type == 'MEAN':
  467. pass # @todo
  468. # compute_mean_order_graph_(graph_ids, initial_medians)
  469. else:
  470. pass # @todo
  471. # sample_initial_medians_(graph_ids, initial_medians)
  472. # Print information about current iteration.
  473. if self.__print_to_stdout == 2:
  474. print('===========================================================')
  475. def __compute_medoid(self, graph_ids, timer, initial_medians):
  476. # Use method selected for initialization phase.
  477. self.__ged_env.set_method(self.__init_method, self.__init_options)
  478. # Print information about current iteration.
  479. if self.__print_to_stdout == 2:
  480. progress = tqdm(desc='\rComputing medoid', total=len(graph_ids), file=sys.stdout)
  481. # Compute the medoid.
  482. medoid_id = graph_ids[0]
  483. best_sum_of_distances = np.inf
  484. for g_id in graph_ids:
  485. if timer.expired():
  486. self.__state = AlgorithmState.CALLED
  487. break
  488. sum_of_distances = 0
  489. for h_id in graph_ids:
  490. self.__ged_env.run_method(g_id, h_id)
  491. sum_of_distances += self.__ged_env.get_upper_bound(g_id, h_id)
  492. if sum_of_distances < best_sum_of_distances:
  493. best_sum_of_distances = sum_of_distances
  494. medoid_id = g_id
  495. # Print information about current iteration.
  496. if self.__print_to_stdout == 2:
  497. progress.update(1)
  498. initial_medians.append(self.__ged_env.get_nx_graph(medoid_id, True, True, False)) # @todo
  499. # Print information about current iteration.
  500. if self.__print_to_stdout == 2:
  501. print('\n')
  502. def __termination_criterion_met(self, converged, timer, itr, itrs_without_update):
  503. if timer.expired() or (itr >= self.__max_itrs if self.__max_itrs >= 0 else False):
  504. if self.__state == AlgorithmState.TERMINATED:
  505. self.__state = AlgorithmState.INITIALIZED
  506. return True
  507. return converged or (itrs_without_update > self.__max_itrs_without_update if self.__max_itrs_without_update >= 0 else False)
  508. def __update_median(self, graphs, median):
  509. # Print information about current iteration.
  510. if self.__print_to_stdout == 2:
  511. print('Updating median: ', end='')
  512. # Store copy of the old median.
  513. old_median = median.copy() # @todo: this is just a shallow copy.
  514. # Update the node labels.
  515. if self.__labeled_nodes:
  516. self.__update_node_labels(graphs, median)
  517. # Update the edges and their labels.
  518. self.__update_edges(graphs, median)
  519. # Print information about current iteration.
  520. if self.__print_to_stdout == 2:
  521. print('done.')
  522. return not self.__are_graphs_equal(median, old_median)
  523. def __update_node_labels(self, graphs, median):
  524. # Print information about current iteration.
  525. if self.__print_to_stdout == 2:
  526. print('nodes ... ', end='')
  527. # Iterate through all nodes of the median.
  528. for i in range(0, nx.number_of_nodes(median)):
  529. # print('i: ', i)
  530. # Collect the labels of the substituted nodes.
  531. node_labels = []
  532. for graph_id, graph in graphs.items():
  533. # print('graph_id: ', graph_id)
  534. # print(self.__node_maps_from_median[graph_id])
  535. k = self.__get_node_image_from_map(self.__node_maps_from_median[graph_id], i)
  536. # print('k: ', k)
  537. if k != np.inf:
  538. node_labels.append(graph.nodes[k])
  539. # Compute the median label and update the median.
  540. if len(node_labels) > 0:
  541. median_label = self.__ged_env.get_median_node_label(node_labels)
  542. if self.__ged_env.get_node_rel_cost(median.nodes[i], median_label) > self.__epsilon:
  543. nx.set_node_attributes(median, {i: median_label})
  544. def __update_edges(self, graphs, median):
  545. # Print information about current iteration.
  546. if self.__print_to_stdout == 2:
  547. print('edges ... ', end='')
  548. # Clear the adjacency lists of the median and reset number of edges to 0.
  549. median_edges = list(median.edges)
  550. for (head, tail) in median_edges:
  551. median.remove_edge(head, tail)
  552. # @todo: what if edge is not labeled?
  553. # Iterate through all possible edges (i,j) of the median.
  554. for i in range(0, nx.number_of_nodes(median)):
  555. for j in range(i + 1, nx.number_of_nodes(median)):
  556. # Collect the labels of the edges to which (i,j) is mapped by the node maps.
  557. edge_labels = []
  558. for graph_id, graph in graphs.items():
  559. k = self.__get_node_image_from_map(self.__node_maps_from_median[graph_id], i)
  560. l = self.__get_node_image_from_map(self.__node_maps_from_median[graph_id], j)
  561. if k != np.inf and l != np.inf:
  562. if graph.has_edge(k, l):
  563. edge_labels.append(graph.edges[(k, l)])
  564. # Compute the median edge label and the overall edge relabeling cost.
  565. rel_cost = 0
  566. median_label = self.__ged_env.get_edge_label(1)
  567. if median.has_edge(i, j):
  568. median_label = median.edges[(i, j)]
  569. if self.__labeled_edges and len(edge_labels) > 0:
  570. new_median_label = self.__ged_env.median_edge_label(edge_labels)
  571. if self.__ged_env.get_edge_rel_cost(median_label, new_median_label) > self.__epsilon:
  572. median_label = new_median_label
  573. for edge_label in edge_labels:
  574. rel_cost += self.__ged_env.get_edge_rel_cost(median_label, edge_label)
  575. # Update the median.
  576. if rel_cost < (self.__edge_ins_cost + self.__edge_del_cost) * len(edge_labels) - self.__edge_del_cost * len(graphs):
  577. median.add_edge(i, j, **median_label)
  578. else:
  579. if median.has_edge(i, j):
  580. median.remove_edge(i, j)
  581. def __update_node_maps(self):
  582. # Print information about current iteration.
  583. if self.__print_to_stdout == 2:
  584. progress = tqdm(desc='\rUpdating node maps', total=len(self.__node_maps_from_median), file=sys.stdout)
  585. # Update the node maps.
  586. node_maps_were_modified = False
  587. for graph_id in self.__node_maps_from_median:
  588. self.__ged_env.run_method(self.__median_id, graph_id)
  589. if self.__ged_env.get_upper_bound(self.__median_id, graph_id) < self.__ged_env.get_induced_cost(self.__median_id, graph_id) - self.__epsilon: # @todo: see above.
  590. self.__node_maps_from_median[graph_id] = self.__ged_env.get_node_map(self.__median_id, graph_id) # @todo: node_map may not assigned.
  591. node_maps_were_modified = True
  592. # Print information about current iteration.
  593. if self.__print_to_stdout == 2:
  594. progress.update(1)
  595. # Print information about current iteration.
  596. if self.__print_to_stdout == 2:
  597. print('\n')
  598. # Return true if the node maps were modified.
  599. return node_maps_were_modified
  600. def __improve_sum_of_distances(self, timer):
  601. pass
  602. def __median_available(self):
  603. return self.__median_id != np.inf
  604. def __get_node_image_from_map(self, node_map, node):
  605. """
  606. Return ID of the node mapping of `node` in `node_map`.
  607. Parameters
  608. ----------
  609. node_map : list[tuple(int, int)]
  610. List of node maps where the mapping node is found.
  611. node : int
  612. The mapping node of this node is returned
  613. Raises
  614. ------
  615. Exception
  616. If the node with ID `node` is not contained in the source nodes of the node map.
  617. Returns
  618. -------
  619. int
  620. ID of the mapping of `node`.
  621. Notes
  622. -----
  623. This function is not implemented in the `ged::MedianGraphEstimator` class of the `GEDLIB` library. Instead it is a Python implementation of the `ged::NodeMap::image` function.
  624. """
  625. if node < len(node_map):
  626. return node_map[node][1] if node_map[node][1] < len(node_map) else np.inf
  627. else:
  628. raise Exception('The node with ID ', str(node), ' is not contained in the source nodes of the node map.')
  629. return np.inf
  630. def __are_graphs_equal(self, g1, g2):
  631. """
  632. Check if the two graphs are equal.
  633. Parameters
  634. ----------
  635. g1 : NetworkX graph object
  636. Graph 1 to be compared.
  637. g2 : NetworkX graph object
  638. Graph 2 to be compared.
  639. Returns
  640. -------
  641. bool
  642. True if the two graph are equal.
  643. Notes
  644. -----
  645. This is not an identical check. Here the two graphs are equal if and only if their original_node_ids, nodes, all node labels, edges and all edge labels are equal. This function is specifically designed for class `MedianGraphEstimator` and should not be used elsewhere.
  646. """
  647. # check original node ids.
  648. if not g1.graph['original_node_ids'] == g2.graph['original_node_ids']:
  649. return False
  650. # check nodes.
  651. nlist1 = [n for n in g1.nodes(data=True)]
  652. nlist2 = [n for n in g2.nodes(data=True)]
  653. if not nlist1 == nlist2:
  654. return False
  655. # check edges.
  656. elist1 = [n for n in g1.edges(data=True)]
  657. elist2 = [n for n in g2.edges(data=True)]
  658. if not elist1 == elist2:
  659. return False
  660. return True
  661. def compute_my_cost(g, h, node_map):
  662. cost = 0.0
  663. for node in g.nodes:
  664. cost += 0

A Python package for graph kernels, graph edit distances and graph pre-image problem.