You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

svmpredict.c 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <string.h>
  4. #include "svm.h"
  5. #include "mex.h"
  6. #include "svm_model_matlab.h"
  7. #ifdef MX_API_VER
  8. #if MX_API_VER < 0x07030000
  9. typedef int mwIndex;
  10. #endif
  11. #endif
  12. #define CMD_LEN 2048
  13. int print_null(const char *s,...) {}
  14. int (*info)(const char *fmt,...) = &mexPrintf;
  15. void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
  16. {
  17. int i, j, low, high;
  18. mwIndex *ir, *jc;
  19. double *samples;
  20. ir = mxGetIr(prhs);
  21. jc = mxGetJc(prhs);
  22. samples = mxGetPr(prhs);
  23. // each column is one instance
  24. j = 0;
  25. low = (int)jc[index], high = (int)jc[index+1];
  26. for(i=low;i<high;i++)
  27. {
  28. x[j].index = (int)ir[i] + 1;
  29. x[j].value = samples[i];
  30. j++;
  31. }
  32. x[j].index = -1;
  33. }
  34. static void fake_answer(int nlhs, mxArray *plhs[])
  35. {
  36. int i;
  37. for(i=0;i<nlhs;i++)
  38. plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
  39. }
  40. void predict(int nlhs, mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
  41. {
  42. int label_vector_row_num, label_vector_col_num;
  43. int feature_number, testing_instance_number;
  44. int instance_index;
  45. double *ptr_instance, *ptr_label, *ptr_predict_label;
  46. double *ptr_prob_estimates, *ptr_dec_values, *ptr;
  47. struct svm_node *x;
  48. mxArray *pplhs[1]; // transposed instance sparse matrix
  49. mxArray *tplhs[3]; // temporary storage for plhs[]
  50. int correct = 0;
  51. int total = 0;
  52. double error = 0;
  53. double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
  54. int svm_type=svm_get_svm_type(model);
  55. int nr_class=svm_get_nr_class(model);
  56. double *prob_estimates=NULL;
  57. // prhs[1] = testing instance matrix
  58. feature_number = (int)mxGetN(prhs[1]);
  59. testing_instance_number = (int)mxGetM(prhs[1]);
  60. label_vector_row_num = (int)mxGetM(prhs[0]);
  61. label_vector_col_num = (int)mxGetN(prhs[0]);
  62. if(label_vector_row_num!=testing_instance_number)
  63. {
  64. mexPrintf("Length of label vector does not match # of instances.\n");
  65. fake_answer(nlhs, plhs);
  66. return;
  67. }
  68. if(label_vector_col_num!=1)
  69. {
  70. mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
  71. fake_answer(nlhs, plhs);
  72. return;
  73. }
  74. ptr_instance = mxGetPr(prhs[1]);
  75. ptr_label = mxGetPr(prhs[0]);
  76. // transpose instance matrix
  77. if(mxIsSparse(prhs[1]))
  78. {
  79. if(model->param.kernel_type == PRECOMPUTED)
  80. {
  81. // precomputed kernel requires dense matrix, so we make one
  82. mxArray *rhs[1], *lhs[1];
  83. rhs[0] = mxDuplicateArray(prhs[1]);
  84. if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
  85. {
  86. mexPrintf("Error: cannot full testing instance matrix\n");
  87. fake_answer(nlhs, plhs);
  88. return;
  89. }
  90. ptr_instance = mxGetPr(lhs[0]);
  91. mxDestroyArray(rhs[0]);
  92. }
  93. else
  94. {
  95. mxArray *pprhs[1];
  96. pprhs[0] = mxDuplicateArray(prhs[1]);
  97. if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
  98. {
  99. mexPrintf("Error: cannot transpose testing instance matrix\n");
  100. fake_answer(nlhs, plhs);
  101. return;
  102. }
  103. }
  104. }
  105. if(predict_probability)
  106. {
  107. if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
  108. info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
  109. else
  110. prob_estimates = (double *) malloc(nr_class*sizeof(double));
  111. }
  112. tplhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
  113. if(predict_probability)
  114. {
  115. // prob estimates are in plhs[2]
  116. if(svm_type==C_SVC || svm_type==NU_SVC)
  117. tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
  118. else
  119. tplhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
  120. }
  121. else
  122. {
  123. // decision values are in plhs[2]
  124. if(svm_type == ONE_CLASS ||
  125. svm_type == EPSILON_SVR ||
  126. svm_type == NU_SVR ||
  127. nr_class == 1) // if only one class in training data, decision values are still returned.
  128. tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
  129. else
  130. tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
  131. }
  132. ptr_predict_label = mxGetPr(tplhs[0]);
  133. ptr_prob_estimates = mxGetPr(tplhs[2]);
  134. ptr_dec_values = mxGetPr(tplhs[2]);
  135. x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
  136. for(instance_index=0;instance_index<testing_instance_number;instance_index++)
  137. {
  138. int i;
  139. double target_label, predict_label;
  140. target_label = ptr_label[instance_index];
  141. if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse
  142. read_sparse_instance(pplhs[0], instance_index, x);
  143. else
  144. {
  145. for(i=0;i<feature_number;i++)
  146. {
  147. x[i].index = i+1;
  148. x[i].value = ptr_instance[testing_instance_number*i+instance_index];
  149. }
  150. x[feature_number].index = -1;
  151. }
  152. if(predict_probability)
  153. {
  154. if(svm_type==C_SVC || svm_type==NU_SVC)
  155. {
  156. predict_label = svm_predict_probability(model, x, prob_estimates);
  157. ptr_predict_label[instance_index] = predict_label;
  158. for(i=0;i<nr_class;i++)
  159. ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
  160. } else {
  161. predict_label = svm_predict(model,x);
  162. ptr_predict_label[instance_index] = predict_label;
  163. }
  164. }
  165. else
  166. {
  167. if(svm_type == ONE_CLASS ||
  168. svm_type == EPSILON_SVR ||
  169. svm_type == NU_SVR)
  170. {
  171. double res;
  172. predict_label = svm_predict_values(model, x, &res);
  173. ptr_dec_values[instance_index] = res;
  174. }
  175. else
  176. {
  177. double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
  178. predict_label = svm_predict_values(model, x, dec_values);
  179. if(nr_class == 1)
  180. ptr_dec_values[instance_index] = 1;
  181. else
  182. for(i=0;i<(nr_class*(nr_class-1))/2;i++)
  183. ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
  184. free(dec_values);
  185. }
  186. ptr_predict_label[instance_index] = predict_label;
  187. }
  188. if(predict_label == target_label)
  189. ++correct;
  190. error += (predict_label-target_label)*(predict_label-target_label);
  191. sump += predict_label;
  192. sumt += target_label;
  193. sumpp += predict_label*predict_label;
  194. sumtt += target_label*target_label;
  195. sumpt += predict_label*target_label;
  196. ++total;
  197. }
  198. if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
  199. {
  200. info("Mean squared error = %g (regression)\n",error/total);
  201. info("Squared correlation coefficient = %g (regression)\n",
  202. ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
  203. ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
  204. );
  205. }
  206. else
  207. info("Accuracy = %g%% (%d/%d) (classification)\n",
  208. (double)correct/total*100,correct,total);
  209. // return accuracy, mean squared error, squared correlation coefficient
  210. tplhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
  211. ptr = mxGetPr(tplhs[1]);
  212. ptr[0] = (double)correct/total*100;
  213. ptr[1] = error/total;
  214. ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
  215. ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
  216. free(x);
  217. if(prob_estimates != NULL)
  218. free(prob_estimates);
  219. switch(nlhs)
  220. {
  221. case 3:
  222. plhs[2] = tplhs[2];
  223. plhs[1] = tplhs[1];
  224. case 1:
  225. case 0:
  226. plhs[0] = tplhs[0];
  227. }
  228. }
  229. void exit_with_help()
  230. {
  231. mexPrintf(
  232. "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
  233. " [predicted_label] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
  234. "Parameters:\n"
  235. " model: SVM model structure from svmtrain.\n"
  236. " libsvm_options:\n"
  237. " -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
  238. " -q : quiet mode (no outputs)\n"
  239. "Returns:\n"
  240. " predicted_label: SVM prediction output vector.\n"
  241. " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
  242. " prob_estimates: If selected, probability estimate vector.\n"
  243. );
  244. }
  245. void mexFunction( int nlhs, mxArray *plhs[],
  246. int nrhs, const mxArray *prhs[] )
  247. {
  248. int prob_estimate_flag = 0;
  249. struct svm_model *model;
  250. info = &mexPrintf;
  251. if(nlhs == 2 || nlhs > 3 || nrhs > 4 || nrhs < 3)
  252. {
  253. exit_with_help();
  254. fake_answer(nlhs, plhs);
  255. return;
  256. }
  257. if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
  258. mexPrintf("Error: label vector and instance matrix must be double\n");
  259. fake_answer(nlhs, plhs);
  260. return;
  261. }
  262. if(mxIsStruct(prhs[2]))
  263. {
  264. const char *error_msg;
  265. // parse options
  266. if(nrhs==4)
  267. {
  268. int i, argc = 1;
  269. char cmd[CMD_LEN], *argv[CMD_LEN/2];
  270. // put options in argv[]
  271. mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
  272. if((argv[argc] = strtok(cmd, " ")) != NULL)
  273. while((argv[++argc] = strtok(NULL, " ")) != NULL)
  274. ;
  275. for(i=1;i<argc;i++)
  276. {
  277. if(argv[i][0] != '-') break;
  278. if((++i>=argc) && argv[i-1][1] != 'q')
  279. {
  280. exit_with_help();
  281. fake_answer(nlhs, plhs);
  282. return;
  283. }
  284. switch(argv[i-1][1])
  285. {
  286. case 'b':
  287. prob_estimate_flag = atoi(argv[i]);
  288. break;
  289. case 'q':
  290. i--;
  291. info = &print_null;
  292. break;
  293. default:
  294. mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
  295. exit_with_help();
  296. fake_answer(nlhs, plhs);
  297. return;
  298. }
  299. }
  300. }
  301. model = matlab_matrix_to_model(prhs[2], &error_msg);
  302. if (model == NULL)
  303. {
  304. mexPrintf("Error: can't read model: %s\n", error_msg);
  305. fake_answer(nlhs, plhs);
  306. return;
  307. }
  308. if(prob_estimate_flag)
  309. {
  310. if(svm_check_probability_model(model)==0)
  311. {
  312. mexPrintf("Model does not support probabiliy estimates\n");
  313. fake_answer(nlhs, plhs);
  314. svm_free_and_destroy_model(&model);
  315. return;
  316. }
  317. }
  318. else
  319. {
  320. if(svm_check_probability_model(model)!=0)
  321. info("Model supports probability estimates, but disabled in predicton.\n");
  322. }
  323. predict(nlhs, plhs, prhs, model, prob_estimate_flag);
  324. // destroy model
  325. svm_free_and_destroy_model(&model);
  326. }
  327. else
  328. {
  329. mexPrintf("model file should be a struct array\n");
  330. fake_answer(nlhs, plhs);
  331. }
  332. return;
  333. }

A Python package for graph kernels, graph edit distances and graph pre-image problem.