You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

svm_model_matlab.c 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #include <stdlib.h>
  2. #include <string.h>
  3. #include "svm.h"
  4. #include "mex.h"
  5. #ifdef MX_API_VER
  6. #if MX_API_VER < 0x07030000
  7. typedef int mwIndex;
  8. #endif
  9. #endif
  10. #define NUM_OF_RETURN_FIELD 11
  11. #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
  12. static const char *field_names[] = {
  13. "Parameters",
  14. "nr_class",
  15. "totalSV",
  16. "rho",
  17. "Label",
  18. "sv_indices",
  19. "ProbA",
  20. "ProbB",
  21. "nSV",
  22. "sv_coef",
  23. "SVs"
  24. };
  25. const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
  26. {
  27. int i, j, n;
  28. double *ptr;
  29. mxArray *return_model, **rhs;
  30. int out_id = 0;
  31. rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
  32. // Parameters
  33. rhs[out_id] = mxCreateDoubleMatrix(5, 1, mxREAL);
  34. ptr = mxGetPr(rhs[out_id]);
  35. ptr[0] = model->param.svm_type;
  36. ptr[1] = model->param.kernel_type;
  37. ptr[2] = model->param.degree;
  38. ptr[3] = model->param.gamma;
  39. ptr[4] = model->param.coef0;
  40. out_id++;
  41. // nr_class
  42. rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
  43. ptr = mxGetPr(rhs[out_id]);
  44. ptr[0] = model->nr_class;
  45. out_id++;
  46. // total SV
  47. rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
  48. ptr = mxGetPr(rhs[out_id]);
  49. ptr[0] = model->l;
  50. out_id++;
  51. // rho
  52. n = model->nr_class*(model->nr_class-1)/2;
  53. rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
  54. ptr = mxGetPr(rhs[out_id]);
  55. for(i = 0; i < n; i++)
  56. ptr[i] = model->rho[i];
  57. out_id++;
  58. // Label
  59. if(model->label)
  60. {
  61. rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
  62. ptr = mxGetPr(rhs[out_id]);
  63. for(i = 0; i < model->nr_class; i++)
  64. ptr[i] = model->label[i];
  65. }
  66. else
  67. rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
  68. out_id++;
  69. // sv_indices
  70. if(model->sv_indices)
  71. {
  72. rhs[out_id] = mxCreateDoubleMatrix(model->l, 1, mxREAL);
  73. ptr = mxGetPr(rhs[out_id]);
  74. for(i = 0; i < model->l; i++)
  75. ptr[i] = model->sv_indices[i];
  76. }
  77. else
  78. rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
  79. out_id++;
  80. // probA
  81. if(model->probA != NULL)
  82. {
  83. rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
  84. ptr = mxGetPr(rhs[out_id]);
  85. for(i = 0; i < n; i++)
  86. ptr[i] = model->probA[i];
  87. }
  88. else
  89. rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
  90. out_id ++;
  91. // probB
  92. if(model->probB != NULL)
  93. {
  94. rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
  95. ptr = mxGetPr(rhs[out_id]);
  96. for(i = 0; i < n; i++)
  97. ptr[i] = model->probB[i];
  98. }
  99. else
  100. rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
  101. out_id++;
  102. // nSV
  103. if(model->nSV)
  104. {
  105. rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
  106. ptr = mxGetPr(rhs[out_id]);
  107. for(i = 0; i < model->nr_class; i++)
  108. ptr[i] = model->nSV[i];
  109. }
  110. else
  111. rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
  112. out_id++;
  113. // sv_coef
  114. rhs[out_id] = mxCreateDoubleMatrix(model->l, model->nr_class-1, mxREAL);
  115. ptr = mxGetPr(rhs[out_id]);
  116. for(i = 0; i < model->nr_class-1; i++)
  117. for(j = 0; j < model->l; j++)
  118. ptr[(i*(model->l))+j] = model->sv_coef[i][j];
  119. out_id++;
  120. // SVs
  121. {
  122. int ir_index, nonzero_element;
  123. mwIndex *ir, *jc;
  124. mxArray *pprhs[1], *pplhs[1];
  125. if(model->param.kernel_type == PRECOMPUTED)
  126. {
  127. nonzero_element = model->l;
  128. num_of_feature = 1;
  129. }
  130. else
  131. {
  132. nonzero_element = 0;
  133. for(i = 0; i < model->l; i++) {
  134. j = 0;
  135. while(model->SV[i][j].index != -1)
  136. {
  137. nonzero_element++;
  138. j++;
  139. }
  140. }
  141. }
  142. // SV in column, easier accessing
  143. rhs[out_id] = mxCreateSparse(num_of_feature, model->l, nonzero_element, mxREAL);
  144. ir = mxGetIr(rhs[out_id]);
  145. jc = mxGetJc(rhs[out_id]);
  146. ptr = mxGetPr(rhs[out_id]);
  147. jc[0] = ir_index = 0;
  148. for(i = 0;i < model->l; i++)
  149. {
  150. if(model->param.kernel_type == PRECOMPUTED)
  151. {
  152. // make a (1 x model->l) matrix
  153. ir[ir_index] = 0;
  154. ptr[ir_index] = model->SV[i][0].value;
  155. ir_index++;
  156. jc[i+1] = jc[i] + 1;
  157. }
  158. else
  159. {
  160. int x_index = 0;
  161. while (model->SV[i][x_index].index != -1)
  162. {
  163. ir[ir_index] = model->SV[i][x_index].index - 1;
  164. ptr[ir_index] = model->SV[i][x_index].value;
  165. ir_index++, x_index++;
  166. }
  167. jc[i+1] = jc[i] + x_index;
  168. }
  169. }
  170. // transpose back to SV in row
  171. pprhs[0] = rhs[out_id];
  172. if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
  173. return "cannot transpose SV matrix";
  174. rhs[out_id] = pplhs[0];
  175. out_id++;
  176. }
  177. /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */
  178. return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
  179. /* Fill struct matrix with input arguments */
  180. for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
  181. mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
  182. /* return */
  183. plhs[0] = return_model;
  184. mxFree(rhs);
  185. return NULL;
  186. }
  187. struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
  188. {
  189. int i, j, n, num_of_fields;
  190. double *ptr;
  191. int id = 0;
  192. struct svm_node *x_space;
  193. struct svm_model *model;
  194. mxArray **rhs;
  195. num_of_fields = mxGetNumberOfFields(matlab_struct);
  196. if(num_of_fields != NUM_OF_RETURN_FIELD)
  197. {
  198. *msg = "number of return field is not correct";
  199. return NULL;
  200. }
  201. rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
  202. for(i=0;i<num_of_fields;i++)
  203. rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
  204. model = Malloc(struct svm_model, 1);
  205. model->rho = NULL;
  206. model->probA = NULL;
  207. model->probB = NULL;
  208. model->label = NULL;
  209. model->sv_indices = NULL;
  210. model->nSV = NULL;
  211. model->free_sv = 1; // XXX
  212. ptr = mxGetPr(rhs[id]);
  213. model->param.svm_type = (int)ptr[0];
  214. model->param.kernel_type = (int)ptr[1];
  215. model->param.degree = (int)ptr[2];
  216. model->param.gamma = ptr[3];
  217. model->param.coef0 = ptr[4];
  218. id++;
  219. ptr = mxGetPr(rhs[id]);
  220. model->nr_class = (int)ptr[0];
  221. id++;
  222. ptr = mxGetPr(rhs[id]);
  223. model->l = (int)ptr[0];
  224. id++;
  225. // rho
  226. n = model->nr_class * (model->nr_class-1)/2;
  227. model->rho = (double*) malloc(n*sizeof(double));
  228. ptr = mxGetPr(rhs[id]);
  229. for(i=0;i<n;i++)
  230. model->rho[i] = ptr[i];
  231. id++;
  232. // label
  233. if(mxIsEmpty(rhs[id]) == 0)
  234. {
  235. model->label = (int*) malloc(model->nr_class*sizeof(int));
  236. ptr = mxGetPr(rhs[id]);
  237. for(i=0;i<model->nr_class;i++)
  238. model->label[i] = (int)ptr[i];
  239. }
  240. id++;
  241. // sv_indices
  242. if(mxIsEmpty(rhs[id]) == 0)
  243. {
  244. model->sv_indices = (int*) malloc(model->l*sizeof(int));
  245. ptr = mxGetPr(rhs[id]);
  246. for(i=0;i<model->l;i++)
  247. model->sv_indices[i] = (int)ptr[i];
  248. }
  249. id++;
  250. // probA
  251. if(mxIsEmpty(rhs[id]) == 0)
  252. {
  253. model->probA = (double*) malloc(n*sizeof(double));
  254. ptr = mxGetPr(rhs[id]);
  255. for(i=0;i<n;i++)
  256. model->probA[i] = ptr[i];
  257. }
  258. id++;
  259. // probB
  260. if(mxIsEmpty(rhs[id]) == 0)
  261. {
  262. model->probB = (double*) malloc(n*sizeof(double));
  263. ptr = mxGetPr(rhs[id]);
  264. for(i=0;i<n;i++)
  265. model->probB[i] = ptr[i];
  266. }
  267. id++;
  268. // nSV
  269. if(mxIsEmpty(rhs[id]) == 0)
  270. {
  271. model->nSV = (int*) malloc(model->nr_class*sizeof(int));
  272. ptr = mxGetPr(rhs[id]);
  273. for(i=0;i<model->nr_class;i++)
  274. model->nSV[i] = (int)ptr[i];
  275. }
  276. id++;
  277. // sv_coef
  278. ptr = mxGetPr(rhs[id]);
  279. model->sv_coef = (double**) malloc((model->nr_class-1)*sizeof(double));
  280. for( i=0 ; i< model->nr_class -1 ; i++ )
  281. model->sv_coef[i] = (double*) malloc((model->l)*sizeof(double));
  282. for(i = 0; i < model->nr_class - 1; i++)
  283. for(j = 0; j < model->l; j++)
  284. model->sv_coef[i][j] = ptr[i*(model->l)+j];
  285. id++;
  286. // SV
  287. {
  288. int sr, elements;
  289. int num_samples;
  290. mwIndex *ir, *jc;
  291. mxArray *pprhs[1], *pplhs[1];
  292. // transpose SV
  293. pprhs[0] = rhs[id];
  294. if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
  295. {
  296. svm_free_and_destroy_model(&model);
  297. *msg = "cannot transpose SV matrix";
  298. return NULL;
  299. }
  300. rhs[id] = pplhs[0];
  301. sr = (int)mxGetN(rhs[id]);
  302. ptr = mxGetPr(rhs[id]);
  303. ir = mxGetIr(rhs[id]);
  304. jc = mxGetJc(rhs[id]);
  305. num_samples = (int)mxGetNzmax(rhs[id]);
  306. elements = num_samples + sr;
  307. model->SV = (struct svm_node **) malloc(sr * sizeof(struct svm_node *));
  308. x_space = (struct svm_node *)malloc(elements * sizeof(struct svm_node));
  309. // SV is in column
  310. for(i=0;i<sr;i++)
  311. {
  312. int low = (int)jc[i], high = (int)jc[i+1];
  313. int x_index = 0;
  314. model->SV[i] = &x_space[low+i];
  315. for(j=low;j<high;j++)
  316. {
  317. model->SV[i][x_index].index = (int)ir[j] + 1;
  318. model->SV[i][x_index].value = ptr[j];
  319. x_index++;
  320. }
  321. model->SV[i][x_index].index = -1;
  322. }
  323. id++;
  324. }
  325. mxFree(rhs);
  326. return model;
  327. }

A Python package for graph kernels, graph edit distances and graph pre-image problem.