|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500 |
- #!/usr/bin/env python
- __all__ = ['find_parameters']
-
- import os, sys, traceback, getpass, time, re
- from threading import Thread
- from subprocess import *
-
- if sys.version_info[0] < 3:
- from Queue import Queue
- else:
- from queue import Queue
-
- telnet_workers = []
- ssh_workers = []
- nr_local_worker = 1
-
- class GridOption:
- def __init__(self, dataset_pathname, options):
- dirname = os.path.dirname(__file__)
- if sys.platform != 'win32':
- self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
- self.gnuplot_pathname = '/usr/bin/gnuplot'
- else:
- # example for windows
- self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
- # svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
- self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
- self.fold = 5
- self.c_begin, self.c_end, self.c_step = -5, 15, 2
- self.g_begin, self.g_end, self.g_step = 3, -15, -2
- self.grid_with_c, self.grid_with_g = True, True
- self.dataset_pathname = dataset_pathname
- self.dataset_title = os.path.split(dataset_pathname)[1]
- self.out_pathname = '{0}.out'.format(self.dataset_title)
- self.png_pathname = '{0}.png'.format(self.dataset_title)
- self.pass_through_string = ' '
- self.resume_pathname = None
- self.parse_options(options)
-
- def parse_options(self, options):
- if type(options) == str:
- options = options.split()
- i = 0
- pass_through_options = []
-
- while i < len(options):
- if options[i] == '-log2c':
- i = i + 1
- if options[i] == 'null':
- self.grid_with_c = False
- else:
- self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
- elif options[i] == '-log2g':
- i = i + 1
- if options[i] == 'null':
- self.grid_with_g = False
- else:
- self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
- elif options[i] == '-v':
- i = i + 1
- self.fold = options[i]
- elif options[i] in ('-c','-g'):
- raise ValueError('Use -log2c and -log2g.')
- elif options[i] == '-svmtrain':
- i = i + 1
- self.svmtrain_pathname = options[i]
- elif options[i] == '-gnuplot':
- i = i + 1
- if options[i] == 'null':
- self.gnuplot_pathname = None
- else:
- self.gnuplot_pathname = options[i]
- elif options[i] == '-out':
- i = i + 1
- if options[i] == 'null':
- self.out_pathname = None
- else:
- self.out_pathname = options[i]
- elif options[i] == '-png':
- i = i + 1
- self.png_pathname = options[i]
- elif options[i] == '-resume':
- if i == (len(options)-1) or options[i+1].startswith('-'):
- self.resume_pathname = self.dataset_title + '.out'
- else:
- i = i + 1
- self.resume_pathname = options[i]
- else:
- pass_through_options.append(options[i])
- i = i + 1
-
- self.pass_through_string = ' '.join(pass_through_options)
- if not os.path.exists(self.svmtrain_pathname):
- raise IOError('svm-train executable not found')
- if not os.path.exists(self.dataset_pathname):
- raise IOError('dataset not found')
- if self.resume_pathname and not os.path.exists(self.resume_pathname):
- raise IOError('file for resumption not found')
- if not self.grid_with_c and not self.grid_with_g:
- raise ValueError('-log2c and -log2g should not be null simultaneously')
- if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
- sys.stderr.write('gnuplot executable not found\n')
- self.gnuplot_pathname = None
-
- def redraw(db,best_param,gnuplot,options,tofile=False):
- if len(db) == 0: return
- begin_level = round(max(x[2] for x in db)) - 3
- step_size = 0.5
-
- best_log2c,best_log2g,best_rate = best_param
-
- # if newly obtained c, g, or cv values are the same,
- # then stop redrawing the contour.
- if all(x[0] == db[0][0] for x in db): return
- if all(x[1] == db[0][1] for x in db): return
- if all(x[2] == db[0][2] for x in db): return
-
- if tofile:
- gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
- gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
- #gnuplot.write(b"set term postscript color solid\n")
- #gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
- elif sys.platform == 'win32':
- gnuplot.write(b"set term windows\n")
- else:
- gnuplot.write( b"set term x11\n")
- gnuplot.write(b"set xlabel \"log2(C)\"\n")
- gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
- gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
- gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
- gnuplot.write(b"set contour\n")
- gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
- gnuplot.write(b"unset surface\n")
- gnuplot.write(b"unset ztics\n")
- gnuplot.write(b"set view 0,0\n")
- gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
- gnuplot.write(b"unset label\n")
- gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \
- at screen 0.5,0.85 center\n". \
- format(best_log2c, best_log2g, best_rate).encode())
- gnuplot.write("set label \"C = {0} gamma = {1}\""
- " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
- gnuplot.write(b"set key at screen 0.9,0.9\n")
- gnuplot.write(b"splot \"-\" with lines\n")
-
- db.sort(key = lambda x:(x[0], -x[1]))
-
- prevc = db[0][0]
- for line in db:
- if prevc != line[0]:
- gnuplot.write(b"\n")
- prevc = line[0]
- gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
- gnuplot.write(b"e\n")
- gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
- gnuplot.flush()
-
-
- def calculate_jobs(options):
-
- def range_f(begin,end,step):
- # like range, but works on non-integer too
- seq = []
- while True:
- if step > 0 and begin > end: break
- if step < 0 and begin < end: break
- seq.append(begin)
- begin = begin + step
- return seq
-
- def permute_sequence(seq):
- n = len(seq)
- if n <= 1: return seq
-
- mid = int(n/2)
- left = permute_sequence(seq[:mid])
- right = permute_sequence(seq[mid+1:])
-
- ret = [seq[mid]]
- while left or right:
- if left: ret.append(left.pop(0))
- if right: ret.append(right.pop(0))
-
- return ret
-
-
- c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
- g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))
-
- if not options.grid_with_c:
- c_seq = [None]
- if not options.grid_with_g:
- g_seq = [None]
-
- nr_c = float(len(c_seq))
- nr_g = float(len(g_seq))
- i, j = 0, 0
- jobs = []
-
- while i < nr_c or j < nr_g:
- if i/nr_c < j/nr_g:
- # increase C resolution
- line = []
- for k in range(0,j):
- line.append((c_seq[i],g_seq[k]))
- i = i + 1
- jobs.append(line)
- else:
- # increase g resolution
- line = []
- for k in range(0,i):
- line.append((c_seq[k],g_seq[j]))
- j = j + 1
- jobs.append(line)
-
- resumed_jobs = {}
-
- if options.resume_pathname is None:
- return jobs, resumed_jobs
-
- for line in open(options.resume_pathname, 'r'):
- line = line.strip()
- rst = re.findall(r'rate=([0-9.]+)',line)
- if not rst:
- continue
- rate = float(rst[0])
-
- c, g = None, None
- rst = re.findall(r'log2c=([0-9.-]+)',line)
- if rst:
- c = float(rst[0])
- rst = re.findall(r'log2g=([0-9.-]+)',line)
- if rst:
- g = float(rst[0])
-
- resumed_jobs[(c,g)] = rate
-
- return jobs, resumed_jobs
-
-
- class WorkerStopToken: # used to notify the worker to stop or if a worker is dead
- pass
-
- class Worker(Thread):
- def __init__(self,name,job_queue,result_queue,options):
- Thread.__init__(self)
- self.name = name
- self.job_queue = job_queue
- self.result_queue = result_queue
- self.options = options
-
- def run(self):
- while True:
- (cexp,gexp) = self.job_queue.get()
- if cexp is WorkerStopToken:
- self.job_queue.put((cexp,gexp))
- # print('worker {0} stop.'.format(self.name))
- break
- try:
- c, g = None, None
- if cexp != None:
- c = 2.0**cexp
- if gexp != None:
- g = 2.0**gexp
- rate = self.run_one(c,g)
- if rate is None: raise RuntimeError('get no rate')
- except:
- # we failed, let others do that and we just quit
-
- traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])
-
- self.job_queue.put((cexp,gexp))
- sys.stderr.write('worker {0} quit.\n'.format(self.name))
- break
- else:
- self.result_queue.put((self.name,cexp,gexp,rate))
-
- def get_cmd(self,c,g):
- options=self.options
- cmdline = '"' + options.svmtrain_pathname + '"'
- if options.grid_with_c:
- cmdline += ' -c {0} '.format(c)
- if options.grid_with_g:
- cmdline += ' -g {0} '.format(g)
- cmdline += ' -v {0} {1} {2} '.format\
- (options.fold,options.pass_through_string,options.dataset_pathname)
- return cmdline
-
- class LocalWorker(Worker):
- def run_one(self,c,g):
- cmdline = self.get_cmd(c,g)
- result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
- for line in result.readlines():
- if str(line).find('Cross') != -1:
- return float(line.split()[-1][0:-1])
-
- class SSHWorker(Worker):
- def __init__(self,name,job_queue,result_queue,host,options):
- Worker.__init__(self,name,job_queue,result_queue,options)
- self.host = host
- self.cwd = os.getcwd()
- def run_one(self,c,g):
- cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
- (self.host,self.cwd,self.get_cmd(c,g))
- result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
- for line in result.readlines():
- if str(line).find('Cross') != -1:
- return float(line.split()[-1][0:-1])
-
- class TelnetWorker(Worker):
- def __init__(self,name,job_queue,result_queue,host,username,password,options):
- Worker.__init__(self,name,job_queue,result_queue,options)
- self.host = host
- self.username = username
- self.password = password
- def run(self):
- import telnetlib
- self.tn = tn = telnetlib.Telnet(self.host)
- tn.read_until('login: ')
- tn.write(self.username + '\n')
- tn.read_until('Password: ')
- tn.write(self.password + '\n')
-
- # XXX: how to know whether login is successful?
- tn.read_until(self.username)
- #
- print('login ok', self.host)
- tn.write('cd '+os.getcwd()+'\n')
- Worker.run(self)
- tn.write('exit\n')
- def run_one(self,c,g):
- cmdline = self.get_cmd(c,g)
- result = self.tn.write(cmdline+'\n')
- (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
- for line in output.split('\n'):
- if str(line).find('Cross') != -1:
- return float(line.split()[-1][0:-1])
-
- def find_parameters(dataset_pathname, options=''):
-
- def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
- if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
- best_rate,best_c,best_g = rate,c,g
- stdout_str = '[{0}] {1} {2} (best '.format\
- (worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
- output_str = ''
- if c != None:
- stdout_str += 'c={0}, '.format(2.0**best_c)
- output_str += 'log2c={0} '.format(c)
- if g != None:
- stdout_str += 'g={0}, '.format(2.0**best_g)
- output_str += 'log2g={0} '.format(g)
- stdout_str += 'rate={0})'.format(best_rate)
- print(stdout_str)
- if options.out_pathname and not resumed:
- output_str += 'rate={0}\n'.format(rate)
- result_file.write(output_str)
- result_file.flush()
-
- return best_c,best_g,best_rate
-
- options = GridOption(dataset_pathname, options);
-
- if options.gnuplot_pathname:
- gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
- else:
- gnuplot = None
-
- # put jobs in queue
-
- jobs,resumed_jobs = calculate_jobs(options)
- job_queue = Queue(0)
- result_queue = Queue(0)
-
- for (c,g) in resumed_jobs:
- result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))
-
- for line in jobs:
- for (c,g) in line:
- if (c,g) not in resumed_jobs:
- job_queue.put((c,g))
-
- # hack the queue to become a stack --
- # this is important when some thread
- # failed and re-put a job. It we still
- # use FIFO, the job will be put
- # into the end of the queue, and the graph
- # will only be updated in the end
-
- job_queue._put = job_queue.queue.appendleft
-
- # fire telnet workers
-
- if telnet_workers:
- nr_telnet_worker = len(telnet_workers)
- username = getpass.getuser()
- password = getpass.getpass()
- for host in telnet_workers:
- worker = TelnetWorker(host,job_queue,result_queue,
- host,username,password,options)
- worker.start()
-
- # fire ssh workers
-
- if ssh_workers:
- for host in ssh_workers:
- worker = SSHWorker(host,job_queue,result_queue,host,options)
- worker.start()
-
- # fire local workers
-
- for i in range(nr_local_worker):
- worker = LocalWorker('local',job_queue,result_queue,options)
- worker.start()
-
- # gather results
-
- done_jobs = {}
-
- if options.out_pathname:
- if options.resume_pathname:
- result_file = open(options.out_pathname, 'a')
- else:
- result_file = open(options.out_pathname, 'w')
-
-
- db = []
- best_rate = -1
- best_c,best_g = None,None
-
- for (c,g) in resumed_jobs:
- rate = resumed_jobs[(c,g)]
- best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)
-
- for line in jobs:
- for (c,g) in line:
- while (c,g) not in done_jobs:
- (worker,c1,g1,rate1) = result_queue.get()
- done_jobs[(c1,g1)] = rate1
- if (c1,g1) not in resumed_jobs:
- best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
- db.append((c,g,done_jobs[(c,g)]))
- if gnuplot and options.grid_with_c and options.grid_with_g:
- redraw(db,[best_c, best_g, best_rate],gnuplot,options)
- redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)
-
-
- if options.out_pathname:
- result_file.close()
- job_queue.put((WorkerStopToken,None))
- best_param, best_cg = {}, []
- if best_c != None:
- best_param['c'] = 2.0**best_c
- best_cg += [2.0**best_c]
- if best_g != None:
- best_param['g'] = 2.0**best_g
- best_cg += [2.0**best_g]
- print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))
-
- return best_rate, best_param
-
-
- if __name__ == '__main__':
-
- def exit_with_help():
- print("""\
- Usage: grid.py [grid_options] [svm_options] dataset
-
- grid_options :
- -log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
- begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
- "null" -- do not grid with c
- -log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
- begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
- "null" -- do not grid with g
- -v n : n-fold cross validation (default 5)
- -svmtrain pathname : set svm executable path and name
- -gnuplot {pathname | "null"} :
- pathname -- set gnuplot executable path and name
- "null" -- do not plot
- -out {pathname | "null"} : (default dataset.out)
- pathname -- set output file path and name
- "null" -- do not output file
- -png pathname : set graphic output file path and name (default dataset.png)
- -resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
- This is experimental. Try this option only if some parameters have been checked for the SAME data.
-
- svm_options : additional options for svm-train""")
- sys.exit(1)
-
- if len(sys.argv) < 2:
- exit_with_help()
- dataset_pathname = sys.argv[-1]
- options = sys.argv[1:-1]
- try:
- find_parameters(dataset_pathname, options)
- except (IOError,ValueError) as e:
- sys.stderr.write(str(e) + '\n')
- sys.stderr.write('Try "grid.py" for more information.\n')
- sys.exit(1)
|