You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

svm.py 9.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. #!/usr/bin/env python
  2. from ctypes import *
  3. from ctypes.util import find_library
  4. from os import path
  5. import sys
  6. if sys.version_info[0] >= 3:
  7. xrange = range
  8. __all__ = ['libsvm', 'svm_problem', 'svm_parameter',
  9. 'toPyModel', 'gen_svm_nodearray', 'print_null', 'svm_node', 'C_SVC',
  10. 'EPSILON_SVR', 'LINEAR', 'NU_SVC', 'NU_SVR', 'ONE_CLASS',
  11. 'POLY', 'PRECOMPUTED', 'PRINT_STRING_FUN', 'RBF',
  12. 'SIGMOID', 'c_double', 'svm_model']
  13. try:
  14. dirname = path.dirname(path.abspath(__file__))
  15. if sys.platform == 'win32':
  16. libsvm = CDLL(path.join(dirname, r'..\windows\libsvm.dll'))
  17. else:
  18. libsvm = CDLL(path.join(dirname, '../libsvm.so.2'))
  19. except:
  20. # For unix the prefix 'lib' is not considered.
  21. if find_library('svm'):
  22. libsvm = CDLL(find_library('svm'))
  23. elif find_library('libsvm'):
  24. libsvm = CDLL(find_library('libsvm'))
  25. else:
  26. raise Exception('LIBSVM library not found.')
  27. C_SVC = 0
  28. NU_SVC = 1
  29. ONE_CLASS = 2
  30. EPSILON_SVR = 3
  31. NU_SVR = 4
  32. LINEAR = 0
  33. POLY = 1
  34. RBF = 2
  35. SIGMOID = 3
  36. PRECOMPUTED = 4
  37. PRINT_STRING_FUN = CFUNCTYPE(None, c_char_p)
  38. def print_null(s):
  39. return
  40. def genFields(names, types):
  41. return list(zip(names, types))
  42. def fillprototype(f, restype, argtypes):
  43. f.restype = restype
  44. f.argtypes = argtypes
  45. class svm_node(Structure):
  46. _names = ["index", "value"]
  47. _types = [c_int, c_double]
  48. _fields_ = genFields(_names, _types)
  49. def __str__(self):
  50. return '%d:%g' % (self.index, self.value)
  51. def gen_svm_nodearray(xi, feature_max=None, isKernel=None):
  52. if isinstance(xi, dict):
  53. index_range = xi.keys()
  54. elif isinstance(xi, (list, tuple)):
  55. if not isKernel:
  56. xi = [0] + xi # idx should start from 1
  57. index_range = range(len(xi))
  58. else:
  59. raise TypeError('xi should be a dictionary, list or tuple')
  60. if feature_max:
  61. assert(isinstance(feature_max, int))
  62. index_range = filter(lambda j: j <= feature_max, index_range)
  63. if not isKernel:
  64. index_range = filter(lambda j:xi[j] != 0, index_range)
  65. index_range = sorted(index_range)
  66. ret = (svm_node * (len(index_range)+1))()
  67. ret[-1].index = -1
  68. for idx, j in enumerate(index_range):
  69. ret[idx].index = j
  70. ret[idx].value = xi[j]
  71. max_idx = 0
  72. if index_range:
  73. max_idx = index_range[-1]
  74. return ret, max_idx
  75. class svm_problem(Structure):
  76. _names = ["l", "y", "x"]
  77. _types = [c_int, POINTER(c_double), POINTER(POINTER(svm_node))]
  78. _fields_ = genFields(_names, _types)
  79. def __init__(self, y, x, isKernel=None):
  80. if len(y) != len(x):
  81. raise ValueError("len(y) != len(x)")
  82. self.l = l = len(y)
  83. max_idx = 0
  84. x_space = self.x_space = []
  85. for i, xi in enumerate(x):
  86. tmp_xi, tmp_idx = gen_svm_nodearray(xi,isKernel=isKernel)
  87. x_space += [tmp_xi]
  88. max_idx = max(max_idx, tmp_idx)
  89. self.n = max_idx
  90. self.y = (c_double * l)()
  91. for i, yi in enumerate(y): self.y[i] = yi
  92. self.x = (POINTER(svm_node) * l)()
  93. for i, xi in enumerate(self.x_space): self.x[i] = xi
  94. class svm_parameter(Structure):
  95. _names = ["svm_type", "kernel_type", "degree", "gamma", "coef0",
  96. "cache_size", "eps", "C", "nr_weight", "weight_label", "weight",
  97. "nu", "p", "shrinking", "probability"]
  98. _types = [c_int, c_int, c_int, c_double, c_double,
  99. c_double, c_double, c_double, c_int, POINTER(c_int), POINTER(c_double),
  100. c_double, c_double, c_int, c_int]
  101. _fields_ = genFields(_names, _types)
  102. def __init__(self, options = None):
  103. if options == None:
  104. options = ''
  105. self.parse_options(options)
  106. def __str__(self):
  107. s = ''
  108. attrs = svm_parameter._names + list(self.__dict__.keys())
  109. values = map(lambda attr: getattr(self, attr), attrs)
  110. for attr, val in zip(attrs, values):
  111. s += (' %s: %s\n' % (attr, val))
  112. s = s.strip()
  113. return s
  114. def set_to_default_values(self):
  115. self.svm_type = C_SVC;
  116. self.kernel_type = RBF
  117. self.degree = 3
  118. self.gamma = 0
  119. self.coef0 = 0
  120. self.nu = 0.5
  121. self.cache_size = 100
  122. self.C = 1
  123. self.eps = 0.001
  124. self.p = 0.1
  125. self.shrinking = 1
  126. self.probability = 0
  127. self.nr_weight = 0
  128. self.weight_label = None
  129. self.weight = None
  130. self.cross_validation = False
  131. self.nr_fold = 0
  132. self.print_func = cast(None, PRINT_STRING_FUN)
  133. def parse_options(self, options):
  134. if isinstance(options, list):
  135. argv = options
  136. elif isinstance(options, str):
  137. argv = options.split()
  138. else:
  139. raise TypeError("arg 1 should be a list or a str.")
  140. self.set_to_default_values()
  141. self.print_func = cast(None, PRINT_STRING_FUN)
  142. weight_label = []
  143. weight = []
  144. i = 0
  145. while i < len(argv):
  146. if argv[i] == "-s":
  147. i = i + 1
  148. self.svm_type = int(argv[i])
  149. elif argv[i] == "-t":
  150. i = i + 1
  151. self.kernel_type = int(argv[i])
  152. elif argv[i] == "-d":
  153. i = i + 1
  154. self.degree = int(argv[i])
  155. elif argv[i] == "-g":
  156. i = i + 1
  157. self.gamma = float(argv[i])
  158. elif argv[i] == "-r":
  159. i = i + 1
  160. self.coef0 = float(argv[i])
  161. elif argv[i] == "-n":
  162. i = i + 1
  163. self.nu = float(argv[i])
  164. elif argv[i] == "-m":
  165. i = i + 1
  166. self.cache_size = float(argv[i])
  167. elif argv[i] == "-c":
  168. i = i + 1
  169. self.C = float(argv[i])
  170. elif argv[i] == "-e":
  171. i = i + 1
  172. self.eps = float(argv[i])
  173. elif argv[i] == "-p":
  174. i = i + 1
  175. self.p = float(argv[i])
  176. elif argv[i] == "-h":
  177. i = i + 1
  178. self.shrinking = int(argv[i])
  179. elif argv[i] == "-b":
  180. i = i + 1
  181. self.probability = int(argv[i])
  182. elif argv[i] == "-q":
  183. self.print_func = PRINT_STRING_FUN(print_null)
  184. elif argv[i] == "-v":
  185. i = i + 1
  186. self.cross_validation = 1
  187. self.nr_fold = int(argv[i])
  188. if self.nr_fold < 2:
  189. raise ValueError("n-fold cross validation: n must >= 2")
  190. elif argv[i].startswith("-w"):
  191. i = i + 1
  192. self.nr_weight += 1
  193. weight_label += [int(argv[i-1][2:])]
  194. weight += [float(argv[i])]
  195. else:
  196. raise ValueError("Wrong options")
  197. i += 1
  198. libsvm.svm_set_print_string_function(self.print_func)
  199. self.weight_label = (c_int*self.nr_weight)()
  200. self.weight = (c_double*self.nr_weight)()
  201. for i in range(self.nr_weight):
  202. self.weight[i] = weight[i]
  203. self.weight_label[i] = weight_label[i]
  204. class svm_model(Structure):
  205. _names = ['param', 'nr_class', 'l', 'SV', 'sv_coef', 'rho',
  206. 'probA', 'probB', 'sv_indices', 'label', 'nSV', 'free_sv']
  207. _types = [svm_parameter, c_int, c_int, POINTER(POINTER(svm_node)),
  208. POINTER(POINTER(c_double)), POINTER(c_double),
  209. POINTER(c_double), POINTER(c_double), POINTER(c_int),
  210. POINTER(c_int), POINTER(c_int), c_int]
  211. _fields_ = genFields(_names, _types)
  212. def __init__(self):
  213. self.__createfrom__ = 'python'
  214. def __del__(self):
  215. # free memory created by C to avoid memory leak
  216. if hasattr(self, '__createfrom__') and self.__createfrom__ == 'C':
  217. libsvm.svm_free_and_destroy_model(pointer(self))
  218. def get_svm_type(self):
  219. return libsvm.svm_get_svm_type(self)
  220. def get_nr_class(self):
  221. return libsvm.svm_get_nr_class(self)
  222. def get_svr_probability(self):
  223. return libsvm.svm_get_svr_probability(self)
  224. def get_labels(self):
  225. nr_class = self.get_nr_class()
  226. labels = (c_int * nr_class)()
  227. libsvm.svm_get_labels(self, labels)
  228. return labels[:nr_class]
  229. def get_sv_indices(self):
  230. total_sv = self.get_nr_sv()
  231. sv_indices = (c_int * total_sv)()
  232. libsvm.svm_get_sv_indices(self, sv_indices)
  233. return sv_indices[:total_sv]
  234. def get_nr_sv(self):
  235. return libsvm.svm_get_nr_sv(self)
  236. def is_probability_model(self):
  237. return (libsvm.svm_check_probability_model(self) == 1)
  238. def get_sv_coef(self):
  239. return [tuple(self.sv_coef[j][i] for j in xrange(self.nr_class - 1))
  240. for i in xrange(self.l)]
  241. def get_SV(self):
  242. result = []
  243. for sparse_sv in self.SV[:self.l]:
  244. row = dict()
  245. i = 0
  246. while True:
  247. row[sparse_sv[i].index] = sparse_sv[i].value
  248. if sparse_sv[i].index == -1:
  249. break
  250. i += 1
  251. result.append(row)
  252. return result
  253. def toPyModel(model_ptr):
  254. """
  255. toPyModel(model_ptr) -> svm_model
  256. Convert a ctypes POINTER(svm_model) to a Python svm_model
  257. """
  258. if bool(model_ptr) == False:
  259. raise ValueError("Null pointer")
  260. m = model_ptr.contents
  261. m.__createfrom__ = 'C'
  262. return m
  263. fillprototype(libsvm.svm_train, POINTER(svm_model), [POINTER(svm_problem), POINTER(svm_parameter)])
  264. fillprototype(libsvm.svm_cross_validation, None, [POINTER(svm_problem), POINTER(svm_parameter), c_int, POINTER(c_double)])
  265. fillprototype(libsvm.svm_save_model, c_int, [c_char_p, POINTER(svm_model)])
  266. fillprototype(libsvm.svm_load_model, POINTER(svm_model), [c_char_p])
  267. fillprototype(libsvm.svm_get_svm_type, c_int, [POINTER(svm_model)])
  268. fillprototype(libsvm.svm_get_nr_class, c_int, [POINTER(svm_model)])
  269. fillprototype(libsvm.svm_get_labels, None, [POINTER(svm_model), POINTER(c_int)])
  270. fillprototype(libsvm.svm_get_sv_indices, None, [POINTER(svm_model), POINTER(c_int)])
  271. fillprototype(libsvm.svm_get_nr_sv, c_int, [POINTER(svm_model)])
  272. fillprototype(libsvm.svm_get_svr_probability, c_double, [POINTER(svm_model)])
  273. fillprototype(libsvm.svm_predict_values, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
  274. fillprototype(libsvm.svm_predict, c_double, [POINTER(svm_model), POINTER(svm_node)])
  275. fillprototype(libsvm.svm_predict_probability, c_double, [POINTER(svm_model), POINTER(svm_node), POINTER(c_double)])
  276. fillprototype(libsvm.svm_free_model_content, None, [POINTER(svm_model)])
  277. fillprototype(libsvm.svm_free_and_destroy_model, None, [POINTER(POINTER(svm_model))])
  278. fillprototype(libsvm.svm_destroy_param, None, [POINTER(svm_parameter)])
  279. fillprototype(libsvm.svm_check_parameter, c_char_p, [POINTER(svm_problem), POINTER(svm_parameter)])
  280. fillprototype(libsvm.svm_check_probability_model, c_int, [POINTER(svm_model)])
  281. fillprototype(libsvm.svm_set_print_string_function, None, [PRINT_STRING_FUN])

A Python package for graph kernels, graph edit distances and graph pre-image problem.