You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grid.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. #!/usr/bin/env python
  2. __all__ = ['find_parameters']
  3. import os, sys, traceback, getpass, time, re
  4. from threading import Thread
  5. from subprocess import *
  6. if sys.version_info[0] < 3:
  7. from Queue import Queue
  8. else:
  9. from queue import Queue
  10. telnet_workers = []
  11. ssh_workers = []
  12. nr_local_worker = 1
  13. class GridOption:
  14. def __init__(self, dataset_pathname, options):
  15. dirname = os.path.dirname(__file__)
  16. if sys.platform != 'win32':
  17. self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
  18. self.gnuplot_pathname = '/usr/bin/gnuplot'
  19. else:
  20. # example for windows
  21. self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
  22. # svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
  23. self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
  24. self.fold = 5
  25. self.c_begin, self.c_end, self.c_step = -5, 15, 2
  26. self.g_begin, self.g_end, self.g_step = 3, -15, -2
  27. self.grid_with_c, self.grid_with_g = True, True
  28. self.dataset_pathname = dataset_pathname
  29. self.dataset_title = os.path.split(dataset_pathname)[1]
  30. self.out_pathname = '{0}.out'.format(self.dataset_title)
  31. self.png_pathname = '{0}.png'.format(self.dataset_title)
  32. self.pass_through_string = ' '
  33. self.resume_pathname = None
  34. self.parse_options(options)
  35. def parse_options(self, options):
  36. if type(options) == str:
  37. options = options.split()
  38. i = 0
  39. pass_through_options = []
  40. while i < len(options):
  41. if options[i] == '-log2c':
  42. i = i + 1
  43. if options[i] == 'null':
  44. self.grid_with_c = False
  45. else:
  46. self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
  47. elif options[i] == '-log2g':
  48. i = i + 1
  49. if options[i] == 'null':
  50. self.grid_with_g = False
  51. else:
  52. self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
  53. elif options[i] == '-v':
  54. i = i + 1
  55. self.fold = options[i]
  56. elif options[i] in ('-c','-g'):
  57. raise ValueError('Use -log2c and -log2g.')
  58. elif options[i] == '-svmtrain':
  59. i = i + 1
  60. self.svmtrain_pathname = options[i]
  61. elif options[i] == '-gnuplot':
  62. i = i + 1
  63. if options[i] == 'null':
  64. self.gnuplot_pathname = None
  65. else:
  66. self.gnuplot_pathname = options[i]
  67. elif options[i] == '-out':
  68. i = i + 1
  69. if options[i] == 'null':
  70. self.out_pathname = None
  71. else:
  72. self.out_pathname = options[i]
  73. elif options[i] == '-png':
  74. i = i + 1
  75. self.png_pathname = options[i]
  76. elif options[i] == '-resume':
  77. if i == (len(options)-1) or options[i+1].startswith('-'):
  78. self.resume_pathname = self.dataset_title + '.out'
  79. else:
  80. i = i + 1
  81. self.resume_pathname = options[i]
  82. else:
  83. pass_through_options.append(options[i])
  84. i = i + 1
  85. self.pass_through_string = ' '.join(pass_through_options)
  86. if not os.path.exists(self.svmtrain_pathname):
  87. raise IOError('svm-train executable not found')
  88. if not os.path.exists(self.dataset_pathname):
  89. raise IOError('dataset not found')
  90. if self.resume_pathname and not os.path.exists(self.resume_pathname):
  91. raise IOError('file for resumption not found')
  92. if not self.grid_with_c and not self.grid_with_g:
  93. raise ValueError('-log2c and -log2g should not be null simultaneously')
  94. if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
  95. sys.stderr.write('gnuplot executable not found\n')
  96. self.gnuplot_pathname = None
  97. def redraw(db,best_param,gnuplot,options,tofile=False):
  98. if len(db) == 0: return
  99. begin_level = round(max(x[2] for x in db)) - 3
  100. step_size = 0.5
  101. best_log2c,best_log2g,best_rate = best_param
  102. # if newly obtained c, g, or cv values are the same,
  103. # then stop redrawing the contour.
  104. if all(x[0] == db[0][0] for x in db): return
  105. if all(x[1] == db[0][1] for x in db): return
  106. if all(x[2] == db[0][2] for x in db): return
  107. if tofile:
  108. gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
  109. gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
  110. #gnuplot.write(b"set term postscript color solid\n")
  111. #gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
  112. elif sys.platform == 'win32':
  113. gnuplot.write(b"set term windows\n")
  114. else:
  115. gnuplot.write( b"set term x11\n")
  116. gnuplot.write(b"set xlabel \"log2(C)\"\n")
  117. gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
  118. gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
  119. gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
  120. gnuplot.write(b"set contour\n")
  121. gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
  122. gnuplot.write(b"unset surface\n")
  123. gnuplot.write(b"unset ztics\n")
  124. gnuplot.write(b"set view 0,0\n")
  125. gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
  126. gnuplot.write(b"unset label\n")
  127. gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
  128. at screen 0.5,0.85 center\n". \
  129. format(best_log2c, best_log2g, best_rate).encode())
  130. gnuplot.write("set label \"C = {0} gamma = {1}\""
  131. " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
  132. gnuplot.write(b"set key at screen 0.9,0.9\n")
  133. gnuplot.write(b"splot \"-\" with lines\n")
  134. db.sort(key = lambda x:(x[0], -x[1]))
  135. prevc = db[0][0]
  136. for line in db:
  137. if prevc != line[0]:
  138. gnuplot.write(b"\n")
  139. prevc = line[0]
  140. gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
  141. gnuplot.write(b"e\n")
  142. gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
  143. gnuplot.flush()
  144. def calculate_jobs(options):
  145. def range_f(begin,end,step):
  146. # like range, but works on non-integer too
  147. seq = []
  148. while True:
  149. if step > 0 and begin > end: break
  150. if step < 0 and begin < end: break
  151. seq.append(begin)
  152. begin = begin + step
  153. return seq
  154. def permute_sequence(seq):
  155. n = len(seq)
  156. if n <= 1: return seq
  157. mid = int(n/2)
  158. left = permute_sequence(seq[:mid])
  159. right = permute_sequence(seq[mid+1:])
  160. ret = [seq[mid]]
  161. while left or right:
  162. if left: ret.append(left.pop(0))
  163. if right: ret.append(right.pop(0))
  164. return ret
  165. c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
  166. g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))
  167. if not options.grid_with_c:
  168. c_seq = [None]
  169. if not options.grid_with_g:
  170. g_seq = [None]
  171. nr_c = float(len(c_seq))
  172. nr_g = float(len(g_seq))
  173. i, j = 0, 0
  174. jobs = []
  175. while i < nr_c or j < nr_g:
  176. if i/nr_c < j/nr_g:
  177. # increase C resolution
  178. line = []
  179. for k in range(0,j):
  180. line.append((c_seq[i],g_seq[k]))
  181. i = i + 1
  182. jobs.append(line)
  183. else:
  184. # increase g resolution
  185. line = []
  186. for k in range(0,i):
  187. line.append((c_seq[k],g_seq[j]))
  188. j = j + 1
  189. jobs.append(line)
  190. resumed_jobs = {}
  191. if options.resume_pathname is None:
  192. return jobs, resumed_jobs
  193. for line in open(options.resume_pathname, 'r'):
  194. line = line.strip()
  195. rst = re.findall(r'rate=([0-9.]+)',line)
  196. if not rst:
  197. continue
  198. rate = float(rst[0])
  199. c, g = None, None
  200. rst = re.findall(r'log2c=([0-9.-]+)',line)
  201. if rst:
  202. c = float(rst[0])
  203. rst = re.findall(r'log2g=([0-9.-]+)',line)
  204. if rst:
  205. g = float(rst[0])
  206. resumed_jobs[(c,g)] = rate
  207. return jobs, resumed_jobs
  208. class WorkerStopToken: # used to notify the worker to stop or if a worker is dead
  209. pass
  210. class Worker(Thread):
  211. def __init__(self,name,job_queue,result_queue,options):
  212. Thread.__init__(self)
  213. self.name = name
  214. self.job_queue = job_queue
  215. self.result_queue = result_queue
  216. self.options = options
  217. def run(self):
  218. while True:
  219. (cexp,gexp) = self.job_queue.get()
  220. if cexp is WorkerStopToken:
  221. self.job_queue.put((cexp,gexp))
  222. # print('worker {0} stop.'.format(self.name))
  223. break
  224. try:
  225. c, g = None, None
  226. if cexp != None:
  227. c = 2.0**cexp
  228. if gexp != None:
  229. g = 2.0**gexp
  230. rate = self.run_one(c,g)
  231. if rate is None: raise RuntimeError('get no rate')
  232. except:
  233. # we failed, let others do that and we just quit
  234. traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
  235. self.job_queue.put((cexp,gexp))
  236. sys.stderr.write('worker {0} quit.\n'.format(self.name))
  237. break
  238. else:
  239. self.result_queue.put((self.name,cexp,gexp,rate))
  240. def get_cmd(self,c,g):
  241. options=self.options
  242. cmdline = '"' + options.svmtrain_pathname + '"'
  243. if options.grid_with_c:
  244. cmdline += ' -c {0} '.format(c)
  245. if options.grid_with_g:
  246. cmdline += ' -g {0} '.format(g)
  247. cmdline += ' -v {0} {1} {2} '.format\
  248. (options.fold,options.pass_through_string,options.dataset_pathname)
  249. return cmdline
  250. class LocalWorker(Worker):
  251. def run_one(self,c,g):
  252. cmdline = self.get_cmd(c,g)
  253. result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
  254. for line in result.readlines():
  255. if str(line).find('Cross') != -1:
  256. return float(line.split()[-1][0:-1])
  257. class SSHWorker(Worker):
  258. def __init__(self,name,job_queue,result_queue,host,options):
  259. Worker.__init__(self,name,job_queue,result_queue,options)
  260. self.host = host
  261. self.cwd = os.getcwd()
  262. def run_one(self,c,g):
  263. cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
  264. (self.host,self.cwd,self.get_cmd(c,g))
  265. result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
  266. for line in result.readlines():
  267. if str(line).find('Cross') != -1:
  268. return float(line.split()[-1][0:-1])
  269. class TelnetWorker(Worker):
  270. def __init__(self,name,job_queue,result_queue,host,username,password,options):
  271. Worker.__init__(self,name,job_queue,result_queue,options)
  272. self.host = host
  273. self.username = username
  274. self.password = password
  275. def run(self):
  276. import telnetlib
  277. self.tn = tn = telnetlib.Telnet(self.host)
  278. tn.read_until('login: ')
  279. tn.write(self.username + '\n')
  280. tn.read_until('Password: ')
  281. tn.write(self.password + '\n')
  282. # XXX: how to know whether login is successful?
  283. tn.read_until(self.username)
  284. #
  285. print('login ok', self.host)
  286. tn.write('cd '+os.getcwd()+'\n')
  287. Worker.run(self)
  288. tn.write('exit\n')
  289. def run_one(self,c,g):
  290. cmdline = self.get_cmd(c,g)
  291. result = self.tn.write(cmdline+'\n')
  292. (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
  293. for line in output.split('\n'):
  294. if str(line).find('Cross') != -1:
  295. return float(line.split()[-1][0:-1])
  296. def find_parameters(dataset_pathname, options=''):
  297. def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
  298. if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
  299. best_rate,best_c,best_g = rate,c,g
  300. stdout_str = '[{0}] {1} {2} (best '.format\
  301. (worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
  302. output_str = ''
  303. if c != None:
  304. stdout_str += 'c={0}, '.format(2.0**best_c)
  305. output_str += 'log2c={0} '.format(c)
  306. if g != None:
  307. stdout_str += 'g={0}, '.format(2.0**best_g)
  308. output_str += 'log2g={0} '.format(g)
  309. stdout_str += 'rate={0})'.format(best_rate)
  310. print(stdout_str)
  311. if options.out_pathname and not resumed:
  312. output_str += 'rate={0}\n'.format(rate)
  313. result_file.write(output_str)
  314. result_file.flush()
  315. return best_c,best_g,best_rate
  316. options = GridOption(dataset_pathname, options);
  317. if options.gnuplot_pathname:
  318. gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
  319. else:
  320. gnuplot = None
  321. # put jobs in queue
  322. jobs,resumed_jobs = calculate_jobs(options)
  323. job_queue = Queue(0)
  324. result_queue = Queue(0)
  325. for (c,g) in resumed_jobs:
  326. result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))
  327. for line in jobs:
  328. for (c,g) in line:
  329. if (c,g) not in resumed_jobs:
  330. job_queue.put((c,g))
  331. # hack the queue to become a stack --
  332. # this is important when some thread
  333. # failed and re-put a job. It we still
  334. # use FIFO, the job will be put
  335. # into the end of the queue, and the graph
  336. # will only be updated in the end
  337. job_queue._put = job_queue.queue.appendleft
  338. # fire telnet workers
  339. if telnet_workers:
  340. nr_telnet_worker = len(telnet_workers)
  341. username = getpass.getuser()
  342. password = getpass.getpass()
  343. for host in telnet_workers:
  344. worker = TelnetWorker(host,job_queue,result_queue,
  345. host,username,password,options)
  346. worker.start()
  347. # fire ssh workers
  348. if ssh_workers:
  349. for host in ssh_workers:
  350. worker = SSHWorker(host,job_queue,result_queue,host,options)
  351. worker.start()
  352. # fire local workers
  353. for i in range(nr_local_worker):
  354. worker = LocalWorker('local',job_queue,result_queue,options)
  355. worker.start()
  356. # gather results
  357. done_jobs = {}
  358. if options.out_pathname:
  359. if options.resume_pathname:
  360. result_file = open(options.out_pathname, 'a')
  361. else:
  362. result_file = open(options.out_pathname, 'w')
  363. db = []
  364. best_rate = -1
  365. best_c,best_g = None,None
  366. for (c,g) in resumed_jobs:
  367. rate = resumed_jobs[(c,g)]
  368. best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)
  369. for line in jobs:
  370. for (c,g) in line:
  371. while (c,g) not in done_jobs:
  372. (worker,c1,g1,rate1) = result_queue.get()
  373. done_jobs[(c1,g1)] = rate1
  374. if (c1,g1) not in resumed_jobs:
  375. best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
  376. db.append((c,g,done_jobs[(c,g)]))
  377. if gnuplot and options.grid_with_c and options.grid_with_g:
  378. redraw(db,[best_c, best_g, best_rate],gnuplot,options)
  379. redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)
  380. if options.out_pathname:
  381. result_file.close()
  382. job_queue.put((WorkerStopToken,None))
  383. best_param, best_cg = {}, []
  384. if best_c != None:
  385. best_param['c'] = 2.0**best_c
  386. best_cg += [2.0**best_c]
  387. if best_g != None:
  388. best_param['g'] = 2.0**best_g
  389. best_cg += [2.0**best_g]
  390. print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))
  391. return best_rate, best_param
  392. if __name__ == '__main__':
  393. def exit_with_help():
  394. print("""\
  395. Usage: grid.py [grid_options] [svm_options] dataset
  396. grid_options :
  397. -log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
  398. begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
  399. "null" -- do not grid with c
  400. -log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
  401. begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
  402. "null" -- do not grid with g
  403. -v n : n-fold cross validation (default 5)
  404. -svmtrain pathname : set svm executable path and name
  405. -gnuplot {pathname | "null"} :
  406. pathname -- set gnuplot executable path and name
  407. "null" -- do not plot
  408. -out {pathname | "null"} : (default dataset.out)
  409. pathname -- set output file path and name
  410. "null" -- do not output file
  411. -png pathname : set graphic output file path and name (default dataset.png)
  412. -resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
  413. This is experimental. Try this option only if some parameters have been checked for the SAME data.
  414. svm_options : additional options for svm-train""")
  415. sys.exit(1)
  416. if len(sys.argv) < 2:
  417. exit_with_help()
  418. dataset_pathname = sys.argv[-1]
  419. options = sys.argv[1:-1]
  420. try:
  421. find_parameters(dataset_pathname, options)
  422. except (IOError,ValueError) as e:
  423. sys.stderr.write(str(e) + '\n')
  424. sys.stderr.write('Try "grid.py" for more information.\n')
  425. sys.exit(1)

A Python package for graph kernels, graph edit distances and graph pre-image problem.