You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

median_graph_estimator.py 60 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Mon Mar 16 18:04:55 2020
  5. @author: ljia
  6. """
  7. import numpy as np
  8. from gklearn.ged.env import AlgorithmState, NodeMap
  9. from gklearn.ged.util import misc
  10. from gklearn.utils import Timer
  11. import time
  12. from tqdm import tqdm
  13. import sys
  14. import networkx as nx
  15. import multiprocessing
  16. from multiprocessing import Pool
  17. from functools import partial
  18. class MedianGraphEstimator(object): # @todo: differ dummy_node from undifined node?
  19. def __init__(self, ged_env, constant_node_costs):
  20. """Constructor.
  21. Parameters
  22. ----------
  23. ged_env : gklearn.gedlib.gedlibpy.GEDEnv
  24. Initialized GED environment. The edit costs must be set by the user.
  25. constant_node_costs : Boolean
  26. Set to True if the node relabeling costs are constant.
  27. """
  28. self.__ged_env = ged_env
  29. self.__init_method = 'BRANCH_FAST'
  30. self.__init_options = ''
  31. self.__descent_method = 'BRANCH_FAST'
  32. self.__descent_options = ''
  33. self.__refine_method = 'IPFP'
  34. self.__refine_options = ''
  35. self.__constant_node_costs = constant_node_costs
  36. self.__labeled_nodes = (ged_env.get_num_node_labels() > 1)
  37. self.__node_del_cost = ged_env.get_node_del_cost(ged_env.get_node_label(1))
  38. self.__node_ins_cost = ged_env.get_node_ins_cost(ged_env.get_node_label(1))
  39. self.__labeled_edges = (ged_env.get_num_edge_labels() > 1)
  40. self.__edge_del_cost = ged_env.get_edge_del_cost(ged_env.get_edge_label(1))
  41. self.__edge_ins_cost = ged_env.get_edge_ins_cost(ged_env.get_edge_label(1))
  42. self.__init_type = 'RANDOM'
  43. self.__num_random_inits = 10
  44. self.__desired_num_random_inits = 10
  45. self.__use_real_randomness = True
  46. self.__seed = 0
  47. self.__parallel = True
  48. self.__update_order = True
  49. self.__refine = True
  50. self.__time_limit_in_sec = 0
  51. self.__epsilon = 0.0001
  52. self.__max_itrs = 100
  53. self.__max_itrs_without_update = 3
  54. self.__num_inits_increase_order = 10
  55. self.__init_type_increase_order = 'K-MEANS++'
  56. self.__max_itrs_increase_order = 10
  57. self.__print_to_stdout = 2
  58. self.__median_id = np.inf # @todo: check
  59. self.__node_maps_from_median = {}
  60. self.__sum_of_distances = 0
  61. self.__best_init_sum_of_distances = np.inf
  62. self.__converged_sum_of_distances = np.inf
  63. self.__runtime = None
  64. self.__runtime_initialized = None
  65. self.__runtime_converged = None
  66. self.__itrs = [] # @todo: check: {} ?
  67. self.__num_decrease_order = 0
  68. self.__num_increase_order = 0
  69. self.__num_converged_descents = 0
  70. self.__state = AlgorithmState.TERMINATED
  71. self.__label_names = {}
  72. if ged_env is None:
  73. raise Exception('The GED environment pointer passed to the constructor of MedianGraphEstimator is null.')
  74. elif not ged_env.is_initialized():
  75. raise Exception('The GED environment is uninitialized. Call gedlibpy.GEDEnv.init() before passing it to the constructor of MedianGraphEstimator.')
  76. def set_options(self, options):
  77. """Sets the options of the estimator.
  78. Parameters
  79. ----------
  80. options : string
  81. String that specifies with which options to run the estimator.
  82. """
  83. self.__set_default_options()
  84. options_map = misc.options_string_to_options_map(options)
  85. for opt_name, opt_val in options_map.items():
  86. if opt_name == 'init-type':
  87. self.__init_type = opt_val
  88. if opt_val != 'MEDOID' and opt_val != 'RANDOM' and opt_val != 'MIN' and opt_val != 'MAX' and opt_val != 'MEAN':
  89. raise Exception('Invalid argument ' + opt_val + ' for option init-type. Usage: options = "[--init-type RANDOM|MEDOID|EMPTY|MIN|MAX|MEAN] [...]"')
  90. elif opt_name == 'random-inits':
  91. try:
  92. self.__num_random_inits = int(opt_val)
  93. self.__desired_num_random_inits = self.__num_random_inits
  94. except:
  95. raise Exception('Invalid argument "' + opt_val + '" for option random-inits. Usage: options = "[--random-inits <convertible to int greater 0>]"')
  96. if self.__num_random_inits <= 0:
  97. raise Exception('Invalid argument "' + opt_val + '" for option random-inits. Usage: options = "[--random-inits <convertible to int greater 0>]"')
  98. elif opt_name == 'randomness':
  99. if opt_val == 'PSEUDO':
  100. self.__use_real_randomness = False
  101. elif opt_val == 'REAL':
  102. self.__use_real_randomness = True
  103. else:
  104. raise Exception('Invalid argument "' + opt_val + '" for option randomness. Usage: options = "[--randomness REAL|PSEUDO] [...]"')
  105. elif opt_name == 'stdout':
  106. if opt_val == '0':
  107. self.__print_to_stdout = 0
  108. elif opt_val == '1':
  109. self.__print_to_stdout = 1
  110. elif opt_val == '2':
  111. self.__print_to_stdout = 2
  112. else:
  113. raise Exception('Invalid argument "' + opt_val + '" for option stdout. Usage: options = "[--stdout 0|1|2] [...]"')
  114. elif opt_name == 'parallel':
  115. if opt_val == 'TRUE':
  116. self.__parallel = True
  117. elif opt_val == 'FALSE':
  118. self.__parallel = False
  119. else:
  120. raise Exception('Invalid argument "' + opt_val + '" for option parallel. Usage: options = "[--parallel TRUE|FALSE] [...]"')
  121. elif opt_name == 'update-order':
  122. if opt_val == 'TRUE':
  123. self.__update_order = True
  124. elif opt_val == 'FALSE':
  125. self.__update_order = False
  126. else:
  127. raise Exception('Invalid argument "' + opt_val + '" for option update-order. Usage: options = "[--update-order TRUE|FALSE] [...]"')
  128. elif opt_name == 'refine':
  129. if opt_val == 'TRUE':
  130. self.__refine = True
  131. elif opt_val == 'FALSE':
  132. self.__refine = False
  133. else:
  134. raise Exception('Invalid argument "' + opt_val + '" for option refine. Usage: options = "[--refine TRUE|FALSE] [...]"')
  135. elif opt_name == 'time-limit':
  136. try:
  137. self.__time_limit_in_sec = float(opt_val)
  138. except:
  139. raise Exception('Invalid argument "' + opt_val + '" for option time-limit. Usage: options = "[--time-limit <convertible to double>] [...]')
  140. elif opt_name == 'max-itrs':
  141. try:
  142. self.__max_itrs = int(opt_val)
  143. except:
  144. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs. Usage: options = "[--max-itrs <convertible to int>] [...]')
  145. elif opt_name == 'max-itrs-without-update':
  146. try:
  147. self.__max_itrs_without_update = int(opt_val)
  148. except:
  149. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs-without-update. Usage: options = "[--max-itrs-without-update <convertible to int>] [...]')
  150. elif opt_name == 'seed':
  151. try:
  152. self.__seed = int(opt_val)
  153. except:
  154. raise Exception('Invalid argument "' + opt_val + '" for option seed. Usage: options = "[--seed <convertible to int greater equal 0>] [...]')
  155. elif opt_name == 'epsilon':
  156. try:
  157. self.__epsilon = float(opt_val)
  158. except:
  159. raise Exception('Invalid argument "' + opt_val + '" for option epsilon. Usage: options = "[--epsilon <convertible to double greater 0>] [...]')
  160. if self.__epsilon <= 0:
  161. raise Exception('Invalid argument "' + opt_val + '" for option epsilon. Usage: options = "[--epsilon <convertible to double greater 0>] [...]')
  162. elif opt_name == 'inits-increase-order':
  163. try:
  164. self.__num_inits_increase_order = int(opt_val)
  165. except:
  166. raise Exception('Invalid argument "' + opt_val + '" for option inits-increase-order. Usage: options = "[--inits-increase-order <convertible to int greater 0>]"')
  167. if self.__num_inits_increase_order <= 0:
  168. raise Exception('Invalid argument "' + opt_val + '" for option inits-increase-order. Usage: options = "[--inits-increase-order <convertible to int greater 0>]"')
  169. elif opt_name == 'init-type-increase-order':
  170. self.__init_type_increase_order = opt_val
  171. if opt_val != 'CLUSTERS' and opt_val != 'K-MEANS++':
  172. raise Exception('Invalid argument ' + opt_val + ' for option init-type-increase-order. Usage: options = "[--init-type-increase-order CLUSTERS|K-MEANS++] [...]"')
  173. elif opt_name == 'max-itrs-increase-order':
  174. try:
  175. self.__max_itrs_increase_order = int(opt_val)
  176. except:
  177. raise Exception('Invalid argument "' + opt_val + '" for option max-itrs-increase-order. Usage: options = "[--max-itrs-increase-order <convertible to int>] [...]')
  178. else:
  179. valid_options = '[--init-type <arg>] [--random-inits <arg>] [--randomness <arg>] [--seed <arg>] [--stdout <arg>] '
  180. valid_options += '[--time-limit <arg>] [--max-itrs <arg>] [--epsilon <arg>] '
  181. valid_options += '[--inits-increase-order <arg>] [--init-type-increase-order <arg>] [--max-itrs-increase-order <arg>]'
  182. raise Exception('Invalid option "' + opt_name + '". Usage: options = "' + valid_options + '"')
  183. def set_init_method(self, init_method, init_options=''):
  184. """Selects method to be used for computing the initial medoid graph.
  185. Parameters
  186. ----------
  187. init_method : string
  188. The selected method. Default: ged::Options::GEDMethod::BRANCH_UNIFORM.
  189. init_options : string
  190. The options for the selected method. Default: "".
  191. Notes
  192. -----
  193. Has no effect unless "--init-type MEDOID" is passed to set_options().
  194. """
  195. self.__init_method = init_method;
  196. self.__init_options = init_options;
  197. def set_descent_method(self, descent_method, descent_options=''):
  198. """Selects method to be used for block gradient descent..
  199. Parameters
  200. ----------
  201. descent_method : string
  202. The selected method. Default: ged::Options::GEDMethod::BRANCH_FAST.
  203. descent_options : string
  204. The options for the selected method. Default: "".
  205. Notes
  206. -----
  207. Has no effect unless "--init-type MEDOID" is passed to set_options().
  208. """
  209. self.__descent_method = descent_method;
  210. self.__descent_options = descent_options;
  211. def set_refine_method(self, refine_method, refine_options):
  212. """Selects method to be used for improving the sum of distances and the node maps for the converged median.
  213. Parameters
  214. ----------
  215. refine_method : string
  216. The selected method. Default: "IPFP".
  217. refine_options : string
  218. The options for the selected method. Default: "".
  219. Notes
  220. -----
  221. Has no effect if "--refine FALSE" is passed to set_options().
  222. """
  223. self.__refine_method = refine_method
  224. self.__refine_options = refine_options
  225. def run(self, graph_ids, set_median_id, gen_median_id):
  226. """Computes a generalized median graph.
  227. Parameters
  228. ----------
  229. graph_ids : list[integer]
  230. The IDs of the graphs for which the median should be computed. Must have been added to the environment passed to the constructor.
  231. set_median_id : integer
  232. The ID of the computed set-median. A dummy graph with this ID must have been added to the environment passed to the constructor. Upon termination, the computed median can be obtained via gklearn.gedlib.gedlibpy.GEDEnv.get_graph().
  233. gen_median_id : integer
  234. The ID of the computed generalized median. Upon termination, the computed median can be obtained via gklearn.gedlib.gedlibpy.GEDEnv.get_graph().
  235. """
  236. # Sanity checks.
  237. if len(graph_ids) == 0:
  238. raise Exception('Empty vector of graph IDs, unable to compute median.')
  239. all_graphs_empty = True
  240. for graph_id in graph_ids:
  241. if self.__ged_env.get_graph_num_nodes(graph_id) > 0:
  242. all_graphs_empty = False
  243. break
  244. if all_graphs_empty:
  245. raise Exception('All graphs in the collection are empty.')
  246. # Start timer and record start time.
  247. start = time.time()
  248. timer = Timer(self.__time_limit_in_sec)
  249. self.__median_id = gen_median_id
  250. self.__state = AlgorithmState.TERMINATED
  251. # Get NetworkX graph representations of the input graphs.
  252. graphs = {}
  253. for graph_id in graph_ids:
  254. # @todo: get_nx_graph() function may need to be modified according to the coming code.
  255. graphs[graph_id] = self.__ged_env.get_nx_graph(graph_id, True, True, False)
  256. # print(self.__ged_env.get_graph_internal_id(0))
  257. # print(graphs[0].graph)
  258. # print(graphs[0].nodes(data=True))
  259. # print(graphs[0].edges(data=True))
  260. # print(nx.adjacency_matrix(graphs[0]))
  261. # Construct initial medians.
  262. medians = []
  263. self.__construct_initial_medians(graph_ids, timer, medians)
  264. end_init = time.time()
  265. self.__runtime_initialized = end_init - start
  266. # print(medians[0].graph)
  267. # print(medians[0].nodes(data=True))
  268. # print(medians[0].edges(data=True))
  269. # print(nx.adjacency_matrix(medians[0]))
  270. # Reset information about iterations and number of times the median decreases and increases.
  271. self.__itrs = [0] * len(medians)
  272. self.__num_decrease_order = 0
  273. self.__num_increase_order = 0
  274. self.__num_converged_descents = 0
  275. # Initialize the best median.
  276. best_sum_of_distances = np.inf
  277. self.__best_init_sum_of_distances = np.inf
  278. node_maps_from_best_median = {}
  279. # Run block gradient descent from all initial medians.
  280. self.__ged_env.set_method(self.__descent_method, self.__descent_options)
  281. for median_pos in range(0, len(medians)):
  282. # Terminate if the timer has expired and at least one SOD has been computed.
  283. if timer.expired() and median_pos > 0:
  284. break
  285. # Print information about current iteration.
  286. if self.__print_to_stdout == 2:
  287. print('\n===========================================================')
  288. print('Block gradient descent for initial median', str(median_pos + 1), 'of', str(len(medians)), '.')
  289. print('-----------------------------------------------------------')
  290. # Get reference to the median.
  291. median = medians[median_pos]
  292. # Load initial median into the environment.
  293. self.__ged_env.load_nx_graph(median, gen_median_id)
  294. self.__ged_env.init(self.__ged_env.get_init_type())
  295. # Compute node maps and sum of distances for initial median.
  296. self.__compute_init_node_maps(graph_ids, gen_median_id)
  297. self.__best_init_sum_of_distances = min(self.__best_init_sum_of_distances, self.__sum_of_distances)
  298. self.__ged_env.load_nx_graph(median, set_median_id)
  299. # print(self.__best_init_sum_of_distances)
  300. # Run block gradient descent from initial median.
  301. converged = False
  302. itrs_without_update = 0
  303. while not self.__termination_criterion_met(converged, timer, self.__itrs[median_pos], itrs_without_update):
  304. # Print information about current iteration.
  305. if self.__print_to_stdout == 2:
  306. print('\n===========================================================')
  307. print('Iteration', str(self.__itrs[median_pos] + 1), 'for initial median', str(median_pos + 1), 'of', str(len(medians)), '.')
  308. print('-----------------------------------------------------------')
  309. # Initialize flags that tell us what happened in the iteration.
  310. median_modified = False
  311. node_maps_modified = False
  312. decreased_order = False
  313. increased_order = False
  314. # Update the median.
  315. median_modified = self.__update_median(graphs, median)
  316. if self.__update_order:
  317. if not median_modified or self.__itrs[median_pos] == 0:
  318. decreased_order = self.__decrease_order(graphs, median)
  319. if not decreased_order or self.__itrs[median_pos] == 0:
  320. increased_order = self.__increase_order(graphs, median)
  321. # Update the number of iterations without update of the median.
  322. if median_modified or decreased_order or increased_order:
  323. itrs_without_update = 0
  324. else:
  325. itrs_without_update += 1
  326. # Print information about current iteration.
  327. if self.__print_to_stdout == 2:
  328. print('Loading median to environment: ... ', end='')
  329. # Load the median into the environment.
  330. # @todo: should this function use the original node label?
  331. self.__ged_env.load_nx_graph(median, gen_median_id)
  332. self.__ged_env.init(self.__ged_env.get_init_type())
  333. # Print information about current iteration.
  334. if self.__print_to_stdout == 2:
  335. print('done.')
  336. # Print information about current iteration.
  337. if self.__print_to_stdout == 2:
  338. print('Updating induced costs: ... ', end='')
  339. # Compute induced costs of the old node maps w.r.t. the updated median.
  340. for graph_id in graph_ids:
  341. # print(self.__node_maps_from_median[graph_id].induced_cost())
  342. # xxx = self.__node_maps_from_median[graph_id]
  343. self.__ged_env.compute_induced_cost(gen_median_id, graph_id, self.__node_maps_from_median[graph_id])
  344. # print('---------------------------------------')
  345. # print(self.__node_maps_from_median[graph_id].induced_cost())
  346. # @todo:!!!!!!!!!!!!!!!!!!!!!!!!!!!!This value is a slight different from the c++ program, which might be a bug! Use it very carefully!
  347. # Print information about current iteration.
  348. if self.__print_to_stdout == 2:
  349. print('done.')
  350. # Update the node maps.
  351. node_maps_modified = self.__update_node_maps()
  352. # Update the order of the median if no improvement can be found with the current order.
  353. # Update the sum of distances.
  354. old_sum_of_distances = self.__sum_of_distances
  355. self.__sum_of_distances = 0
  356. for graph_id, node_map in self.__node_maps_from_median.items():
  357. self.__sum_of_distances += node_map.induced_cost()
  358. # print(self.__sum_of_distances)
  359. # Print information about current iteration.
  360. if self.__print_to_stdout == 2:
  361. print('Old local SOD: ', old_sum_of_distances)
  362. print('New local SOD: ', self.__sum_of_distances)
  363. print('Best converged SOD: ', best_sum_of_distances)
  364. print('Modified median: ', median_modified)
  365. print('Modified node maps: ', node_maps_modified)
  366. print('Decreased order: ', decreased_order)
  367. print('Increased order: ', increased_order)
  368. print('===========================================================\n')
  369. converged = not (median_modified or node_maps_modified or decreased_order or increased_order)
  370. self.__itrs[median_pos] += 1
  371. # Update the best median.
  372. if self.__sum_of_distances < best_sum_of_distances:
  373. best_sum_of_distances = self.__sum_of_distances
  374. node_maps_from_best_median = self.__node_maps_from_median.copy() # @todo: this is a shallow copy, not sure if it is enough.
  375. best_median = median
  376. # Update the number of converged descents.
  377. if converged:
  378. self.__num_converged_descents += 1
  379. # Store the best encountered median.
  380. self.__sum_of_distances = best_sum_of_distances
  381. self.__node_maps_from_median = node_maps_from_best_median
  382. self.__ged_env.load_nx_graph(best_median, gen_median_id)
  383. self.__ged_env.init(self.__ged_env.get_init_type())
  384. end_descent = time.time()
  385. self.__runtime_converged = end_descent - start
  386. # Refine the sum of distances and the node maps for the converged median.
  387. self.__converged_sum_of_distances = self.__sum_of_distances
  388. if self.__refine:
  389. self.__improve_sum_of_distances(timer)
  390. # Record end time, set runtime and reset the number of initial medians.
  391. end = time.time()
  392. self.__runtime = end - start
  393. self.__num_random_inits = self.__desired_num_random_inits
  394. # Print global information.
  395. if self.__print_to_stdout != 0:
  396. print('\n===========================================================')
  397. print('Finished computation of generalized median graph.')
  398. print('-----------------------------------------------------------')
  399. print('Best SOD after initialization: ', self.__best_init_sum_of_distances)
  400. print('Converged SOD: ', self.__converged_sum_of_distances)
  401. if self.__refine:
  402. print('Refined SOD: ', self.__sum_of_distances)
  403. print('Overall runtime: ', self.__runtime)
  404. print('Runtime of initialization: ', self.__runtime_initialized)
  405. print('Runtime of block gradient descent: ', self.__runtime_converged - self.__runtime_initialized)
  406. if self.__refine:
  407. print('Runtime of refinement: ', self.__runtime - self.__runtime_converged)
  408. print('Number of initial medians: ', len(medians))
  409. total_itr = 0
  410. num_started_descents = 0
  411. for itr in self.__itrs:
  412. total_itr += itr
  413. if itr > 0:
  414. num_started_descents += 1
  415. print('Size of graph collection: ', len(graph_ids))
  416. print('Number of started descents: ', num_started_descents)
  417. print('Number of converged descents: ', self.__num_converged_descents)
  418. print('Overall number of iterations: ', total_itr)
  419. print('Overall number of times the order decreased: ', self.__num_decrease_order)
  420. print('Overall number of times the order increased: ', self.__num_increase_order)
  421. print('===========================================================\n')
  422. def __improve_sum_of_distances(self, timer): # @todo: go through and test
  423. # Use method selected for refinement phase.
  424. self.__ged_env.set_method(self.__refine_method, self.__refine_options)
  425. # Print information about current iteration.
  426. if self.__print_to_stdout == 2:
  427. progress = tqdm(desc='Improving node maps', total=len(self.__node_maps_from_median), file=sys.stdout)
  428. print('\n===========================================================')
  429. print('Improving node maps and SOD for converged median.')
  430. print('-----------------------------------------------------------')
  431. progress.update(1)
  432. # Improving the node maps.
  433. for graph_id, node_map in self.__node_maps_from_median.items():
  434. if time.expired():
  435. if self.__state == AlgorithmState.TERMINATED:
  436. self.__state = AlgorithmState.CONVERGED
  437. break
  438. self.__ged_env.run_method(self.__gen_median_id, graph_id)
  439. if self.__ged_env.get_upper_bound(self.__gen_median_id, graph_id) < node_map.induced_cost():
  440. self.__node_maps_from_median[graph_id] = self.__ged_env.get_node_map(self.__gen_median_id, graph_id)
  441. self.__sum_of_distances += self.__node_maps_from_median[graph_id].induced_cost()
  442. # Print information.
  443. if self.__print_to_stdout == 2:
  444. progress.update(1)
  445. self.__sum_of_distances = 0.0
  446. for key, val in self.__node_maps_from_median.items():
  447. self.__sum_of_distances += val.induced_cost()
  448. # Print information.
  449. if self.__print_to_stdout == 2:
  450. print('===========================================================\n')
  451. def __median_available(self):
  452. return self.__gen_median_id != np.inf
  453. def get_state(self):
  454. if not self.__median_available():
  455. raise Exception('No median has been computed. Call run() before calling get_state().')
  456. return self.__state
  457. def get_sum_of_distances(self, state=''):
  458. """Returns the sum of distances.
  459. Parameters
  460. ----------
  461. state : string
  462. The state of the estimator. Can be 'initialized' or 'converged'. Default: ""
  463. Returns
  464. -------
  465. float
  466. The sum of distances (SOD) of the median when the estimator was in the state `state` during the last call to run(). If `state` is not given, the converged SOD (without refinement) or refined SOD (with refinement) is returned.
  467. """
  468. if not self.__median_available():
  469. raise Exception('No median has been computed. Call run() before calling get_sum_of_distances().')
  470. if state == 'initialized':
  471. return self.__best_init_sum_of_distances
  472. if state == 'converged':
  473. return self.__converged_sum_of_distances
  474. return self.__sum_of_distances
  475. def get_runtime(self, state):
  476. if not self.__median_available():
  477. raise Exception('No median has been computed. Call run() before calling get_runtime().')
  478. if state == AlgorithmState.INITIALIZED:
  479. return self.__runtime_initialized
  480. if state == AlgorithmState.CONVERGED:
  481. return self.__runtime_converged
  482. return self.__runtime
  483. def get_num_itrs(self):
  484. if not self.__median_available():
  485. raise Exception('No median has been computed. Call run() before calling get_num_itrs().')
  486. return self.__itrs
  487. def get_num_times_order_decreased(self):
  488. if not self.__median_available():
  489. raise Exception('No median has been computed. Call run() before calling get_num_times_order_decreased().')
  490. return self.__num_decrease_order
  491. def get_num_times_order_increased(self):
  492. if not self.__median_available():
  493. raise Exception('No median has been computed. Call run() before calling get_num_times_order_increased().')
  494. return self.__num_increase_order
  495. def get_num_converged_descents(self):
  496. if not self.__median_available():
  497. raise Exception('No median has been computed. Call run() before calling get_num_converged_descents().')
  498. return self.__num_converged_descents
  499. def get_ged_env(self):
  500. return self.__ged_env
  501. def __set_default_options(self):
  502. self.__init_type = 'RANDOM'
  503. self.__num_random_inits = 10
  504. self.__desired_num_random_inits = 10
  505. self.__use_real_randomness = True
  506. self.__seed = 0
  507. self.__parallel = True
  508. self.__update_order = True
  509. self.__refine = True
  510. self.__time_limit_in_sec = 0
  511. self.__epsilon = 0.0001
  512. self.__max_itrs = 100
  513. self.__max_itrs_without_update = 3
  514. self.__num_inits_increase_order = 10
  515. self.__init_type_increase_order = 'K-MEANS++'
  516. self.__max_itrs_increase_order = 10
  517. self.__print_to_stdout = 2
  518. self.__label_names = {}
  519. def __construct_initial_medians(self, graph_ids, timer, initial_medians):
  520. # Print information about current iteration.
  521. if self.__print_to_stdout == 2:
  522. print('\n===========================================================')
  523. print('Constructing initial median(s).')
  524. print('-----------------------------------------------------------')
  525. # Compute or sample the initial median(s).
  526. initial_medians.clear()
  527. if self.__init_type == 'MEDOID':
  528. self.__compute_medoid(graph_ids, timer, initial_medians)
  529. elif self.__init_type == 'MAX':
  530. pass # @todo
  531. # compute_max_order_graph_(graph_ids, initial_medians)
  532. elif self.__init_type == 'MIN':
  533. pass # @todo
  534. # compute_min_order_graph_(graph_ids, initial_medians)
  535. elif self.__init_type == 'MEAN':
  536. pass # @todo
  537. # compute_mean_order_graph_(graph_ids, initial_medians)
  538. else:
  539. pass # @todo
  540. # sample_initial_medians_(graph_ids, initial_medians)
  541. # Print information about current iteration.
  542. if self.__print_to_stdout == 2:
  543. print('===========================================================')
  544. def __compute_medoid(self, graph_ids, timer, initial_medians):
  545. # Use method selected for initialization phase.
  546. self.__ged_env.set_method(self.__init_method, self.__init_options)
  547. # Compute the medoid.
  548. if self.__parallel:
  549. # @todo: notice when parallel self.__ged_env is not modified.
  550. sum_of_distances_list = [np.inf] * len(graph_ids)
  551. len_itr = len(graph_ids)
  552. itr = zip(graph_ids, range(0, len(graph_ids)))
  553. n_jobs = multiprocessing.cpu_count()
  554. if len_itr < 100 * n_jobs:
  555. chunksize = int(len_itr / n_jobs) + 1
  556. else:
  557. chunksize = 100
  558. def init_worker(ged_env_toshare):
  559. global G_ged_env
  560. G_ged_env = ged_env_toshare
  561. do_fun = partial(_compute_medoid_parallel, graph_ids)
  562. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(self.__ged_env,))
  563. if self.__print_to_stdout == 2:
  564. iterator = tqdm(pool.imap_unordered(do_fun, itr, chunksize),
  565. desc='Computing medoid', file=sys.stdout)
  566. else:
  567. iterator = pool.imap_unordered(do_fun, itr, chunksize)
  568. for i, dis in iterator:
  569. sum_of_distances_list[i] = dis
  570. pool.close()
  571. pool.join()
  572. medoid_id = np.argmin(sum_of_distances_list)
  573. best_sum_of_distances = sum_of_distances_list[medoid_id]
  574. initial_medians.append(self.__ged_env.get_nx_graph(medoid_id, True, True, False)) # @todo
  575. else:
  576. # Print information about current iteration.
  577. if self.__print_to_stdout == 2:
  578. progress = tqdm(desc='Computing medoid', total=len(graph_ids), file=sys.stdout)
  579. medoid_id = graph_ids[0]
  580. best_sum_of_distances = np.inf
  581. for g_id in graph_ids:
  582. if timer.expired():
  583. self.__state = AlgorithmState.CALLED
  584. break
  585. sum_of_distances = 0
  586. for h_id in graph_ids:
  587. self.__ged_env.run_method(g_id, h_id)
  588. sum_of_distances += self.__ged_env.get_upper_bound(g_id, h_id)
  589. if sum_of_distances < best_sum_of_distances:
  590. best_sum_of_distances = sum_of_distances
  591. medoid_id = g_id
  592. # Print information about current iteration.
  593. if self.__print_to_stdout == 2:
  594. progress.update(1)
  595. initial_medians.append(self.__ged_env.get_nx_graph(medoid_id, True, True, False)) # @todo
  596. # Print information about current iteration.
  597. if self.__print_to_stdout == 2:
  598. print('\n')
  599. def __compute_init_node_maps(self, graph_ids, gen_median_id):
  600. # Compute node maps and sum of distances for initial median.
  601. if self.__parallel:
  602. # @todo: notice when parallel self.__ged_env is not modified.
  603. self.__sum_of_distances = 0
  604. self.__node_maps_from_median.clear()
  605. sum_of_distances_list = [0] * len(graph_ids)
  606. len_itr = len(graph_ids)
  607. itr = graph_ids
  608. n_jobs = multiprocessing.cpu_count()
  609. if len_itr < 100 * n_jobs:
  610. chunksize = int(len_itr / n_jobs) + 1
  611. else:
  612. chunksize = 100
  613. def init_worker(ged_env_toshare):
  614. global G_ged_env
  615. G_ged_env = ged_env_toshare
  616. do_fun = partial(_compute_init_node_maps_parallel, gen_median_id)
  617. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(self.__ged_env,))
  618. if self.__print_to_stdout == 2:
  619. iterator = tqdm(pool.imap_unordered(do_fun, itr, chunksize),
  620. desc='Computing initial node maps', file=sys.stdout)
  621. else:
  622. iterator = pool.imap_unordered(do_fun, itr, chunksize)
  623. for g_id, sod, node_maps in iterator:
  624. sum_of_distances_list[g_id] = sod
  625. self.__node_maps_from_median[g_id] = node_maps
  626. pool.close()
  627. pool.join()
  628. self.__sum_of_distances = np.sum(sum_of_distances_list)
  629. # xxx = self.__node_maps_from_median
  630. else:
  631. # Print information about current iteration.
  632. if self.__print_to_stdout == 2:
  633. progress = tqdm(desc='Computing initial node maps', total=len(graph_ids), file=sys.stdout)
  634. self.__sum_of_distances = 0
  635. self.__node_maps_from_median.clear()
  636. for graph_id in graph_ids:
  637. self.__ged_env.run_method(gen_median_id, graph_id)
  638. self.__node_maps_from_median[graph_id] = self.__ged_env.get_node_map(gen_median_id, graph_id)
  639. # print(self.__node_maps_from_median[graph_id])
  640. self.__sum_of_distances += self.__node_maps_from_median[graph_id].induced_cost()
  641. # print(self.__sum_of_distances)
  642. # Print information about current iteration.
  643. if self.__print_to_stdout == 2:
  644. progress.update(1)
  645. # Print information about current iteration.
  646. if self.__print_to_stdout == 2:
  647. print('\n')
  648. def __termination_criterion_met(self, converged, timer, itr, itrs_without_update):
  649. if timer.expired() or (itr >= self.__max_itrs if self.__max_itrs >= 0 else False):
  650. if self.__state == AlgorithmState.TERMINATED:
  651. self.__state = AlgorithmState.INITIALIZED
  652. return True
  653. return converged or (itrs_without_update > self.__max_itrs_without_update if self.__max_itrs_without_update >= 0 else False)
  654. def __update_median(self, graphs, median):
  655. # Print information about current iteration.
  656. if self.__print_to_stdout == 2:
  657. print('Updating median: ', end='')
  658. # Store copy of the old median.
  659. old_median = median.copy() # @todo: this is just a shallow copy.
  660. # Update the node labels.
  661. if self.__labeled_nodes:
  662. self.__update_node_labels(graphs, median)
  663. # Update the edges and their labels.
  664. self.__update_edges(graphs, median)
  665. # Print information about current iteration.
  666. if self.__print_to_stdout == 2:
  667. print('done.')
  668. return not self.__are_graphs_equal(median, old_median)
  669. def __update_node_labels(self, graphs, median):
  670. # Print information about current iteration.
  671. if self.__print_to_stdout == 2:
  672. print('nodes ... ', end='')
  673. # Iterate through all nodes of the median.
  674. for i in range(0, nx.number_of_nodes(median)):
  675. # print('i: ', i)
  676. # Collect the labels of the substituted nodes.
  677. node_labels = []
  678. for graph_id, graph in graphs.items():
  679. # print('graph_id: ', graph_id)
  680. # print(self.__node_maps_from_median[graph_id])
  681. k = self.__node_maps_from_median[graph_id].image(i)
  682. # print('k: ', k)
  683. if k != np.inf:
  684. node_labels.append(graph.nodes[k])
  685. # Compute the median label and update the median.
  686. if len(node_labels) > 0:
  687. # median_label = self.__ged_env.get_median_node_label(node_labels)
  688. median_label = self.__get_median_node_label(node_labels)
  689. if self.__ged_env.get_node_rel_cost(median.nodes[i], median_label) > self.__epsilon:
  690. nx.set_node_attributes(median, {i: median_label})
  691. def __update_edges(self, graphs, median):
  692. # Print information about current iteration.
  693. if self.__print_to_stdout == 2:
  694. print('edges ... ', end='')
  695. # # Clear the adjacency lists of the median and reset number of edges to 0.
  696. # median_edges = list(median.edges)
  697. # for (head, tail) in median_edges:
  698. # median.remove_edge(head, tail)
  699. # @todo: what if edge is not labeled?
  700. # Iterate through all possible edges (i,j) of the median.
  701. for i in range(0, nx.number_of_nodes(median)):
  702. for j in range(i + 1, nx.number_of_nodes(median)):
  703. # Collect the labels of the edges to which (i,j) is mapped by the node maps.
  704. edge_labels = []
  705. for graph_id, graph in graphs.items():
  706. k = self.__node_maps_from_median[graph_id].image(i)
  707. l = self.__node_maps_from_median[graph_id].image(j)
  708. if k != np.inf and l != np.inf:
  709. if graph.has_edge(k, l):
  710. edge_labels.append(graph.edges[(k, l)])
  711. # Compute the median edge label and the overall edge relabeling cost.
  712. rel_cost = 0
  713. median_label = self.__ged_env.get_edge_label(1)
  714. if median.has_edge(i, j):
  715. median_label = median.edges[(i, j)]
  716. if self.__labeled_edges and len(edge_labels) > 0:
  717. new_median_label = self.__get_median_edge_label(edge_labels)
  718. if self.__ged_env.get_edge_rel_cost(median_label, new_median_label) > self.__epsilon:
  719. median_label = new_median_label
  720. for edge_label in edge_labels:
  721. rel_cost += self.__ged_env.get_edge_rel_cost(median_label, edge_label)
  722. # Update the median.
  723. if median.has_edge(i, j):
  724. median.remove_edge(i, j)
  725. if rel_cost < (self.__edge_ins_cost + self.__edge_del_cost) * len(edge_labels) - self.__edge_del_cost * len(graphs):
  726. median.add_edge(i, j, **median_label)
  727. # else:
  728. # if median.has_edge(i, j):
  729. # median.remove_edge(i, j)
  730. def __update_node_maps(self):
  731. # Update the node maps.
  732. if self.__parallel:
  733. # @todo: notice when parallel self.__ged_env is not modified.
  734. node_maps_were_modified = False
  735. # xxx = self.__node_maps_from_median.copy()
  736. len_itr = len(self.__node_maps_from_median)
  737. itr = [item for item in self.__node_maps_from_median.items()]
  738. n_jobs = multiprocessing.cpu_count()
  739. if len_itr < 100 * n_jobs:
  740. chunksize = int(len_itr / n_jobs) + 1
  741. else:
  742. chunksize = 100
  743. def init_worker(ged_env_toshare):
  744. global G_ged_env
  745. G_ged_env = ged_env_toshare
  746. do_fun = partial(_update_node_maps_parallel, self.__median_id, self.__epsilon)
  747. pool = Pool(processes=n_jobs, initializer=init_worker, initargs=(self.__ged_env,))
  748. if self.__print_to_stdout == 2:
  749. iterator = tqdm(pool.imap_unordered(do_fun, itr, chunksize),
  750. desc='Updating node maps', file=sys.stdout)
  751. else:
  752. iterator = pool.imap_unordered(do_fun, itr, chunksize)
  753. for g_id, node_map, nm_modified in iterator:
  754. self.__node_maps_from_median[g_id] = node_map
  755. if nm_modified:
  756. node_maps_were_modified = True
  757. pool.close()
  758. pool.join()
  759. # yyy = self.__node_maps_from_median.copy()
  760. else:
  761. # Print information about current iteration.
  762. if self.__print_to_stdout == 2:
  763. progress = tqdm(desc='Updating node maps', total=len(self.__node_maps_from_median), file=sys.stdout)
  764. node_maps_were_modified = False
  765. for graph_id, node_map in self.__node_maps_from_median.items():
  766. self.__ged_env.run_method(self.__median_id, graph_id)
  767. if self.__ged_env.get_upper_bound(self.__median_id, graph_id) < node_map.induced_cost() - self.__epsilon:
  768. # xxx = self.__node_maps_from_median[graph_id]
  769. self.__node_maps_from_median[graph_id] = self.__ged_env.get_node_map(self.__median_id, graph_id)
  770. # yyy = self.__node_maps_from_median[graph_id]
  771. node_maps_were_modified = True
  772. # Print information about current iteration.
  773. if self.__print_to_stdout == 2:
  774. progress.update(1)
  775. # Print information about current iteration.
  776. if self.__print_to_stdout == 2:
  777. print('\n')
  778. # Return true if the node maps were modified.
  779. return node_maps_were_modified
  780. def __decrease_order(self, graphs, median):
  781. # Print information about current iteration
  782. if self.__print_to_stdout == 2:
  783. print('Trying to decrease order: ... ', end='')
  784. # Initialize ID of the node that is to be deleted.
  785. id_deleted_node = [None] # @todo: or np.inf
  786. decreased_order = False
  787. # Decrease the order as long as the best deletion delta is negative.
  788. while self.__compute_best_deletion_delta(graphs, median, id_deleted_node) < -self.__epsilon:
  789. decreased_order = True
  790. median = self.__delete_node_from_median(id_deleted_node[0], median)
  791. # Print information about current iteration.
  792. if self.__print_to_stdout == 2:
  793. print('done.')
  794. # Return true iff the order was decreased.
  795. return decreased_order
  796. def __compute_best_deletion_delta(self, graphs, median, id_deleted_node):
  797. best_delta = 0.0
  798. # Determine node that should be deleted (if any).
  799. for i in range(0, nx.number_of_nodes(median)):
  800. # Compute cost delta.
  801. delta = 0.0
  802. for graph_id, graph in graphs.items():
  803. k = self.__node_maps_from_median[graph_id].image(i)
  804. if k == np.inf:
  805. delta -= self.__node_del_cost
  806. else:
  807. delta += self.__node_ins_cost - self.__ged_env.get_node_rel_cost(median.nodes[i], graph.nodes[k])
  808. for j, j_label in median[i].items():
  809. l = self.__node_maps_from_median[graph_id].image(j)
  810. if k == np.inf or l == np.inf:
  811. delta -= self.__edge_del_cost
  812. elif not graph.has_edge(k, l):
  813. delta -= self.__edge_del_cost
  814. else:
  815. delta += self.__edge_ins_cost - self.__ged_env.get_edge_rel_cost(j_label, graph.edges[(k, l)])
  816. # Update best deletion delta.
  817. if delta < best_delta - self.__epsilon:
  818. best_delta = delta
  819. id_deleted_node[0] = i
  820. # id_deleted_node[0] = 3 # @todo:
  821. return best_delta
  822. def __delete_node_from_median(self, id_deleted_node, median):
  823. # Update the median.
  824. median.remove_node(id_deleted_node)
  825. median = nx.convert_node_labels_to_integers(median, first_label=0, ordering='default', label_attribute=None) # @todo: This doesn't guarantee that the order is the same as in G.
  826. # Update the node maps.
  827. for key, node_map in self.__node_maps_from_median.items():
  828. new_node_map = NodeMap(nx.number_of_nodes(median), node_map.num_target_nodes())
  829. is_unassigned_target_node = [True] * node_map.num_target_nodes()
  830. for i in range(0, nx.number_of_nodes(median) + 1):
  831. if i != id_deleted_node:
  832. new_i = (i if i < id_deleted_node else i - 1)
  833. k = node_map.image(i)
  834. new_node_map.add_assignment(new_i, k)
  835. if k != np.inf:
  836. is_unassigned_target_node[k] = False
  837. for k in range(0, node_map.num_target_nodes()):
  838. if is_unassigned_target_node[k]:
  839. new_node_map.add_assignment(np.inf, k)
  840. # print(new_node_map.get_forward_map(), new_node_map.get_backward_map())
  841. self.__node_maps_from_median[key] = new_node_map
  842. # Increase overall number of decreases.
  843. self.__num_decrease_order += 1
  844. return median
  845. def __increase_order(self, graphs, median):
  846. # Print information about current iteration.
  847. if self.__print_to_stdout == 2:
  848. print('Trying to increase order: ... ', end='')
  849. # Initialize the best configuration and the best label of the node that is to be inserted.
  850. best_config = {}
  851. best_label = self.__ged_env.get_node_label(1)
  852. increased_order = False
  853. # Increase the order as long as the best insertion delta is negative.
  854. while self.__compute_best_insertion_delta(graphs, best_config, best_label) < - self.__epsilon:
  855. increased_order = True
  856. self.__add_node_to_median(best_config, best_label, median)
  857. # Print information about current iteration.
  858. if self.__print_to_stdout == 2:
  859. print('done.')
  860. # Return true iff the order was increased.
  861. return increased_order
  862. def __compute_best_insertion_delta(self, graphs, best_config, best_label):
  863. # Construct sets of inserted nodes.
  864. no_inserted_node = True
  865. inserted_nodes = {}
  866. for graph_id, graph in graphs.items():
  867. inserted_nodes[graph_id] = []
  868. best_config[graph_id] = np.inf
  869. for k in range(nx.number_of_nodes(graph)):
  870. if self.__node_maps_from_median[graph_id].pre_image(k) == np.inf:
  871. no_inserted_node = False
  872. inserted_nodes[graph_id].append((k, tuple(item for item in graph.nodes[k].items()))) # @todo: can order of label names be garantteed?
  873. # Return 0.0 if no node is inserted in any of the graphs.
  874. if no_inserted_node:
  875. return 0.0
  876. # Compute insertion configuration, label, and delta.
  877. best_delta = 0.0 # @todo
  878. if len(self.__label_names['node_labels']) == 0 and len(self.__label_names['node_attrs']) == 0: # @todo
  879. best_delta = self.__compute_insertion_delta_unlabeled(inserted_nodes, best_config, best_label)
  880. elif len(self.__label_names['node_labels']) > 0: # self.__constant_node_costs:
  881. best_delta = self.__compute_insertion_delta_constant(inserted_nodes, best_config, best_label)
  882. else:
  883. best_delta = self.__compute_insertion_delta_generic(inserted_nodes, best_config, best_label)
  884. # Return the best delta.
  885. return best_delta
  886. def __compute_insertion_delta_unlabeled(self, inserted_nodes, best_config, best_label): # @todo: go through and test.
  887. # Construct the nest configuration and compute its insertion delta.
  888. best_delta = 0.0
  889. best_config.clear()
  890. for graph_id, node_set in inserted_nodes.items():
  891. if len(node_set) == 0:
  892. best_config[graph_id] = np.inf
  893. best_delta += self.__node_del_cost
  894. else:
  895. best_config[graph_id] = node_set[0][0]
  896. best_delta -= self.__node_ins_cost
  897. # Return the best insertion delta.
  898. return best_delta
  899. def __compute_insertion_delta_constant(self, inserted_nodes, best_config, best_label):
  900. # Construct histogram and inverse label maps.
  901. hist = {}
  902. inverse_label_maps = {}
  903. for graph_id, node_set in inserted_nodes.items():
  904. inverse_label_maps[graph_id] = {}
  905. for node in node_set:
  906. k = node[0]
  907. label = node[1]
  908. if label not in inverse_label_maps[graph_id]:
  909. inverse_label_maps[graph_id][label] = k
  910. if label not in hist:
  911. hist[label] = 1
  912. else:
  913. hist[label] += 1
  914. # Determine the best label.
  915. best_count = 0
  916. for key, val in hist.items():
  917. if val > best_count:
  918. best_count = val
  919. best_label_tuple = key
  920. # get best label.
  921. best_label.clear()
  922. for key, val in best_label_tuple:
  923. best_label[key] = val
  924. # Construct the best configuration and compute its insertion delta.
  925. best_config.clear()
  926. best_delta = 0.0
  927. node_rel_cost = self.__ged_env.get_node_rel_cost(self.__ged_env.get_node_label(1), self.__ged_env.get_node_label(2))
  928. triangle_ineq_holds = (node_rel_cost <= self.__node_del_cost + self.__node_ins_cost)
  929. for graph_id, _ in inserted_nodes.items():
  930. if best_label_tuple in inverse_label_maps[graph_id]:
  931. best_config[graph_id] = inverse_label_maps[graph_id][best_label_tuple]
  932. best_delta -= self.__node_ins_cost
  933. elif triangle_ineq_holds and not len(inserted_nodes[graph_id]) == 0:
  934. best_config[graph_id] = inserted_nodes[graph_id][0][0]
  935. best_delta += node_rel_cost - self.__node_ins_cost
  936. else:
  937. best_config[graph_id] = np.inf
  938. best_delta += self.__node_del_cost
  939. # Return the best insertion delta.
  940. return best_delta
  941. def __compute_insertion_delta_generic(self, inserted_nodes, best_config, best_label):
  942. # Collect all node labels of inserted nodes.
  943. node_labels = []
  944. for _, node_set in inserted_nodes.items():
  945. for node in node_set:
  946. node_labels.append(node[1])
  947. # Compute node label medians that serve as initial solutions for block gradient descent.
  948. initial_node_labels = []
  949. self.__compute_initial_node_labels(node_labels, initial_node_labels)
  950. # Determine best insertion configuration, label, and delta via parallel block gradient descent from all initial node labels.
  951. best_delta = 0.0
  952. for node_label in initial_node_labels:
  953. # Construct local configuration.
  954. config = {}
  955. for graph_id, _ in inserted_nodes.items():
  956. config[graph_id] = tuple((np.inf, tuple(item for item in self.__ged_env.get_node_label(1).items())))
  957. # Run block gradient descent.
  958. converged = False
  959. itr = 0
  960. while not self.__insertion_termination_criterion_met(converged, itr):
  961. converged = not self.__update_config(node_label, inserted_nodes, config, node_labels)
  962. node_label_dict = dict(node_label)
  963. converged = converged and (not self.__update_node_label([dict(item) for item in node_labels], node_label_dict)) # @todo: the dict is tupled again in the function, can be better.
  964. node_label = tuple(item for item in node_label_dict.items()) # @todo: watch out: initial_node_labels[i] is not modified here.
  965. itr += 1
  966. # Compute insertion delta of converged solution.
  967. delta = 0.0
  968. for _, node in config.items():
  969. if node[0] == np.inf:
  970. delta += self.__node_del_cost
  971. else:
  972. delta += self.__ged_env.get_node_rel_cost(dict(node_label), dict(node[1])) - self.__node_ins_cost
  973. # Update best delta and global configuration if improvement has been found.
  974. if delta < best_delta - self.__epsilon:
  975. best_delta = delta
  976. best_label.clear()
  977. for key, val in node_label:
  978. best_label[key] = val
  979. best_config.clear()
  980. for graph_id, val in config.items():
  981. best_config[graph_id] = val[0]
  982. # Return the best delta.
  983. return best_delta
  984. def __compute_initial_node_labels(self, node_labels, median_labels):
  985. median_labels.clear()
  986. if self.__use_real_randomness: # @todo: may not work if parallelized.
  987. rng = np.random.randint(0, high=2**32 - 1, size=1)
  988. urng = np.random.RandomState(seed=rng[0])
  989. else:
  990. urng = np.random.RandomState(seed=self.__seed)
  991. # Generate the initial node label medians.
  992. if self.__init_type_increase_order == 'K-MEANS++':
  993. # Use k-means++ heuristic to generate the initial node label medians.
  994. already_selected = [False] * len(node_labels)
  995. selected_label_id = urng.randint(low=0, high=len(node_labels), size=1)[0] # c++ test: 23
  996. median_labels.append(node_labels[selected_label_id])
  997. already_selected[selected_label_id] = True
  998. # xxx = [41, 0, 18, 9, 6, 14, 21, 25, 33] for c++ test
  999. # iii = 0 for c++ test
  1000. while len(median_labels) < self.__num_inits_increase_order:
  1001. weights = [np.inf] * len(node_labels)
  1002. for label_id in range(0, len(node_labels)):
  1003. if already_selected[label_id]:
  1004. weights[label_id] = 0
  1005. continue
  1006. for label in median_labels:
  1007. weights[label_id] = min(weights[label_id], self.__ged_env.get_node_rel_cost(dict(label), dict(node_labels[label_id])))
  1008. sum_weight = np.sum(weights)
  1009. if sum_weight == 0:
  1010. p = np.array([1 / len(weights)] * len(weights))
  1011. else:
  1012. p = np.array(weights) / np.sum(weights)
  1013. selected_label_id = urng.choice(range(0, len(weights)), size=1, p=p)[0] # for c++ test: xxx[iii]
  1014. # iii += 1 for c++ test
  1015. median_labels.append(node_labels[selected_label_id])
  1016. already_selected[selected_label_id] = True
  1017. else:
  1018. # Compute the initial node medians as the medians of randomly generated clusters of (roughly) equal size.
  1019. # @todo: go through and test.
  1020. shuffled_node_labels = [np.inf] * len(node_labels) #@todo: random?
  1021. # @todo: std::shuffle(shuffled_node_labels.begin(), shuffled_node_labels.end(), urng);?
  1022. cluster_size = len(node_labels) / self.__num_inits_increase_order
  1023. pos = 0.0
  1024. cluster = []
  1025. while len(median_labels) < self.__num_inits_increase_order - 1:
  1026. while pos < (len(median_labels) + 1) * cluster_size:
  1027. cluster.append(shuffled_node_labels[pos])
  1028. pos += 1
  1029. median_labels.append(self.__get_median_node_label(cluster))
  1030. cluster.clear()
  1031. while pos < len(shuffled_node_labels):
  1032. pos += 1
  1033. cluster.append(shuffled_node_labels[pos])
  1034. median_labels.append(self.__get_median_node_label(cluster))
  1035. cluster.clear()
  1036. # Run Lloyd's Algorithm.
  1037. converged = False
  1038. closest_median_ids = [np.inf] * len(node_labels)
  1039. clusters = [[] for _ in range(len(median_labels))]
  1040. itr = 1
  1041. while not self.__insertion_termination_criterion_met(converged, itr):
  1042. converged = not self.__update_clusters(node_labels, median_labels, closest_median_ids)
  1043. if not converged:
  1044. for cluster in clusters:
  1045. cluster.clear()
  1046. for label_id in range(0, len(node_labels)):
  1047. clusters[closest_median_ids[label_id]].append(node_labels[label_id])
  1048. for cluster_id in range(0, len(clusters)):
  1049. node_label = dict(median_labels[cluster_id])
  1050. self.__update_node_label([dict(item) for item in clusters[cluster_id]], node_label) # @todo: the dict is tupled again in the function, can be better.
  1051. median_labels[cluster_id] = tuple(item for item in node_label.items())
  1052. itr += 1
  1053. def __insertion_termination_criterion_met(self, converged, itr):
  1054. return converged or (itr >= self.__max_itrs_increase_order if self.__max_itrs_increase_order > 0 else False)
  1055. def __update_config(self, node_label, inserted_nodes, config, node_labels):
  1056. # Determine the best configuration.
  1057. config_modified = False
  1058. for graph_id, node_set in inserted_nodes.items():
  1059. best_assignment = config[graph_id]
  1060. best_cost = 0.0
  1061. if best_assignment[0] == np.inf:
  1062. best_cost = self.__node_del_cost
  1063. else:
  1064. best_cost = self.__ged_env.get_node_rel_cost(dict(node_label), dict(best_assignment[1])) - self.__node_ins_cost
  1065. for node in node_set:
  1066. cost = self.__ged_env.get_node_rel_cost(dict(node_label), dict(node[1])) - self.__node_ins_cost
  1067. if cost < best_cost - self.__epsilon:
  1068. best_cost = cost
  1069. best_assignment = node
  1070. config_modified = True
  1071. if self.__node_del_cost < best_cost - self.__epsilon:
  1072. best_cost = self.__node_del_cost
  1073. best_assignment = tuple((np.inf, best_assignment[1]))
  1074. config_modified = True
  1075. config[graph_id] = best_assignment
  1076. # Collect the node labels contained in the best configuration.
  1077. node_labels.clear()
  1078. for key, val in config.items():
  1079. if val[0] != np.inf:
  1080. node_labels.append(val[1])
  1081. # Return true if the configuration was modified.
  1082. return config_modified
  1083. def __update_node_label(self, node_labels, node_label):
  1084. new_node_label = self.__get_median_node_label(node_labels)
  1085. if self.__ged_env.get_node_rel_cost(new_node_label, node_label) > self.__epsilon:
  1086. node_label.clear()
  1087. for key, val in new_node_label.items():
  1088. node_label[key] = val
  1089. return True
  1090. return False
  1091. def __update_clusters(self, node_labels, median_labels, closest_median_ids):
  1092. # Determine the closest median for each node label.
  1093. clusters_modified = False
  1094. for label_id in range(0, len(node_labels)):
  1095. closest_median_id = np.inf
  1096. dist_to_closest_median = np.inf
  1097. for median_id in range(0, len(median_labels)):
  1098. dist_to_median = self.__ged_env.get_node_rel_cost(dict(median_labels[median_id]), dict(node_labels[label_id]))
  1099. if dist_to_median < dist_to_closest_median - self.__epsilon:
  1100. dist_to_closest_median = dist_to_median
  1101. closest_median_id = median_id
  1102. if closest_median_id != closest_median_ids[label_id]:
  1103. closest_median_ids[label_id] = closest_median_id
  1104. clusters_modified = True
  1105. # Return true if the clusters were modified.
  1106. return clusters_modified
  1107. def __add_node_to_median(self, best_config, best_label, median):
  1108. # Update the median.
  1109. nb_nodes_median = nx.number_of_nodes(median)
  1110. median.add_node(nb_nodes_median, **best_label)
  1111. # Update the node maps.
  1112. for graph_id, node_map in self.__node_maps_from_median.items():
  1113. node_map_as_rel = []
  1114. node_map.as_relation(node_map_as_rel)
  1115. new_node_map = NodeMap(nx.number_of_nodes(median), node_map.num_target_nodes())
  1116. for assignment in node_map_as_rel:
  1117. new_node_map.add_assignment(assignment[0], assignment[1])
  1118. new_node_map.add_assignment(nx.number_of_nodes(median) - 1, best_config[graph_id])
  1119. self.__node_maps_from_median[graph_id] = new_node_map
  1120. # Increase overall number of increases.
  1121. self.__num_increase_order += 1
  1122. def __improve_sum_of_distances(self, timer):
  1123. pass
  1124. def __median_available(self):
  1125. return self.__median_id != np.inf
  1126. # def __get_node_image_from_map(self, node_map, node):
  1127. # """
  1128. # Return ID of the node mapping of `node` in `node_map`.
  1129. # Parameters
  1130. # ----------
  1131. # node_map : list[tuple(int, int)]
  1132. # List of node maps where the mapping node is found.
  1133. #
  1134. # node : int
  1135. # The mapping node of this node is returned
  1136. # Raises
  1137. # ------
  1138. # Exception
  1139. # If the node with ID `node` is not contained in the source nodes of the node map.
  1140. # Returns
  1141. # -------
  1142. # int
  1143. # ID of the mapping of `node`.
  1144. #
  1145. # Notes
  1146. # -----
  1147. # This function is not implemented in the `ged::MedianGraphEstimator` class of the `GEDLIB` library. Instead it is a Python implementation of the `ged::NodeMap::image` function.
  1148. # """
  1149. # if node < len(node_map):
  1150. # return node_map[node][1] if node_map[node][1] < len(node_map) else np.inf
  1151. # else:
  1152. # raise Exception('The node with ID ', str(node), ' is not contained in the source nodes of the node map.')
  1153. # return np.inf
  1154. def __are_graphs_equal(self, g1, g2):
  1155. """
  1156. Check if the two graphs are equal.
  1157. Parameters
  1158. ----------
  1159. g1 : NetworkX graph object
  1160. Graph 1 to be compared.
  1161. g2 : NetworkX graph object
  1162. Graph 2 to be compared.
  1163. Returns
  1164. -------
  1165. bool
  1166. True if the two graph are equal.
  1167. Notes
  1168. -----
  1169. This is not an identical check. Here the two graphs are equal if and only if their original_node_ids, nodes, all node labels, edges and all edge labels are equal. This function is specifically designed for class `MedianGraphEstimator` and should not be used elsewhere.
  1170. """
  1171. # check original node ids.
  1172. if not g1.graph['original_node_ids'] == g2.graph['original_node_ids']:
  1173. return False
  1174. # check nodes.
  1175. nlist1 = [n for n in g1.nodes(data=True)]
  1176. nlist2 = [n for n in g2.nodes(data=True)]
  1177. if not nlist1 == nlist2:
  1178. return False
  1179. # check edges.
  1180. elist1 = [n for n in g1.edges(data=True)]
  1181. elist2 = [n for n in g2.edges(data=True)]
  1182. if not elist1 == elist2:
  1183. return False
  1184. return True
  1185. def compute_my_cost(g, h, node_map):
  1186. cost = 0.0
  1187. for node in g.nodes:
  1188. cost += 0
  1189. def set_label_names(self, node_labels=[], edge_labels=[], node_attrs=[], edge_attrs=[]):
  1190. self.__label_names = {'node_labels': node_labels, 'edge_labels': edge_labels,
  1191. 'node_attrs': node_attrs, 'edge_attrs': edge_attrs}
  1192. def __get_median_node_label(self, node_labels):
  1193. if len(self.__label_names['node_labels']) > 0:
  1194. return self.__get_median_label_symbolic(node_labels)
  1195. elif len(self.__label_names['node_attrs']) > 0:
  1196. return self.__get_median_label_nonsymbolic(node_labels)
  1197. else:
  1198. raise Exception('Node label names are not given.')
  1199. def __get_median_edge_label(self, edge_labels):
  1200. if len(self.__label_names['edge_labels']) > 0:
  1201. return self.__get_median_label_symbolic(edge_labels)
  1202. elif len(self.__label_names['edge_attrs']) > 0:
  1203. return self.__get_median_label_nonsymbolic(edge_labels)
  1204. else:
  1205. raise Exception('Edge label names are not given.')
  1206. def __get_median_label_symbolic(self, labels):
  1207. # Construct histogram.
  1208. hist = {}
  1209. for label in labels:
  1210. label = tuple([kv for kv in label.items()]) # @todo: this may be slow.
  1211. if label not in hist:
  1212. hist[label] = 1
  1213. else:
  1214. hist[label] += 1
  1215. # Return the label that appears most frequently.
  1216. best_count = 0
  1217. median_label = {}
  1218. for label, count in hist.items():
  1219. if count > best_count:
  1220. best_count = count
  1221. median_label = {kv[0]: kv[1] for kv in label}
  1222. return median_label
  1223. def __get_median_label_nonsymbolic(self, labels):
  1224. if len(labels) == 0:
  1225. return {} # @todo
  1226. else:
  1227. # Transform the labels into coordinates and compute mean label as initial solution.
  1228. labels_as_coords = []
  1229. sums = {}
  1230. for key, val in labels[0].items():
  1231. sums[key] = 0
  1232. for label in labels:
  1233. coords = {}
  1234. for key, val in label.items():
  1235. label_f = float(val)
  1236. sums[key] += label_f
  1237. coords[key] = label_f
  1238. labels_as_coords.append(coords)
  1239. median = {}
  1240. for key, val in sums.items():
  1241. median[key] = val / len(labels)
  1242. # Run main loop of Weiszfeld's Algorithm.
  1243. epsilon = 0.0001
  1244. delta = 1.0
  1245. num_itrs = 0
  1246. all_equal = False
  1247. while ((delta > epsilon) and (num_itrs < 100) and (not all_equal)):
  1248. numerator = {}
  1249. for key, val in sums.items():
  1250. numerator[key] = 0
  1251. denominator = 0
  1252. for label_as_coord in labels_as_coords:
  1253. norm = 0
  1254. for key, val in label_as_coord.items():
  1255. norm += (val - median[key]) ** 2
  1256. norm = np.sqrt(norm)
  1257. if norm > 0:
  1258. for key, val in label_as_coord.items():
  1259. numerator[key] += val / norm
  1260. denominator += 1.0 / norm
  1261. if denominator == 0:
  1262. all_equal = True
  1263. else:
  1264. new_median = {}
  1265. delta = 0.0
  1266. for key, val in numerator.items():
  1267. this_median = val / denominator
  1268. new_median[key] = this_median
  1269. delta += np.abs(median[key] - this_median)
  1270. median = new_median
  1271. num_itrs += 1
  1272. # Transform the solution to strings and return it.
  1273. median_label = {}
  1274. for key, val in median.items():
  1275. median_label[key] = str(val)
  1276. return median_label
  1277. # def __get_median_edge_label_symbolic(self, edge_labels):
  1278. # pass
  1279. # def __get_median_edge_label_nonsymbolic(self, edge_labels):
  1280. # if len(edge_labels) == 0:
  1281. # return {}
  1282. # else:
  1283. # # Transform the labels into coordinates and compute mean label as initial solution.
  1284. # edge_labels_as_coords = []
  1285. # sums = {}
  1286. # for key, val in edge_labels[0].items():
  1287. # sums[key] = 0
  1288. # for edge_label in edge_labels:
  1289. # coords = {}
  1290. # for key, val in edge_label.items():
  1291. # label = float(val)
  1292. # sums[key] += label
  1293. # coords[key] = label
  1294. # edge_labels_as_coords.append(coords)
  1295. # median = {}
  1296. # for key, val in sums.items():
  1297. # median[key] = val / len(edge_labels)
  1298. #
  1299. # # Run main loop of Weiszfeld's Algorithm.
  1300. # epsilon = 0.0001
  1301. # delta = 1.0
  1302. # num_itrs = 0
  1303. # all_equal = False
  1304. # while ((delta > epsilon) and (num_itrs < 100) and (not all_equal)):
  1305. # numerator = {}
  1306. # for key, val in sums.items():
  1307. # numerator[key] = 0
  1308. # denominator = 0
  1309. # for edge_label_as_coord in edge_labels_as_coords:
  1310. # norm = 0
  1311. # for key, val in edge_label_as_coord.items():
  1312. # norm += (val - median[key]) ** 2
  1313. # norm += np.sqrt(norm)
  1314. # if norm > 0:
  1315. # for key, val in edge_label_as_coord.items():
  1316. # numerator[key] += val / norm
  1317. # denominator += 1.0 / norm
  1318. # if denominator == 0:
  1319. # all_equal = True
  1320. # else:
  1321. # new_median = {}
  1322. # delta = 0.0
  1323. # for key, val in numerator.items():
  1324. # this_median = val / denominator
  1325. # new_median[key] = this_median
  1326. # delta += np.abs(median[key] - this_median)
  1327. # median = new_median
  1328. #
  1329. # num_itrs += 1
  1330. #
  1331. # # Transform the solution to ged::GXLLabel and return it.
  1332. # median_label = {}
  1333. # for key, val in median.items():
  1334. # median_label[key] = str(val)
  1335. # return median_label
  1336. def _compute_medoid_parallel(graph_ids, itr):
  1337. g_id = itr[0]
  1338. i = itr[1]
  1339. # @todo: timer not considered here.
  1340. # if timer.expired():
  1341. # self.__state = AlgorithmState.CALLED
  1342. # break
  1343. sum_of_distances = 0
  1344. for h_id in graph_ids:
  1345. G_ged_env.run_method(g_id, h_id)
  1346. sum_of_distances += G_ged_env.get_upper_bound(g_id, h_id)
  1347. return i, sum_of_distances
  1348. def _compute_init_node_maps_parallel(gen_median_id, itr):
  1349. graph_id = itr
  1350. G_ged_env.run_method(gen_median_id, graph_id)
  1351. node_maps_from_median = G_ged_env.get_node_map(gen_median_id, graph_id)
  1352. # print(self.__node_maps_from_median[graph_id])
  1353. sum_of_distance = node_maps_from_median.induced_cost()
  1354. # print(self.__sum_of_distances)
  1355. return graph_id, sum_of_distance, node_maps_from_median
  1356. def _update_node_maps_parallel(median_id, epsilon, itr):
  1357. graph_id = itr[0]
  1358. node_map = itr[1]
  1359. node_maps_were_modified = False
  1360. G_ged_env.run_method(median_id, graph_id)
  1361. if G_ged_env.get_upper_bound(median_id, graph_id) < node_map.induced_cost() - epsilon:
  1362. node_map = G_ged_env.get_node_map(median_id, graph_id)
  1363. node_maps_were_modified = True
  1364. return graph_id, node_map, node_maps_were_modified

A Python package for graph kernels, graph edit distances and graph pre-image problem.