diff --git a/lang/zh/gklearn/gedlib/lib/libsvm.3.22/tools/grid.py b/lang/zh/gklearn/gedlib/lib/libsvm.3.22/tools/grid.py new file mode 100644 index 0000000..40cb082 --- /dev/null +++ b/lang/zh/gklearn/gedlib/lib/libsvm.3.22/tools/grid.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python +__all__ = ['find_parameters'] + +import os, sys, traceback, getpass, time, re +from threading import Thread +from subprocess import * + +if sys.version_info[0] < 3: + from Queue import Queue +else: + from queue import Queue + +telnet_workers = [] +ssh_workers = [] +nr_local_worker = 1 + +class GridOption: + def __init__(self, dataset_pathname, options): + dirname = os.path.dirname(__file__) + if sys.platform != 'win32': + self.svmtrain_pathname = os.path.join(dirname, '../svm-train') + self.gnuplot_pathname = '/usr/bin/gnuplot' + else: + # example for windows + self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe') + # svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe' + self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe' + self.fold = 5 + self.c_begin, self.c_end, self.c_step = -5, 15, 2 + self.g_begin, self.g_end, self.g_step = 3, -15, -2 + self.grid_with_c, self.grid_with_g = True, True + self.dataset_pathname = dataset_pathname + self.dataset_title = os.path.split(dataset_pathname)[1] + self.out_pathname = '{0}.out'.format(self.dataset_title) + self.png_pathname = '{0}.png'.format(self.dataset_title) + self.pass_through_string = ' ' + self.resume_pathname = None + self.parse_options(options) + + def parse_options(self, options): + if type(options) == str: + options = options.split() + i = 0 + pass_through_options = [] + + while i < len(options): + if options[i] == '-log2c': + i = i + 1 + if options[i] == 'null': + self.grid_with_c = False + else: + self.c_begin, self.c_end, self.c_step = map(float,options[i].split(',')) + elif options[i] == '-log2g': + i = i + 1 + if options[i] == 'null': + self.grid_with_g = False + else: + self.g_begin, self.g_end, self.g_step = map(float,options[i].split(',')) + elif options[i] == '-v': + i = i + 1 + self.fold = options[i] + elif options[i] in ('-c','-g'): + raise ValueError('Use -log2c and -log2g.') + elif options[i] == '-svmtrain': + i = i + 1 + self.svmtrain_pathname = options[i] + elif options[i] == '-gnuplot': + i = i + 1 + if options[i] == 'null': + self.gnuplot_pathname = None + else: + self.gnuplot_pathname = options[i] + elif options[i] == '-out': + i = i + 1 + if options[i] == 'null': + self.out_pathname = None + else: + self.out_pathname = options[i] + elif options[i] == '-png': + i = i + 1 + self.png_pathname = options[i] + elif options[i] == '-resume': + if i == (len(options)-1) or options[i+1].startswith('-'): + self.resume_pathname = self.dataset_title + '.out' + else: + i = i + 1 + self.resume_pathname = options[i] + else: + pass_through_options.append(options[i]) + i = i + 1 + + self.pass_through_string = ' '.join(pass_through_options) + if not os.path.exists(self.svmtrain_pathname): + raise IOError('svm-train executable not found') + if not os.path.exists(self.dataset_pathname): + raise IOError('dataset not found') + if self.resume_pathname and not os.path.exists(self.resume_pathname): + raise IOError('file for resumption not found') + if not self.grid_with_c and not self.grid_with_g: + raise ValueError('-log2c and -log2g should not be null simultaneously') + if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname): + sys.stderr.write('gnuplot executable not found\n') + self.gnuplot_pathname = None + +def redraw(db,best_param,gnuplot,options,tofile=False): + if len(db) == 0: return + begin_level = round(max(x[2] for x in db)) - 3 + step_size = 0.5 + + best_log2c,best_log2g,best_rate = best_param + + # if newly obtained c, g, or cv values are the same, + # then stop redrawing the contour. + if all(x[0] == db[0][0] for x in db): return + if all(x[1] == db[0][1] for x in db): return + if all(x[2] == db[0][2] for x in db): return + + if tofile: + gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n") + gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode()) + #gnuplot.write(b"set term postscript color solid\n") + #gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode()) + elif sys.platform == 'win32': + gnuplot.write(b"set term windows\n") + else: + gnuplot.write( b"set term x11\n") + gnuplot.write(b"set xlabel \"log2(C)\"\n") + gnuplot.write(b"set ylabel \"log2(gamma)\"\n") + gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode()) + gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode()) + gnuplot.write(b"set contour\n") + gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode()) + gnuplot.write(b"unset surface\n") + gnuplot.write(b"unset ztics\n") + gnuplot.write(b"set view 0,0\n") + gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode()) + gnuplot.write(b"unset label\n") + gnuplot.write("set label \"Best log2(C) = {0} log2(gamma) = {1} accuracy = {2}%\" \ + at screen 0.5,0.85 center\n". \ + format(best_log2c, best_log2g, best_rate).encode()) + gnuplot.write("set label \"C = {0} gamma = {1}\"" + " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode()) + gnuplot.write(b"set key at screen 0.9,0.9\n") + gnuplot.write(b"splot \"-\" with lines\n") + + db.sort(key = lambda x:(x[0], -x[1])) + + prevc = db[0][0] + for line in db: + if prevc != line[0]: + gnuplot.write(b"\n") + prevc = line[0] + gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode()) + gnuplot.write(b"e\n") + gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure + gnuplot.flush() + + +def calculate_jobs(options): + + def range_f(begin,end,step): + # like range, but works on non-integer too + seq = [] + while True: + if step > 0 and begin > end: break + if step < 0 and begin < end: break + seq.append(begin) + begin = begin + step + return seq + + def permute_sequence(seq): + n = len(seq) + if n <= 1: return seq + + mid = int(n/2) + left = permute_sequence(seq[:mid]) + right = permute_sequence(seq[mid+1:]) + + ret = [seq[mid]] + while left or right: + if left: ret.append(left.pop(0)) + if right: ret.append(right.pop(0)) + + return ret + + + c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step)) + g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step)) + + if not options.grid_with_c: + c_seq = [None] + if not options.grid_with_g: + g_seq = [None] + + nr_c = float(len(c_seq)) + nr_g = float(len(g_seq)) + i, j = 0, 0 + jobs = [] + + while i < nr_c or j < nr_g: + if i/nr_c < j/nr_g: + # increase C resolution + line = [] + for k in range(0,j): + line.append((c_seq[i],g_seq[k])) + i = i + 1 + jobs.append(line) + else: + # increase g resolution + line = [] + for k in range(0,i): + line.append((c_seq[k],g_seq[j])) + j = j + 1 + jobs.append(line) + + resumed_jobs = {} + + if options.resume_pathname is None: + return jobs, resumed_jobs + + for line in open(options.resume_pathname, 'r'): + line = line.strip() + rst = re.findall(r'rate=([0-9.]+)',line) + if not rst: + continue + rate = float(rst[0]) + + c, g = None, None + rst = re.findall(r'log2c=([0-9.-]+)',line) + if rst: + c = float(rst[0]) + rst = re.findall(r'log2g=([0-9.-]+)',line) + if rst: + g = float(rst[0]) + + resumed_jobs[(c,g)] = rate + + return jobs, resumed_jobs + + +class WorkerStopToken: # used to notify the worker to stop or if a worker is dead + pass + +class Worker(Thread): + def __init__(self,name,job_queue,result_queue,options): + Thread.__init__(self) + self.name = name + self.job_queue = job_queue + self.result_queue = result_queue + self.options = options + + def run(self): + while True: + (cexp,gexp) = self.job_queue.get() + if cexp is WorkerStopToken: + self.job_queue.put((cexp,gexp)) + # print('worker {0} stop.'.format(self.name)) + break + try: + c, g = None, None + if cexp != None: + c = 2.0**cexp + if gexp != None: + g = 2.0**gexp + rate = self.run_one(c,g) + if rate is None: raise RuntimeError('get no rate') + except: + # we failed, let others do that and we just quit + + traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2]) + + self.job_queue.put((cexp,gexp)) + sys.stderr.write('worker {0} quit.\n'.format(self.name)) + break + else: + self.result_queue.put((self.name,cexp,gexp,rate)) + + def get_cmd(self,c,g): + options=self.options + cmdline = '"' + options.svmtrain_pathname + '"' + if options.grid_with_c: + cmdline += ' -c {0} '.format(c) + if options.grid_with_g: + cmdline += ' -g {0} '.format(g) + cmdline += ' -v {0} {1} {2} '.format\ + (options.fold,options.pass_through_string,options.dataset_pathname) + return cmdline + +class LocalWorker(Worker): + def run_one(self,c,g): + cmdline = self.get_cmd(c,g) + result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout + for line in result.readlines(): + if str(line).find('Cross') != -1: + return float(line.split()[-1][0:-1]) + +class SSHWorker(Worker): + def __init__(self,name,job_queue,result_queue,host,options): + Worker.__init__(self,name,job_queue,result_queue,options) + self.host = host + self.cwd = os.getcwd() + def run_one(self,c,g): + cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\ + (self.host,self.cwd,self.get_cmd(c,g)) + result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout + for line in result.readlines(): + if str(line).find('Cross') != -1: + return float(line.split()[-1][0:-1]) + +class TelnetWorker(Worker): + def __init__(self,name,job_queue,result_queue,host,username,password,options): + Worker.__init__(self,name,job_queue,result_queue,options) + self.host = host + self.username = username + self.password = password + def run(self): + import telnetlib + self.tn = tn = telnetlib.Telnet(self.host) + tn.read_until('login: ') + tn.write(self.username + '\n') + tn.read_until('Password: ') + tn.write(self.password + '\n') + + # XXX: how to know whether login is successful? + tn.read_until(self.username) + # + print('login ok', self.host) + tn.write('cd '+os.getcwd()+'\n') + Worker.run(self) + tn.write('exit\n') + def run_one(self,c,g): + cmdline = self.get_cmd(c,g) + result = self.tn.write(cmdline+'\n') + (idx,matchm,output) = self.tn.expect(['Cross.*\n']) + for line in output.split('\n'): + if str(line).find('Cross') != -1: + return float(line.split()[-1][0:-1]) + +def find_parameters(dataset_pathname, options=''): + + def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed): + if (rate > best_rate) or (rate==best_rate and g==best_g and c