You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

median_graph_estimator_cml.py 63 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 16 18:04:55 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. import time
  9. from tqdm import tqdm
  10. import sys
  11. import networkx as nx
  12. import multiprocessing
  13. from multiprocessing import Pool
  14. from functools import partial
  15. from gklearn.ged.env import AlgorithmState, NodeMap
  16. from gklearn.ged.util import misc
  17. from gklearn.utils import Timer, SpecialLabel
  18. class MedianGraphEstimatorCML(object): # @todo: differ dummy_node from undifined node?
  19. """Estimate median graphs using the pure Python version of GEDEnv.
  20. """
  21. def __init__(self, ged_env, constant_node_costs):
  22. """Constructor.
  23. Parameters
  24. ----------
  25. ged_env : gklearn.gedlib.gedlibpy.GEDEnv
  26. Initialized GED environment. The edit costs must be set by the user.
  27. constant_node_costs : Boolean
  28. Set to True if the node relabeling costs are constant.
  29. """
  30. self._ged_env = ged_env
  31. self._init_method = 'BRANCH_FAST'
  32. self._init_options = ''
  33. self._descent_method = 'BRANCH_FAST'
  34. self._descent_options = ''
  35. self._refine_method = 'IPFP'
  36. self._refine_options = ''
  37. self._constant_node_costs = constant_node_costs
  38. self._labeled_nodes = (ged_env.get_num_node_labels() > 1)
  39. self._node_del_cost = ged_env.get_node_del_cost(ged_env.get_node_label(1, to_dict=False))
  40. self._node_ins_cost = ged_env.get_node_ins_cost(ged_env.get_node_label(1, to_dict=False))
  41. self._labeled_edges = (ged_env.get_num_edge_labels() > 1)
  42. self._edge_del_cost = ged_env.get_edge_del_cost(ged_env.get_edge_label(1, to_dict=False))
  43. self._edge_ins_cost = ged_env.get_edge_ins_cost(ged_env.get_edge_label(1, to_dict=False))
  44. self._init_type = 'RANDOM'
  45. self._num_random_inits = 10
  46. self._desired_num_random_inits = 10
  47. self._use_real_randomness = True
  48. self._seed = 0
  49. self._parallel = True
  50. self._update_order = True
  51. self._sort_graphs = True # sort graphs by size when computing GEDs.
  52. self._refine = True
  53. self._time_limit_in_sec = 0
  54. self._epsilon = 0.0001
  55. self._max_itrs = 100
  56. self._max_itrs_without_update = 3
  57. self._num_inits_increase_order = 10
  58. self._init_type_increase_order = 'K-MEANS++'
  59. self._max_itrs_increase_order = 10
  60. self._print_to_stdout = 2
  61. self._median_id = np.inf # @todo: check
  62. self._node_maps_from_median = {}
  63. self._sum_of_distances = 0
  64. self._best_init_sum_of_distances = np.inf
  65. self._converged_sum_of_distances = np.inf
  66. self._runtime = None
  67. self._runtime_initialized = None
  68. self._runtime_converged = None
  69. self._itrs = [] # @todo: check: {} ?
  70. self._num_decrease_order = 0
  71. self._num_increase_order = 0
  72. self._num_converged_descents = 0
  73. self._state = AlgorithmState.TERMINATED
  74. self._label_names = {}
  75. if ged_env is None:
  76. raise Exception('The GED environment pointer passed to the constructor of MedianGraphEstimator is null.')
  77. elif not ged_env.is_initialized():
  78. raise Exception('The GED environment is uninitialized. Call gedlibpy.GEDEnv.init() before passing it to the constructor of MedianGraphEstimator.')
  79. def set_options(self, options):
  80. """Sets the options of the estimator.
  81. Parameters
  82. ----------
  83. options : string
  84. String that specifies with which options to run the estimator.
  85. """
  86. self._set_default_options()
  87. options_map = misc.options_string_to_options_map(options)
  88. for opt_name, opt_val in options_map.items():
  89. if opt_name == 'init-type':
  90. self._init_type = opt_val
  91. if opt_val != 'MEDOID' and opt_val != 'RANDOM' and opt_val != 'MIN' and opt_val != 'MAX' and opt_val != 'MEAN':
  92. raise Exception('Invalid argument ' + opt_val + ' for option init-type. Usage: options = "[--init-type RANDOM|MEDOID|EMPTY|MIN|MAX|MEAN] [...]"')
  93. elif opt_name == 'random-inits':
  94. try:
  95. self._num_random_inits = int(opt_val)
  96. self._desired_num_random_inits = self._num_random_inits
  97. except:
  98. raise Exception('Invalid argument "' + opt_val + '" for option random-inits. Usage: options = "[--random-inits <convertible to int greater 0>]"')
  99. if self._num_random_inits <= 0:
  100. raise Exception('Invalid argument "' + opt_val + '" for option random-inits. Usage: options = "[--random-inits <convertible to int greater 0>]"')
  101. elif opt_name == 'randomness':
  102. if opt_val == 'PSEUDO':
  103. self._use_real_randomness = False
  104. elif opt_val == 'REAL':
  105. self._use_real_randomness = True
  106. else:
  107. raise Exception('Invalid argument "' + opt_val + '" for option randomness. Usage: options = "[--randomness REAL|PSEUDO] [...]"')
  108. elif opt_name == 'stdout':
  109. if opt_val == '0':
  110. self._print_to_stdout = 0
  111. elif opt_val == '1':
  112. self._print_to_stdout = 1
  113. elif opt_val == '2':
  114. self._print_to_stdout = 2
  115. else:
  116. raise Exception('Invalid argument "' + opt_val + '" for option stdout. Usage: options = "[--stdout 0|1|2] [...]"')
  117. elif opt_name == 'parallel':
  118. if opt_val == 'TRUE':
  119. self._parallel = True
  120. elif opt_val == 'FALSE':
  121. self._parallel = False
  122. else:
  123. raise Exception('Invalid argument "' + opt_val + '" for option parallel. Usage: options = "[--parallel TRUE|FALSE] [...]"')
  124. elif opt_name == 'update-order':
  125. if opt_val == 'TRUE':
  126. self._update_order = True
  127. elif opt_val == 'FALSE':
  128. self._update_order = False
  129. else:
  130. raise Exception('Invalid argument "' + opt_val + '" for option update-order. Usage: options = "[--update-order TRUE|FALSE] [...]"')
  131. elif opt_name == 'sort-graphs':
  132. if opt_val == 'TRUE':
  133. self._sort_graphs = True
  134. elif opt_val == 'FALSE':
  135. self._sort_graphs = False
  136. else:
  137. raise Exception('Invalid argument "' + opt_val + '" for option sort-graphs. Usage: options = "[--sort-graphs TRUE|FALSE] [...]"')
  138. elif opt_name == 'refine':
  139. if opt_val == 'TRUE':
  140. self._refine = True
  141. elif opt_val == 'FALSE':
  142. self._refine = False
  143. else:
  144. raise Exception('Invalid argument "' + opt_val + '" for option refine. Usage: options = "[--refine TRUE|FALSE] [...]"')
  145. elif opt_name == 'time-limit':
  146. try:
  147. self._time_limit_in_sec = float(opt_val)
  148. except:
  149. raise Exception('Invalid argument "' + opt_val + '" for option time-limit. Usage: options = "[--time-limit <convertible to double>] [...]')
  150. elif opt_name == 'max-itrs':
  151. try:
  152. self._max_itrs = int(opt_val)
  153. except:
  154. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs. Usage: options = "[--max-itrs <convertible to int>] [...]')
  155. elif opt_name == 'max-itrs-without-update':
  156. try:
  157. self._max_itrs_without_update = int(opt_val)
  158. except:
  159. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs-without-update. Usage: options = "[--max-itrs-without-update <convertible to int>] [...]')
  160. elif opt_name == 'seed':
  161. try:
  162. self._seed = int(opt_val)
  163. except:
  164. raise Exception('Invalid argument "' + opt_val + '" for option seed. Usage: options = "[--seed <convertible to int greater equal 0>] [...]')
  165. elif opt_name == 'epsilon':
  166. try:
  167. self._epsilon = float(opt_val)
  168. except:
  169. raise Exception('Invalid argument "' + opt_val + '" for option epsilon. Usage: options = "[--epsilon <convertible to double greater 0>] [...]')
  170. if self._epsilon <= 0:
  171. raise Exception('Invalid argument "' + opt_val + '" for option epsilon. Usage: options = "[--epsilon <convertible to double greater 0>] [...]')
  172. elif opt_name == 'inits-increase-order':
  173. try:
  174. self._num_inits_increase_order = int(opt_val)
  175. except:
  176. raise Exception('Invalid argument "' + opt_val + '" for option inits-increase-order. Usage: options = "[--inits-increase-order <convertible to int greater 0>]"')
  177. if self._num_inits_increase_order <= 0:
  178. raise Exception('Invalid argument "' + opt_val + '" for option inits-increase-order. Usage: options = "[--inits-increase-order <convertible to int greater 0>]"')
  179. elif opt_name == 'init-type-increase-order':
  180. self._init_type_increase_order = opt_val
  181. if opt_val != 'CLUSTERS' and opt_val != 'K-MEANS++':
  182. raise Exception('Invalid argument ' + opt_val + ' for option init-type-increase-order. Usage: options = "[--init-type-increase-order CLUSTERS|K-MEANS++] [...]"')
  183. elif opt_name == 'max-itrs-increase-order':
  184. try:
  185. self._max_itrs_increase_order = int(opt_val)
  186. except:
  187. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs-increase-order. Usage: options = "[--max-itrs-increase-order <convertible to int>] [...]')
  188. else:
  189. valid_options = '[--init-type <arg>] [--random-inits <arg>] [--randomness <arg>] [--seed <arg>] [--stdout <arg>] '
  190. valid_options += '[--time-limit <arg>] [--max-itrs <arg>] [--epsilon <arg>] '
  191. valid_options += '[--inits-increase-order <arg>] [--init-type-increase-order <arg>] [--max-itrs-increase-order <arg>]'
  192. raise Exception('Invalid option "' + opt_name + '". Usage: options = "' + valid_options + '"')
  193. def set_init_method(self, init_method, init_options={}):
  194. """Selects method to be used for computing the initial medoid graph.
  195. Parameters
  196. ----------
  197. init_method : string
  198. The selected method. Default: ged::Options::GEDMethod::BRANCH_UNIFORM.
  199. init_options : string
  200. The options for the selected method. Default: "".
  201. Notes
  202. -----
  203. Has no effect unless "--init-type MEDOID" is passed to set_options().
  204. """
  205. self._init_method = init_method;
  206. self._init_options = init_options;
  207. def set_descent_method(self, descent_method, descent_options=''):
  208. """Selects method to be used for block gradient descent..
  209. Parameters
  210. ----------
  211. descent_method : string
  212. The selected method. Default: ged::Options::GEDMethod::BRANCH_FAST.
  213. descent_options : string
  214. The options for the selected method. Default: "".
  215. Notes
  216. -----
  217. Has no effect unless "--init-type MEDOID" is passed to set_options().
  218. """
  219. self._descent_method = descent_method;
  220. self._descent_options = descent_options;
  221. def set_refine_method(self, refine_method, refine_options):
  222. """Selects method to be used for improving the sum of distances and the node maps for the converged median.
  223. Parameters
  224. ----------
  225. refine_method : string
  226. The selected method. Default: "IPFP".
  227. refine_options : string
  228. The options for the selected method. Default: "".
  229. Notes
  230. -----
  231. Has no effect if "--refine FALSE" is passed to set_options().
  232. """
  233. self._refine_method = refine_method
  234. self._refine_options = refine_options
  235. def run(self, graph_ids, set_median_id, gen_median_id):
  236. """Computes a generalized median graph.
  237. Parameters
  238. ----------
  239. graph_ids : list[integer]
  240. The IDs of the graphs for which the median should be computed. Must have been added to the environment passed to the constructor.
  241. set_median_id : integer
  242. The ID of the computed set-median. A dummy graph with this ID must have been added to the environment passed to the constructor. Upon termination, the computed median can be obtained via gklearn.gedlib.gedlibpy.GEDEnv.get_graph().
  243. gen_median_id : integer
  244. The ID of the computed generalized median. Upon termination, the computed median can be obtained via gklearn.gedlib.gedlibpy.GEDEnv.get_graph().
  245. """
  246. # Sanity checks.
  247. if len(graph_ids) == 0:
  248. raise Exception('Empty vector of graph IDs, unable to compute median.')
  249. all_graphs_empty = True
  250. for graph_id in graph_ids:
  251. if self._ged_env.get_graph_num_nodes(graph_id) > 0:
  252. all_graphs_empty = False
  253. break
  254. if all_graphs_empty:
  255. raise Exception('All graphs in the collection are empty.')
  256. # Start timer and record start time.
  257. start = time.time()
  258. timer = Timer(self._time_limit_in_sec)
  259. self._median_id = gen_median_id
  260. self._state = AlgorithmState.TERMINATED
  261. # Get NetworkX graph representations of the input graphs.
  262. graphs = {}
  263. for graph_id in graph_ids:
  264. # @todo: get_nx_graph() function may need to be modified according to the coming code.
  265. graphs[graph_id] = self._ged_env.get_nx_graph(graph_id)
  266. # print(self._ged_env.get_graph_internal_id(0))
  267. # print(graphs[0].graph)
  268. # print(graphs[0].nodes(data=True))
  269. # print(graphs[0].edges(data=True))
  270. # print(nx.adjacency_matrix(graphs[0]))
  271. # Construct initial medians.
  272. medians = []
  273. self._construct_initial_medians(graph_ids, timer, medians)
  274. end_init = time.time()
  275. self._runtime_initialized = end_init - start
  276. # print(medians[0].graph)
  277. # print(medians[0].nodes(data=True))
  278. # print(medians[0].edges(data=True))
  279. # print(nx.adjacency_matrix(medians[0]))
  280. # Reset information about iterations and number of times the median decreases and increases.
  281. self._itrs = [0] * len(medians)
  282. self._num_decrease_order = 0
  283. self._num_increase_order = 0
  284. self._num_converged_descents = 0
  285. # Initialize the best median.
  286. best_sum_of_distances = np.inf
  287. self._best_init_sum_of_distances = np.inf
  288. node_maps_from_best_median = {}
  289. # Run block gradient descent from all initial medians.
  290. self._ged_env.set_method(self._descent_method, self._descent_options)
  291. for median_pos in range(0, len(medians)):
  292. # Terminate if the timer has expired and at least one SOD has been computed.
  293. if timer.expired() and median_pos > 0:
  294. break
  295. # Print information about current iteration.
  296. if self._print_to_stdout == 2:
  297. print('\n===========================================================')
  298. print('Block gradient descent for initial median', str(median_pos + 1), 'of', str(len(medians)), '.')
  299. print('-----------------------------------------------------------')
  300. # Get reference to the median.
  301. median = medians[median_pos]
  302. # Load initial median into the environment.
  303. self._ged_env.load_nx_graph(median, gen_median_id)
  304. self._ged_env.init(self._ged_env.get_init_type())
  305. # Compute node maps and sum of distances for initial median.
  306. # xxx = self._node_maps_from_median
  307. self._compute_init_node_maps(graph_ids, gen_median_id)
  308. # yyy = self._node_maps_from_median
  309. self._best_init_sum_of_distances = min(self._best_init_sum_of_distances, self._sum_of_distances)
  310. self._ged_env.load_nx_graph(median, set_median_id)
  311. # print(self._best_init_sum_of_distances)
  312. # Run block gradient descent from initial median.
  313. converged = False
  314. itrs_without_update = 0
  315. while not self._termination_criterion_met(converged, timer, self._itrs[median_pos], itrs_without_update):
  316. # Print information about current iteration.
  317. if self._print_to_stdout == 2:
  318. print('\n===========================================================')
  319. print('Iteration', str(self._itrs[median_pos] + 1), 'for initial median', str(median_pos + 1), 'of', str(len(medians)), '.')
  320. print('-----------------------------------------------------------')
  321. # Initialize flags that tell us what happened in the iteration.
  322. median_modified = False
  323. node_maps_modified = False
  324. decreased_order = False
  325. increased_order = False
  326. # Update the median.
  327. median_modified = self._update_median(graphs, median)
  328. if self._update_order:
  329. pass # @todo:
  330. # if not median_modified or self._itrs[median_pos] == 0:
  331. # decreased_order = self._decrease_order(graphs, median)
  332. # if not decreased_order or self._itrs[median_pos] == 0:
  333. # increased_order = self._increase_order(graphs, median)
  334. # Update the number of iterations without update of the median.
  335. if median_modified or decreased_order or increased_order:
  336. itrs_without_update = 0
  337. else:
  338. itrs_without_update += 1
  339. # Print information about current iteration.
  340. if self._print_to_stdout == 2:
  341. print('Loading median to environment: ... ', end='')
  342. # Load the median into the environment.
  343. # @todo: should this function use the original node label?
  344. self._ged_env.load_nx_graph(median, gen_median_id)
  345. self._ged_env.init(self._ged_env.get_init_type())
  346. # Print information about current iteration.
  347. if self._print_to_stdout == 2:
  348. print('done.')
  349. # Print information about current iteration.
  350. if self._print_to_stdout == 2:
  351. print('Updating induced costs: ... ', end='')
  352. # Compute induced costs of the old node maps w.r.t. the updated median.
  353. for graph_id in graph_ids:
  354. # print(self._node_maps_from_median[graph_id].induced_cost())
  355. # xxx = self._node_maps_from_median[graph_id]
  356. self._ged_env.compute_induced_cost(gen_median_id, graph_id, self._node_maps_from_median[graph_id])
  357. # print('---------------------------------------')
  358. # print(self._node_maps_from_median[graph_id].induced_cost())
  359. # @todo:!!!!!!!!!!!!!!!!!!!!!!!!!!!!This value is a slight different from the c++ program, which might be a bug! Use it very carefully!
  360. # Print information about current iteration.
  361. if self._print_to_stdout == 2:
  362. print('done.')
  363. # Update the node maps.
  364. node_maps_modified = self._update_node_maps()
  365. # Update the order of the median if no improvement can be found with the current order.
  366. # Update the sum of distances.
  367. old_sum_of_distances = self._sum_of_distances
  368. self._sum_of_distances = 0
  369. for graph_id, node_map in self._node_maps_from_median.items():
  370. self._sum_of_distances += node_map.induced_cost()
  371. # print(self._sum_of_distances)
  372. # Print information about current iteration.
  373. if self._print_to_stdout == 2:
  374. print('Old local SOD: ', old_sum_of_distances)
  375. print('New local SOD: ', self._sum_of_distances)
  376. print('Best converged SOD: ', best_sum_of_distances)
  377. print('Modified median: ', median_modified)
  378. print('Modified node maps: ', node_maps_modified)
  379. print('Decreased order: ', decreased_order)
  380. print('Increased order: ', increased_order)
  381. print('===========================================================\n')
  382. converged = not (median_modified or node_maps_modified or decreased_order or increased_order)
  383. self._itrs[median_pos] += 1
  384. # Update the best median.
  385. if self._sum_of_distances < best_sum_of_distances:
  386. best_sum_of_distances = self._sum_of_distances
  387. node_maps_from_best_median = self._node_maps_from_median.copy() # @todo: this is a shallow copy, not sure if it is enough.
  388. best_median = median
  389. # Update the number of converged descents.
  390. if converged:
  391. self._num_converged_descents += 1
  392. # Store the best encountered median.
  393. self._sum_of_distances = best_sum_of_distances
  394. self._node_maps_from_median = node_maps_from_best_median
  395. self._ged_env.load_nx_graph(best_median, gen_median_id)
  396. self._ged_env.init(self._ged_env.get_init_type())
  397. end_descent = time.time()
  398. self._runtime_converged = end_descent - start
  399. # Refine the sum of distances and the node maps for the converged median.
  400. self._converged_sum_of_distances = self._sum_of_distances
  401. if self._refine:
  402. self._improve_sum_of_distances(timer)
  403. # Record end time, set runtime and reset the number of initial medians.
  404. end = time.time()
  405. self._runtime = end - start
  406. self._num_random_inits = self._desired_num_random_inits
  407. # Print global information.
  408. if self._print_to_stdout != 0:
  409. print('\n===========================================================')
  410. print('Finished computation of generalized median graph.')
  411. print('-----------------------------------------------------------')
  412. print('Best SOD after initialization: ', self._best_init_sum_of_distances)
  413. print('Converged SOD: ', self._converged_sum_of_distances)
  414. if self._refine:
  415. print('Refined SOD: ', self._sum_of_distances)
  416. print('Overall runtime: ', self._runtime)
  417. print('Runtime of initialization: ', self._runtime_initialized)
  418. print('Runtime of block gradient descent: ', self._runtime_converged - self._runtime_initialized)
  419. if self._refine:
  420. print('Runtime of refinement: ', self._runtime - self._runtime_converged)
  421. print('Number of initial medians: ', len(medians))
  422. total_itr = 0
  423. num_started_descents = 0
  424. for itr in self._itrs:
  425. total_itr += itr
  426. if itr > 0:
  427. num_started_descents += 1
  428. print('Size of graph collection: ', len(graph_ids))
  429. print('Number of started descents: ', num_started_descents)
  430. print('Number of converged descents: ', self._num_converged_descents)
  431. print('Overall number of iterations: ', total_itr)
  432. print('Overall number of times the order decreased: ', self._num_decrease_order)
  433. print('Overall number of times the order increased: ', self._num_increase_order)
  434. print('===========================================================\n')
  435. def _improve_sum_of_distances(self, timer): # @todo: go through and test
  436. # Use method selected for refinement phase.
  437. self._ged_env.set_method(self._refine_method, self._refine_options)
  438. # Print information about current iteration.
  439. if self._print_to_stdout == 2:
  440. progress = tqdm(desc='Improving node maps', total=len(self._node_maps_from_median), file=sys.stdout)
  441. print('\n===========================================================')
  442. print('Improving node maps and SOD for converged median.')
  443. print('-----------------------------------------------------------')
  444. progress.update(1)
  445. # Improving the node maps.
  446. nb_nodes_median = self._ged_env.get_graph_num_nodes(self._gen_median_id)
  447. for graph_id, node_map in self._node_maps_from_median.items():
  448. if time.expired():
  449. if self._state == AlgorithmState.TERMINATED:
  450. self._state = AlgorithmState.CONVERGED
  451. break
  452. nb_nodes_g = self._ged_env.get_graph_num_nodes(graph_id)
  453. if nb_nodes_median <= nb_nodes_g or not self._sort_graphs:
  454. self._ged_env.run_method(self._gen_median_id, graph_id)
  455. if self._ged_env.get_upper_bound(self._gen_median_id, graph_id) < node_map.induced_cost():
  456. self._node_maps_from_median[graph_id] = self._ged_env.get_node_map(self._gen_median_id, graph_id)
  457. else:
  458. self._ged_env.run_method(graph_id, self._gen_median_id)
  459. if self._ged_env.get_upper_bound(graph_id, self._gen_median_id) < node_map.induced_cost():
  460. node_map_tmp = self._ged_env.get_node_map(graph_id, self._gen_median_id)
  461. node_map_tmp.forward_map, node_map_tmp.backward_map = node_map_tmp.backward_map, node_map_tmp.forward_map
  462. self._node_maps_from_median[graph_id] = node_map_tmp
  463. self._sum_of_distances += self._node_maps_from_median[graph_id].induced_cost()
  464. # Print information.
  465. if self._print_to_stdout == 2:
  466. progress.update(1)
  467. self._sum_of_distances = 0.0
  468. for key, val in self._node_maps_from_median.items():
  469. self._sum_of_distances += val.induced_cost()
  470. # Print information.
  471. if self._print_to_stdout == 2:
  472. print('===========================================================\n')
  473. def _median_available(self):
  474. return self._median_id != np.inf
  475. def get_state(self):
  476. if not self._median_available():
  477. raise Exception('No median has been computed. Call run() before calling get_state().')
  478. return self._state
  479. def get_sum_of_distances(self, state=''):
  480. """Returns the sum of distances.
  481. Parameters
  482. ----------
  483. state : string
  484. The state of the estimator. Can be 'initialized' or 'converged'. Default: ""
  485. Returns
  486. -------
  487. float
  488. The sum of distances (SOD) of the median when the estimator was in the state `state` during the last call to run(). If `state` is not given, the converged SOD (without refinement) or refined SOD (with refinement) is returned.
  489. """
  490. if not self._median_available():
  491. raise Exception('No median has been computed. Call run() before calling get_sum_of_distances().')
  492. if state == 'initialized':
  493. return self._best_init_sum_of_distances
  494. if state == 'converged':
  495. return self._converged_sum_of_distances
  496. return self._sum_of_distances
  497. def get_runtime(self, state):
  498. if not self._median_available():
  499. raise Exception('No median has been computed. Call run() before calling get_runtime().')
  500. if state == AlgorithmState.INITIALIZED:
  501. return self._runtime_initialized
  502. if state == AlgorithmState.CONVERGED:
  503. return self._runtime_converged
  504. return self._runtime
  505. def get_num_itrs(self):
  506. if not self._median_available():
  507. raise Exception('No median has been computed. Call run() before calling get_num_itrs().')
  508. return self._itrs
  509. def get_num_times_order_decreased(self):
  510. if not self._median_available():
  511. raise Exception('No median has been computed. Call run() before calling get_num_times_order_decreased().')
  512. return self._num_decrease_order
  513. def get_num_times_order_increased(self):
  514. if not self._median_available():
  515. raise Exception('No median has been computed. Call run() before calling get_num_times_order_increased().')
  516. return self._num_increase_order
  517. def get_num_converged_descents(self):
  518. if not self._median_available():
  519. raise Exception('No median has been computed. Call run() before calling get_num_converged_descents().')
  520. return self._num_converged_descents
  521. def get_ged_env(self):
  522. return self._ged_env
  523. def _set_default_options(self):
  524. self._init_type = 'RANDOM'
  525. self._num_random_inits = 10
  526. self._desired_num_random_inits = 10
  527. self._use_real_randomness = True
  528. self._seed = 0
  529. self._parallel = True
  530. self._update_order = True
  531. self._sort_graphs = True
  532. self._refine = True
  533. self._time_limit_in_sec = 0
  534. self._epsilon = 0.0001
  535. self._max_itrs = 100
  536. self._max_itrs_without_update = 3
  537. self._num_inits_increase_order = 10
  538. self._init_type_increase_order = 'K-MEANS++'
  539. self._max_itrs_increase_order = 10
  540. self._print_to_stdout = 2
  541. self._label_names = {}
  542. def _construct_initial_medians(self, graph_ids, timer, initial_medians):
  543. # Print information about current iteration.
  544. if self._print_to_stdout == 2:
  545. print('\n===========================================================')
  546. print('Constructing initial median(s).')
  547. print('-----------------------------------------------------------')
  548. # Compute or sample the initial median(s).
  549. initial_medians.clear()
  550. if self._init_type == 'MEDOID':
  551. self._compute_medoid(graph_ids, timer, initial_medians)
  552. elif self._init_type == 'MAX':
  553. pass # @todo
  554. # compute_max_order_graph_(graph_ids, initial_medians)
  555. elif self._init_type == 'MIN':
  556. pass # @todo
  557. # compute_min_order_graph_(graph_ids, initial_medians)
  558. elif self._init_type == 'MEAN':
  559. pass # @todo
  560. # compute_mean_order_graph_(graph_ids, initial_medians)
  561. else:
  562. pass # @todo
  563. # sample_initial_medians_(graph_ids, initial_medians)
  564. # Print information about current iteration.
  565. if self._print_to_stdout == 2:
  566. print('===========================================================')
  567. def _compute_medoid(self, graph_ids, timer, initial_medians):
  568. # Use method selected for initialization phase.
  569. self._ged_env.set_method(self._init_method, self._init_options)
  570. # Compute the medoid.
  571. if self._parallel:
  572. # @todo: notice when parallel self._ged_env is not modified.
  573. sum_of_distances_list = [np.inf] * len(graph_ids)
  574. len_itr = len(graph_ids)
  575. itr = zip(graph_ids, range(0, len(graph_ids)))
  576. n_jobs = multiprocessing.cpu_count()
  577. if len_itr < 100 * n_jobs:
  578. chunksize = int(len_itr / n_jobs) + 1
  579. else:
  580. chunksize = 100
  581. def init_worker(ged_env_toshare):
  582. global G_ged_env
  583. G_ged_env = ged_env_toshare
  584. do_fun = partial(_compute_medoid_parallel, graph_ids, self._sort_graphs)
  585. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(self._ged_env,))
  586. if self._print_to_stdout == 2:
  587. iterator = tqdm(pool.imap_unordered(do_fun, itr, chunksize),
  588. desc='Computing medoid', file=sys.stdout)
  589. else:
  590. iterator = pool.imap_unordered(do_fun, itr, chunksize)
  591. for i, dis in iterator:
  592. sum_of_distances_list[i] = dis
  593. pool.close()
  594. pool.join()
  595. medoid_id = np.argmin(sum_of_distances_list)
  596. best_sum_of_distances = sum_of_distances_list[medoid_id]
  597. initial_medians.append(self._ged_env.get_nx_graph(medoid_id)) # @todo
  598. else:
  599. # Print information about current iteration.
  600. if self._print_to_stdout == 2:
  601. progress = tqdm(desc='Computing medoid', total=len(graph_ids), file=sys.stdout)
  602. medoid_id = graph_ids[0]
  603. best_sum_of_distances = np.inf
  604. for g_id in graph_ids:
  605. if timer.expired():
  606. self._state = AlgorithmState.CALLED
  607. break
  608. nb_nodes_g = self._ged_env.get_graph_num_nodes(g_id)
  609. sum_of_distances = 0
  610. for h_id in graph_ids: # @todo: this can be faster, only a half is needed.
  611. nb_nodes_h = self._ged_env.get_graph_num_nodes(h_id)
  612. if nb_nodes_g <= nb_nodes_h or not self._sort_graphs:
  613. self._ged_env.run_method(g_id, h_id) # @todo
  614. sum_of_distances += self._ged_env.get_upper_bound(g_id, h_id)
  615. else:
  616. self._ged_env.run_method(h_id, g_id)
  617. sum_of_distances += self._ged_env.get_upper_bound(h_id, g_id)
  618. if sum_of_distances < best_sum_of_distances:
  619. best_sum_of_distances = sum_of_distances
  620. medoid_id = g_id
  621. # Print information about current iteration.
  622. if self._print_to_stdout == 2:
  623. progress.update(1)
  624. initial_medians.append(self._ged_env.get_nx_graph(medoid_id)) # @todo
  625. # Print information about current iteration.
  626. if self._print_to_stdout == 2:
  627. print('\n')
  628. def _compute_init_node_maps(self, graph_ids, gen_median_id):
  629. # Compute node maps and sum of distances for initial median.
  630. if self._parallel:
  631. # @todo: notice when parallel self._ged_env is not modified.
  632. self._sum_of_distances = 0
  633. self._node_maps_from_median.clear()
  634. sum_of_distances_list = [0] * len(graph_ids)
  635. len_itr = len(graph_ids)
  636. itr = graph_ids
  637. n_jobs = multiprocessing.cpu_count()
  638. if len_itr < 100 * n_jobs:
  639. chunksize = int(len_itr / n_jobs) + 1
  640. else:
  641. chunksize = 100
  642. def init_worker(ged_env_toshare):
  643. global G_ged_env
  644. G_ged_env = ged_env_toshare
  645. nb_nodes_median = self._ged_env.get_graph_num_nodes(gen_median_id)
  646. do_fun = partial(_compute_init_node_maps_parallel, gen_median_id, self._sort_graphs, nb_nodes_median)
  647. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(self._ged_env,))
  648. if self._print_to_stdout == 2:
  649. iterator = tqdm(pool.imap_unordered(do_fun, itr, chunksize),
  650. desc='Computing initial node maps', file=sys.stdout)
  651. else:
  652. iterator = pool.imap_unordered(do_fun, itr, chunksize)
  653. for g_id, sod, node_maps in iterator:
  654. sum_of_distances_list[g_id] = sod
  655. self._node_maps_from_median[g_id] = node_maps
  656. pool.close()
  657. pool.join()
  658. self._sum_of_distances = np.sum(sum_of_distances_list)
  659. # xxx = self._node_maps_from_median
  660. else:
  661. # Print information about current iteration.
  662. if self._print_to_stdout == 2:
  663. progress = tqdm(desc='Computing initial node maps', total=len(graph_ids), file=sys.stdout)
  664. self._sum_of_distances = 0
  665. self._node_maps_from_median.clear()
  666. nb_nodes_median = self._ged_env.get_graph_num_nodes(gen_median_id)
  667. for graph_id in graph_ids:
  668. nb_nodes_g = self._ged_env.get_graph_num_nodes(graph_id)
  669. if nb_nodes_median <= nb_nodes_g or not self._sort_graphs:
  670. self._ged_env.run_method(gen_median_id, graph_id)
  671. self._node_maps_from_median[graph_id] = self._ged_env.get_node_map(gen_median_id, graph_id)
  672. else:
  673. self._ged_env.run_method(graph_id, gen_median_id)
  674. node_map_tmp = self._ged_env.get_node_map(graph_id, gen_median_id)
  675. node_map_tmp.forward_map, node_map_tmp.backward_map = node_map_tmp.backward_map, node_map_tmp.forward_map
  676. self._node_maps_from_median[graph_id] = node_map_tmp
  677. # print(self._node_maps_from_median[graph_id])
  678. self._sum_of_distances += self._node_maps_from_median[graph_id].induced_cost()
  679. # print(self._sum_of_distances)
  680. # Print information about current iteration.
  681. if self._print_to_stdout == 2:
  682. progress.update(1)
  683. # Print information about current iteration.
  684. if self._print_to_stdout == 2:
  685. print('\n')
  686. def _termination_criterion_met(self, converged, timer, itr, itrs_without_update):
  687. if timer.expired() or (itr >= self._max_itrs if self._max_itrs >= 0 else False):
  688. if self._state == AlgorithmState.TERMINATED:
  689. self._state = AlgorithmState.INITIALIZED
  690. return True
  691. return converged or (itrs_without_update > self._max_itrs_without_update if self._max_itrs_without_update >= 0 else False)
  692. def _update_median(self, graphs, median):
  693. # Print information about current iteration.
  694. if self._print_to_stdout == 2:
  695. print('Updating median: ', end='')
  696. # Store copy of the old median.
  697. old_median = median.copy() # @todo: this is just a shallow copy.
  698. # Update the node labels.
  699. if self._labeled_nodes:
  700. self._update_node_labels(graphs, median)
  701. # Update the edges and their labels.
  702. self._update_edges(graphs, median)
  703. # Print information about current iteration.
  704. if self._print_to_stdout == 2:
  705. print('done.')
  706. return not self._are_graphs_equal(median, old_median)
  707. def _update_node_labels(self, graphs, median):
  708. # print('----------------------------')
  709. # Print information about current iteration.
  710. if self._print_to_stdout == 2:
  711. print('nodes ... ', end='')
  712. # Collect all possible node labels.
  713. all_labels = self._ged_env.get_all_node_labels()
  714. # Iterate through all nodes of the median.
  715. for i in range(0, nx.number_of_nodes(median)):
  716. # print('i: ', i)
  717. # Collect the labels of the substituted nodes.
  718. node_labels = []
  719. for graph_id, graph in graphs.items():
  720. k = self._node_maps_from_median[graph_id].image(i)
  721. if k != np.inf:
  722. node_labels.append(tuple(graph.nodes[k].items())) # @todo: sort
  723. else:
  724. node_labels.append(SpecialLabel.DUMMY)
  725. # Compute the median label and update the median.
  726. if len(node_labels) > 0:
  727. fi_min = np.inf
  728. median_label = tuple()
  729. for label1 in all_labels:
  730. fi = 0
  731. for label2 in node_labels:
  732. fi += self._ged_env.get_node_cost(label1, label2) # @todo: check inside, this might be slow
  733. if fi < fi_min: # @todo: fi is too easy to be zero. use <= or consider multiple optimal labels.
  734. fi_min = fi
  735. median_label = label1
  736. median_label = {kv[0]: kv[1] for kv in median_label}
  737. nx.set_node_attributes(median, {i: median_label})
  738. # median_label = self._get_median_node_label(node_labels)
  739. # if self._ged_env.get_node_rel_cost(median.nodes[i], median_label) > self._epsilon:
  740. # nx.set_node_attributes(median, {i: median_label})
  741. def _update_edges(self, graphs, median):
  742. # Print information about current iteration.
  743. if self._print_to_stdout == 2:
  744. print('edges ... ', end='')
  745. # Collect all possible edge labels.
  746. all_labels = self._ged_env.get_all_edge_labels()
  747. # @todo: what if edge is not labeled?
  748. # Iterate through all possible edges (i,j) of the median.
  749. for i in range(0, nx.number_of_nodes(median)):
  750. for j in range(i + 1, nx.number_of_nodes(median)):
  751. # Collect the labels of the edges to which (i,j) is mapped by the node maps.
  752. edge_labels = []
  753. for graph_id, graph in graphs.items():
  754. k = self._node_maps_from_median[graph_id].image(i)
  755. l = self._node_maps_from_median[graph_id].image(j)
  756. if k != np.inf and l != np.inf and graph.has_edge(k, l):
  757. edge_labels.append(tuple(graph.edges[(k, l)].items())) # @todo: sort
  758. else:
  759. edge_labels.append(SpecialLabel.DUMMY)
  760. # Compute the median edge label and the overall edge relabeling cost.
  761. if self._labeled_edges and len(edge_labels) > 0:
  762. fij1_min = np.inf
  763. median_label = tuple()
  764. # Compute f_ij^0.
  765. fij0 = 0
  766. for label2 in edge_labels:
  767. fij0 += self._ged_env.get_edge_cost(SpecialLabel.DUMMY, label2)
  768. for label1 in all_labels:
  769. fij1 = 0
  770. for label2 in edge_labels:
  771. fij1 += self._ged_env.get_edge_cost(label1, label2)
  772. if fij1 < fij1_min:
  773. fij1_min = fij1
  774. median_label = label1
  775. # Update the median.
  776. if median.has_edge(i, j):
  777. median.remove_edge(i, j)
  778. if fij1_min < fij0: # @todo: this never happens.
  779. median_label = {kv[0]: kv[1] for kv in median_label}
  780. median.add_edge(i, j, **median_label)
  781. # if self._ged_env.get_edge_rel_cost(median_label, new_median_label) > self._epsilon:
  782. # median_label = new_median_label
  783. def _update_node_maps(self):
  784. # Update the node maps.
  785. if self._parallel:
  786. # @todo: notice when parallel self._ged_env is not modified.
  787. node_maps_were_modified = False
  788. # xxx = self._node_maps_from_median.copy()
  789. len_itr = len(self._node_maps_from_median)
  790. itr = [item for item in self._node_maps_from_median.items()]
  791. n_jobs = multiprocessing.cpu_count()
  792. if len_itr < 100 * n_jobs:
  793. chunksize = int(len_itr / n_jobs) + 1
  794. else:
  795. chunksize = 100
  796. def init_worker(ged_env_toshare):
  797. global G_ged_env
  798. G_ged_env = ged_env_toshare
  799. nb_nodes_median = self._ged_env.get_graph_num_nodes(self._median_id)
  800. do_fun = partial(_update_node_maps_parallel, self._median_id, self._epsilon, self._sort_graphs, nb_nodes_median)
  801. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(self._ged_env,))
  802. if self._print_to_stdout == 2:
  803. iterator = tqdm(pool.imap_unordered(do_fun, itr, chunksize),
  804. desc='Updating node maps', file=sys.stdout)
  805. else:
  806. iterator = pool.imap_unordered(do_fun, itr, chunksize)
  807. for g_id, node_map, nm_modified in iterator:
  808. self._node_maps_from_median[g_id] = node_map
  809. if nm_modified:
  810. node_maps_were_modified = True
  811. pool.close()
  812. pool.join()
  813. # yyy = self._node_maps_from_median.copy()
  814. else:
  815. # Print information about current iteration.
  816. if self._print_to_stdout == 2:
  817. progress = tqdm(desc='Updating node maps', total=len(self._node_maps_from_median), file=sys.stdout)
  818. node_maps_were_modified = False
  819. nb_nodes_median = self._ged_env.get_graph_num_nodes(self._median_id)
  820. for graph_id, node_map in self._node_maps_from_median.items():
  821. nb_nodes_g = self._ged_env.get_graph_num_nodes(graph_id)
  822. if nb_nodes_median <= nb_nodes_g or not self._sort_graphs:
  823. self._ged_env.run_method(self._median_id, graph_id)
  824. if self._ged_env.get_upper_bound(self._median_id, graph_id) < node_map.induced_cost() - self._epsilon:
  825. # xxx = self._node_maps_from_median[graph_id]
  826. self._node_maps_from_median[graph_id] = self._ged_env.get_node_map(self._median_id, graph_id)
  827. node_maps_were_modified = True
  828. else:
  829. self._ged_env.run_method(graph_id, self._median_id)
  830. if self._ged_env.get_upper_bound(graph_id, self._median_id) < node_map.induced_cost() - self._epsilon:
  831. node_map_tmp = self._ged_env.get_node_map(graph_id, self._median_id)
  832. node_map_tmp.forward_map, node_map_tmp.backward_map = node_map_tmp.backward_map, node_map_tmp.forward_map
  833. self._node_maps_from_median[graph_id] = node_map_tmp
  834. node_maps_were_modified = True
  835. # Print information about current iteration.
  836. if self._print_to_stdout == 2:
  837. progress.update(1)
  838. # Print information about current iteration.
  839. if self._print_to_stdout == 2:
  840. print('\n')
  841. # Return true if the node maps were modified.
  842. return node_maps_were_modified
  843. def _decrease_order(self, graphs, median):
  844. # Print information about current iteration
  845. if self._print_to_stdout == 2:
  846. print('Trying to decrease order: ... ', end='')
  847. if nx.number_of_nodes(median) <= 1:
  848. if self._print_to_stdout == 2:
  849. print('median graph has only 1 node, skip decrease.')
  850. return False
  851. # Initialize ID of the node that is to be deleted.
  852. id_deleted_node = [None] # @todo: or np.inf
  853. decreased_order = False
  854. # Decrease the order as long as the best deletion delta is negative.
  855. while self._compute_best_deletion_delta(graphs, median, id_deleted_node) < -self._epsilon:
  856. decreased_order = True
  857. self._delete_node_from_median(id_deleted_node[0], median)
  858. if nx.number_of_nodes(median) <= 1:
  859. if self._print_to_stdout == 2:
  860. print('decrease stopped because median graph remains only 1 node. ', end='')
  861. break
  862. # Print information about current iteration.
  863. if self._print_to_stdout == 2:
  864. print('done.')
  865. # Return true iff the order was decreased.
  866. return decreased_order
  867. def _compute_best_deletion_delta(self, graphs, median, id_deleted_node):
  868. best_delta = 0.0
  869. # Determine node that should be deleted (if any).
  870. for i in range(0, nx.number_of_nodes(median)):
  871. # Compute cost delta.
  872. delta = 0.0
  873. for graph_id, graph in graphs.items():
  874. k = self._node_maps_from_median[graph_id].image(i)
  875. if k == np.inf:
  876. delta -= self._node_del_cost
  877. else:
  878. delta += self._node_ins_cost - self._ged_env.get_node_rel_cost(median.nodes[i], graph.nodes[k])
  879. for j, j_label in median[i].items():
  880. l = self._node_maps_from_median[graph_id].image(j)
  881. if k == np.inf or l == np.inf:
  882. delta -= self._edge_del_cost
  883. elif not graph.has_edge(k, l):
  884. delta -= self._edge_del_cost
  885. else:
  886. delta += self._edge_ins_cost - self._ged_env.get_edge_rel_cost(j_label, graph.edges[(k, l)])
  887. # Update best deletion delta.
  888. if delta < best_delta - self._epsilon:
  889. best_delta = delta
  890. id_deleted_node[0] = i
  891. # id_deleted_node[0] = 3 # @todo:
  892. return best_delta
  893. def _delete_node_from_median(self, id_deleted_node, median):
  894. # Update the median.
  895. mapping = {}
  896. for i in range(0, nx.number_of_nodes(median)):
  897. if i != id_deleted_node:
  898. new_i = (i if i < id_deleted_node else (i - 1))
  899. mapping[i] = new_i
  900. median.remove_node(id_deleted_node)
  901. nx.relabel_nodes(median, mapping, copy=False)
  902. # Update the node maps.
  903. # xxx = self._node_maps_from_median
  904. for key, node_map in self._node_maps_from_median.items():
  905. new_node_map = NodeMap(nx.number_of_nodes(median), node_map.num_target_nodes())
  906. is_unassigned_target_node = [True] * node_map.num_target_nodes()
  907. for i in range(0, nx.number_of_nodes(median) + 1):
  908. if i != id_deleted_node:
  909. new_i = (i if i < id_deleted_node else (i - 1))
  910. k = node_map.image(i)
  911. new_node_map.add_assignment(new_i, k)
  912. if k != np.inf:
  913. is_unassigned_target_node[k] = False
  914. for k in range(0, node_map.num_target_nodes()):
  915. if is_unassigned_target_node[k]:
  916. new_node_map.add_assignment(np.inf, k)
  917. # print(self._node_maps_from_median[key].forward_map, self._node_maps_from_median[key].backward_map)
  918. # print(new_node_map.forward_map, new_node_map.backward_map
  919. self._node_maps_from_median[key] = new_node_map
  920. # Increase overall number of decreases.
  921. self._num_decrease_order += 1
  922. def _increase_order(self, graphs, median):
  923. # Print information about current iteration.
  924. if self._print_to_stdout == 2:
  925. print('Trying to increase order: ... ', end='')
  926. # Initialize the best configuration and the best label of the node that is to be inserted.
  927. best_config = {}
  928. best_label = self._ged_env.get_node_label(1, to_dict=True)
  929. increased_order = False
  930. # Increase the order as long as the best insertion delta is negative.
  931. while self._compute_best_insertion_delta(graphs, best_config, best_label) < - self._epsilon:
  932. increased_order = True
  933. self._add_node_to_median(best_config, best_label, median)
  934. # Print information about current iteration.
  935. if self._print_to_stdout == 2:
  936. print('done.')
  937. # Return true iff the order was increased.
  938. return increased_order
  939. def _compute_best_insertion_delta(self, graphs, best_config, best_label):
  940. # Construct sets of inserted nodes.
  941. no_inserted_node = True
  942. inserted_nodes = {}
  943. for graph_id, graph in graphs.items():
  944. inserted_nodes[graph_id] = []
  945. best_config[graph_id] = np.inf
  946. for k in range(nx.number_of_nodes(graph)):
  947. if self._node_maps_from_median[graph_id].pre_image(k) == np.inf:
  948. no_inserted_node = False
  949. inserted_nodes[graph_id].append((k, tuple(item for item in graph.nodes[k].items()))) # @todo: can order of label names be garantteed?
  950. # Return 0.0 if no node is inserted in any of the graphs.
  951. if no_inserted_node:
  952. return 0.0
  953. # Compute insertion configuration, label, and delta.
  954. best_delta = 0.0 # @todo
  955. if len(self._label_names['node_labels']) == 0 and len(self._label_names['node_attrs']) == 0: # @todo
  956. best_delta = self._compute_insertion_delta_unlabeled(inserted_nodes, best_config, best_label)
  957. elif len(self._label_names['node_labels']) > 0: # self._constant_node_costs:
  958. best_delta = self._compute_insertion_delta_constant(inserted_nodes, best_config, best_label)
  959. else:
  960. best_delta = self._compute_insertion_delta_generic(inserted_nodes, best_config, best_label)
  961. # Return the best delta.
  962. return best_delta
  963. def _compute_insertion_delta_unlabeled(self, inserted_nodes, best_config, best_label): # @todo: go through and test.
  964. # Construct the nest configuration and compute its insertion delta.
  965. best_delta = 0.0
  966. best_config.clear()
  967. for graph_id, node_set in inserted_nodes.items():
  968. if len(node_set) == 0:
  969. best_config[graph_id] = np.inf
  970. best_delta += self._node_del_cost
  971. else:
  972. best_config[graph_id] = node_set[0][0]
  973. best_delta -= self._node_ins_cost
  974. # Return the best insertion delta.
  975. return best_delta
  976. def _compute_insertion_delta_constant(self, inserted_nodes, best_config, best_label):
  977. # Construct histogram and inverse label maps.
  978. hist = {}
  979. inverse_label_maps = {}
  980. for graph_id, node_set in inserted_nodes.items():
  981. inverse_label_maps[graph_id] = {}
  982. for node in node_set:
  983. k = node[0]
  984. label = node[1]
  985. if label not in inverse_label_maps[graph_id]:
  986. inverse_label_maps[graph_id][label] = k
  987. if label not in hist:
  988. hist[label] = 1
  989. else:
  990. hist[label] += 1
  991. # Determine the best label.
  992. best_count = 0
  993. for key, val in hist.items():
  994. if val > best_count:
  995. best_count = val
  996. best_label_tuple = key
  997. # get best label.
  998. best_label.clear()
  999. for key, val in best_label_tuple:
  1000. best_label[key] = val
  1001. # Construct the best configuration and compute its insertion delta.
  1002. best_config.clear()
  1003. best_delta = 0.0
  1004. node_rel_cost = self._ged_env.get_node_rel_cost(self._ged_env.get_node_label(1, to_dict=False), self._ged_env.get_node_label(2, to_dict=False))
  1005. triangle_ineq_holds = (node_rel_cost <= self._node_del_cost + self._node_ins_cost)
  1006. for graph_id, _ in inserted_nodes.items():
  1007. if best_label_tuple in inverse_label_maps[graph_id]:
  1008. best_config[graph_id] = inverse_label_maps[graph_id][best_label_tuple]
  1009. best_delta -= self._node_ins_cost
  1010. elif triangle_ineq_holds and not len(inserted_nodes[graph_id]) == 0:
  1011. best_config[graph_id] = inserted_nodes[graph_id][0][0]
  1012. best_delta += node_rel_cost - self._node_ins_cost
  1013. else:
  1014. best_config[graph_id] = np.inf
  1015. best_delta += self._node_del_cost
  1016. # Return the best insertion delta.
  1017. return best_delta
  1018. def _compute_insertion_delta_generic(self, inserted_nodes, best_config, best_label):
  1019. # Collect all node labels of inserted nodes.
  1020. node_labels = []
  1021. for _, node_set in inserted_nodes.items():
  1022. for node in node_set:
  1023. node_labels.append(node[1])
  1024. # Compute node label medians that serve as initial solutions for block gradient descent.
  1025. initial_node_labels = []
  1026. self._compute_initial_node_labels(node_labels, initial_node_labels)
  1027. # Determine best insertion configuration, label, and delta via parallel block gradient descent from all initial node labels.
  1028. best_delta = 0.0
  1029. for node_label in initial_node_labels:
  1030. # Construct local configuration.
  1031. config = {}
  1032. for graph_id, _ in inserted_nodes.items():
  1033. config[graph_id] = tuple((np.inf, self._ged_env.get_node_label(1, to_dict=False)))
  1034. # Run block gradient descent.
  1035. converged = False
  1036. itr = 0
  1037. while not self._insertion_termination_criterion_met(converged, itr):
  1038. converged = not self._update_config(node_label, inserted_nodes, config, node_labels)
  1039. node_label_dict = dict(node_label)
  1040. converged = converged and (not self._update_node_label([dict(item) for item in node_labels], node_label_dict)) # @todo: the dict is tupled again in the function, can be better.
  1041. node_label = tuple(item for item in node_label_dict.items()) # @todo: watch out: initial_node_labels[i] is not modified here.
  1042. itr += 1
  1043. # Compute insertion delta of converged solution.
  1044. delta = 0.0
  1045. for _, node in config.items():
  1046. if node[0] == np.inf:
  1047. delta += self._node_del_cost
  1048. else:
  1049. delta += self._ged_env.get_node_rel_cost(dict(node_label), dict(node[1])) - self._node_ins_cost
  1050. # Update best delta and global configuration if improvement has been found.
  1051. if delta < best_delta - self._epsilon:
  1052. best_delta = delta
  1053. best_label.clear()
  1054. for key, val in node_label:
  1055. best_label[key] = val
  1056. best_config.clear()
  1057. for graph_id, val in config.items():
  1058. best_config[graph_id] = val[0]
  1059. # Return the best delta.
  1060. return best_delta
  1061. def _compute_initial_node_labels(self, node_labels, median_labels):
  1062. median_labels.clear()
  1063. if self._use_real_randomness: # @todo: may not work if parallelized.
  1064. rng = np.random.randint(0, high=2**32 - 1, size=1)
  1065. urng = np.random.RandomState(seed=rng[0])
  1066. else:
  1067. urng = np.random.RandomState(seed=self._seed)
  1068. # Generate the initial node label medians.
  1069. if self._init_type_increase_order == 'K-MEANS++':
  1070. # Use k-means++ heuristic to generate the initial node label medians.
  1071. already_selected = [False] * len(node_labels)
  1072. selected_label_id = urng.randint(low=0, high=len(node_labels), size=1)[0] # c++ test: 23
  1073. median_labels.append(node_labels[selected_label_id])
  1074. already_selected[selected_label_id] = True
  1075. # xxx = [41, 0, 18, 9, 6, 14, 21, 25, 33] for c++ test
  1076. # iii = 0 for c++ test
  1077. while len(median_labels) < self._num_inits_increase_order:
  1078. weights = [np.inf] * len(node_labels)
  1079. for label_id in range(0, len(node_labels)):
  1080. if already_selected[label_id]:
  1081. weights[label_id] = 0
  1082. continue
  1083. for label in median_labels:
  1084. weights[label_id] = min(weights[label_id], self._ged_env.get_node_rel_cost(dict(label), dict(node_labels[label_id])))
  1085. # get non-zero weights.
  1086. weights_p, idx_p = [], []
  1087. for i, w in enumerate(weights):
  1088. if w != 0:
  1089. weights_p.append(w)
  1090. idx_p.append(i)
  1091. if len(weights_p) > 0:
  1092. p = np.array(weights_p) / np.sum(weights_p)
  1093. selected_label_id = urng.choice(range(0, len(weights_p)), size=1, p=p)[0] # for c++ test: xxx[iii]
  1094. selected_label_id = idx_p[selected_label_id]
  1095. # iii += 1 for c++ test
  1096. median_labels.append(node_labels[selected_label_id])
  1097. already_selected[selected_label_id] = True
  1098. else: # skip the loop when all node_labels are selected. This happens when len(node_labels) <= self._num_inits_increase_order.
  1099. break
  1100. else:
  1101. # Compute the initial node medians as the medians of randomly generated clusters of (roughly) equal size.
  1102. # @todo: go through and test.
  1103. shuffled_node_labels = [np.inf] * len(node_labels) #@todo: random?
  1104. # @todo: std::shuffle(shuffled_node_labels.begin(), shuffled_node_labels.end(), urng);?
  1105. cluster_size = len(node_labels) / self._num_inits_increase_order
  1106. pos = 0.0
  1107. cluster = []
  1108. while len(median_labels) < self._num_inits_increase_order - 1:
  1109. while pos < (len(median_labels) + 1) * cluster_size:
  1110. cluster.append(shuffled_node_labels[pos])
  1111. pos += 1
  1112. median_labels.append(self._get_median_node_label(cluster))
  1113. cluster.clear()
  1114. while pos < len(shuffled_node_labels):
  1115. pos += 1
  1116. cluster.append(shuffled_node_labels[pos])
  1117. median_labels.append(self._get_median_node_label(cluster))
  1118. cluster.clear()
  1119. # Run Lloyd's Algorithm.
  1120. converged = False
  1121. closest_median_ids = [np.inf] * len(node_labels)
  1122. clusters = [[] for _ in range(len(median_labels))]
  1123. itr = 1
  1124. while not self._insertion_termination_criterion_met(converged, itr):
  1125. converged = not self._update_clusters(node_labels, median_labels, closest_median_ids)
  1126. if not converged:
  1127. for cluster in clusters:
  1128. cluster.clear()
  1129. for label_id in range(0, len(node_labels)):
  1130. clusters[closest_median_ids[label_id]].append(node_labels[label_id])
  1131. for cluster_id in range(0, len(clusters)):
  1132. node_label = dict(median_labels[cluster_id])
  1133. self._update_node_label([dict(item) for item in clusters[cluster_id]], node_label) # @todo: the dict is tupled again in the function, can be better.
  1134. median_labels[cluster_id] = tuple(item for item in node_label.items())
  1135. itr += 1
  1136. def _insertion_termination_criterion_met(self, converged, itr):
  1137. return converged or (itr >= self._max_itrs_increase_order if self._max_itrs_increase_order > 0 else False)
  1138. def _update_config(self, node_label, inserted_nodes, config, node_labels):
  1139. # Determine the best configuration.
  1140. config_modified = False
  1141. for graph_id, node_set in inserted_nodes.items():
  1142. best_assignment = config[graph_id]
  1143. best_cost = 0.0
  1144. if best_assignment[0] == np.inf:
  1145. best_cost = self._node_del_cost
  1146. else:
  1147. best_cost = self._ged_env.get_node_rel_cost(dict(node_label), dict(best_assignment[1])) - self._node_ins_cost
  1148. for node in node_set:
  1149. cost = self._ged_env.get_node_rel_cost(dict(node_label), dict(node[1])) - self._node_ins_cost
  1150. if cost < best_cost - self._epsilon:
  1151. best_cost = cost
  1152. best_assignment = node
  1153. config_modified = True
  1154. if self._node_del_cost < best_cost - self._epsilon:
  1155. best_cost = self._node_del_cost
  1156. best_assignment = tuple((np.inf, best_assignment[1]))
  1157. config_modified = True
  1158. config[graph_id] = best_assignment
  1159. # Collect the node labels contained in the best configuration.
  1160. node_labels.clear()
  1161. for key, val in config.items():
  1162. if val[0] != np.inf:
  1163. node_labels.append(val[1])
  1164. # Return true if the configuration was modified.
  1165. return config_modified
  1166. def _update_node_label(self, node_labels, node_label):
  1167. if len(node_labels) == 0: # @todo: check if this is the correct solution. Especially after calling _update_config().
  1168. return False
  1169. new_node_label = self._get_median_node_label(node_labels)
  1170. if self._ged_env.get_node_rel_cost(new_node_label, node_label) > self._epsilon:
  1171. node_label.clear()
  1172. for key, val in new_node_label.items():
  1173. node_label[key] = val
  1174. return True
  1175. return False
  1176. def _update_clusters(self, node_labels, median_labels, closest_median_ids):
  1177. # Determine the closest median for each node label.
  1178. clusters_modified = False
  1179. for label_id in range(0, len(node_labels)):
  1180. closest_median_id = np.inf
  1181. dist_to_closest_median = np.inf
  1182. for median_id in range(0, len(median_labels)):
  1183. dist_to_median = self._ged_env.get_node_rel_cost(dict(median_labels[median_id]), dict(node_labels[label_id]))
  1184. if dist_to_median < dist_to_closest_median - self._epsilon:
  1185. dist_to_closest_median = dist_to_median
  1186. closest_median_id = median_id
  1187. if closest_median_id != closest_median_ids[label_id]:
  1188. closest_median_ids[label_id] = closest_median_id
  1189. clusters_modified = True
  1190. # Return true if the clusters were modified.
  1191. return clusters_modified
  1192. def _add_node_to_median(self, best_config, best_label, median):
  1193. # Update the median.
  1194. nb_nodes_median = nx.number_of_nodes(median)
  1195. median.add_node(nb_nodes_median, **best_label)
  1196. # Update the node maps.
  1197. for graph_id, node_map in self._node_maps_from_median.items():
  1198. node_map_as_rel = []
  1199. node_map.as_relation(node_map_as_rel)
  1200. new_node_map = NodeMap(nx.number_of_nodes(median), node_map.num_target_nodes())
  1201. for assignment in node_map_as_rel:
  1202. new_node_map.add_assignment(assignment[0], assignment[1])
  1203. new_node_map.add_assignment(nx.number_of_nodes(median) - 1, best_config[graph_id])
  1204. self._node_maps_from_median[graph_id] = new_node_map
  1205. # Increase overall number of increases.
  1206. self._num_increase_order += 1
  1207. def _are_graphs_equal(self, g1, g2):
  1208. """
  1209. Check if the two graphs are equal.
  1210. Parameters
  1211. ----------
  1212. g1 : NetworkX graph object
  1213. Graph 1 to be compared.
  1214. g2 : NetworkX graph object
  1215. Graph 2 to be compared.
  1216. Returns
  1217. -------
  1218. bool
  1219. True if the two graph are equal.
  1220. Notes
  1221. -----
  1222. This is not an identical check. Here the two graphs are equal if and only if their original_node_ids, nodes, all node labels, edges and all edge labels are equal. This function is specifically designed for class `MedianGraphEstimator` and should not be used elsewhere.
  1223. """
  1224. # check original node ids.
  1225. if not g1.graph['original_node_ids'] == g2.graph['original_node_ids']:
  1226. return False # @todo: why check this?
  1227. # check nodes.
  1228. nlist1 = [n for n in g1.nodes(data=True)] # @todo: shallow?
  1229. nlist2 = [n for n in g2.nodes(data=True)]
  1230. if not nlist1 == nlist2:
  1231. return False
  1232. # check edges.
  1233. elist1 = [n for n in g1.edges(data=True)]
  1234. elist2 = [n for n in g2.edges(data=True)]
  1235. if not elist1 == elist2:
  1236. return False
  1237. return True
  1238. def compute_my_cost(g, h, node_map):
  1239. cost = 0.0
  1240. for node in g.nodes:
  1241. cost += 0
  1242. def set_label_names(self, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):
  1243. self._label_names = {'node_labels': node_labels, 'edge_labels': edge_labels,
  1244. 'node_attrs': node_attrs, 'edge_attrs': edge_attrs}
  1245. # def _get_median_node_label(self, node_labels):
  1246. # if len(self._label_names['node_labels']) > 0:
  1247. # return self._get_median_label_symbolic(node_labels)
  1248. # elif len(self._label_names['node_attrs']) > 0:
  1249. # return self._get_median_label_nonsymbolic(node_labels)
  1250. # else:
  1251. # raise Exception('Node label names are not given.')
  1252. #
  1253. #
  1254. # def _get_median_edge_label(self, edge_labels):
  1255. # if len(self._label_names['edge_labels']) > 0:
  1256. # return self._get_median_label_symbolic(edge_labels)
  1257. # elif len(self._label_names['edge_attrs']) > 0:
  1258. # return self._get_median_label_nonsymbolic(edge_labels)
  1259. # else:
  1260. # raise Exception('Edge label names are not given.')
  1261. #
  1262. #
  1263. # def _get_median_label_symbolic(self, labels):
  1264. # f_i = np.inf
  1265. #
  1266. # for label in labels:
  1267. # pass
  1268. #
  1269. # # Construct histogram.
  1270. # hist = {}
  1271. # for label in labels:
  1272. # label = tuple([kv for kv in label.items()]) # @todo: this may be slow.
  1273. # if label not in hist:
  1274. # hist[label] = 1
  1275. # else:
  1276. # hist[label] += 1
  1277. #
  1278. # # Return the label that appears most frequently.
  1279. # best_count = 0
  1280. # median_label = {}
  1281. # for label, count in hist.items():
  1282. # if count > best_count:
  1283. # best_count = count
  1284. # median_label = {kv[0]: kv[1] for kv in label}
  1285. #
  1286. # return median_label
  1287. #
  1288. #
  1289. # def _get_median_label_nonsymbolic(self, labels):
  1290. # if len(labels) == 0:
  1291. # return {} # @todo
  1292. # else:
  1293. # # Transform the labels into coordinates and compute mean label as initial solution.
  1294. # labels_as_coords = []
  1295. # sums = {}
  1296. # for key, val in labels[0].items():
  1297. # sums[key] = 0
  1298. # for label in labels:
  1299. # coords = {}
  1300. # for key, val in label.items():
  1301. # label_f = float(val)
  1302. # sums[key] += label_f
  1303. # coords[key] = label_f
  1304. # labels_as_coords.append(coords)
  1305. # median = {}
  1306. # for key, val in sums.items():
  1307. # median[key] = val / len(labels)
  1308. #
  1309. # # Run main loop of Weiszfeld's Algorithm.
  1310. # epsilon = 0.0001
  1311. # delta = 1.0
  1312. # num_itrs = 0
  1313. # all_equal = False
  1314. # while ((delta > epsilon) and (num_itrs < 100) and (not all_equal)):
  1315. # numerator = {}
  1316. # for key, val in sums.items():
  1317. # numerator[key] = 0
  1318. # denominator = 0
  1319. # for label_as_coord in labels_as_coords:
  1320. # norm = 0
  1321. # for key, val in label_as_coord.items():
  1322. # norm += (val - median[key]) ** 2
  1323. # norm = np.sqrt(norm)
  1324. # if norm > 0:
  1325. # for key, val in label_as_coord.items():
  1326. # numerator[key] += val / norm
  1327. # denominator += 1.0 / norm
  1328. # if denominator == 0:
  1329. # all_equal = True
  1330. # else:
  1331. # new_median = {}
  1332. # delta = 0.0
  1333. # for key, val in numerator.items():
  1334. # this_median = val / denominator
  1335. # new_median[key] = this_median
  1336. # delta += np.abs(median[key] - this_median)
  1337. # median = new_median
  1338. #
  1339. # num_itrs += 1
  1340. #
  1341. # # Transform the solution to strings and return it.
  1342. # median_label = {}
  1343. # for key, val in median.items():
  1344. # median_label[key] = str(val)
  1345. # return median_label
  1346. def _compute_medoid_parallel(graph_ids, sort, itr):
  1347. g_id = itr[0]
  1348. i = itr[1]
  1349. # @todo: timer not considered here.
  1350. # if timer.expired():
  1351. # self._state = AlgorithmState.CALLED
  1352. # break
  1353. nb_nodes_g = G_ged_env.get_graph_num_nodes(g_id)
  1354. sum_of_distances = 0
  1355. for h_id in graph_ids:
  1356. nb_nodes_h = G_ged_env.get_graph_num_nodes(h_id)
  1357. if nb_nodes_g <= nb_nodes_h or not sort:
  1358. G_ged_env.run_method(g_id, h_id)
  1359. sum_of_distances += G_ged_env.get_upper_bound(g_id, h_id)
  1360. else:
  1361. G_ged_env.run_method(h_id, g_id)
  1362. sum_of_distances += G_ged_env.get_upper_bound(h_id, g_id)
  1363. return i, sum_of_distances
  1364. def _compute_init_node_maps_parallel(gen_median_id, sort, nb_nodes_median, itr):
  1365. graph_id = itr
  1366. nb_nodes_g = G_ged_env.get_graph_num_nodes(graph_id)
  1367. if nb_nodes_median <= nb_nodes_g or not sort:
  1368. G_ged_env.run_method(gen_median_id, graph_id)
  1369. node_map = G_ged_env.get_node_map(gen_median_id, graph_id)
  1370. # print(self._node_maps_from_median[graph_id])
  1371. else:
  1372. G_ged_env.run_method(graph_id, gen_median_id)
  1373. node_map = G_ged_env.get_node_map(graph_id, gen_median_id)
  1374. node_map.forward_map, node_map.backward_map = node_map.backward_map, node_map.forward_map
  1375. sum_of_distance = node_map.induced_cost()
  1376. # print(self._sum_of_distances)
  1377. return graph_id, sum_of_distance, node_map
  1378. def _update_node_maps_parallel(median_id, epsilon, sort, nb_nodes_median, itr):
  1379. graph_id = itr[0]
  1380. node_map = itr[1]
  1381. node_maps_were_modified = False
  1382. nb_nodes_g = G_ged_env.get_graph_num_nodes(graph_id)
  1383. if nb_nodes_median <= nb_nodes_g or not sort:
  1384. G_ged_env.run_method(median_id, graph_id)
  1385. if G_ged_env.get_upper_bound(median_id, graph_id) < node_map.induced_cost() - epsilon:
  1386. node_map = G_ged_env.get_node_map(median_id, graph_id)
  1387. node_maps_were_modified = True
  1388. else:
  1389. G_ged_env.run_method(graph_id, median_id)
  1390. if G_ged_env.get_upper_bound(graph_id, median_id) < node_map.induced_cost() - epsilon:
  1391. node_map = G_ged_env.get_node_map(graph_id, median_id)
  1392. node_map.forward_map, node_map.backward_map = node_map.backward_map, node_map.forward_map
  1393. node_maps_were_modified = True
  1394. return graph_id, node_map, node_maps_were_modified

A Python package for graph kernels, graph edit distances and graph pre-image problem.