You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

svm-toy.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. #include <windows.h>
  2. #include <windowsx.h>
  3. #include <stdio.h>
  4. #include <string.h>
  5. #include <ctype.h>
  6. #include <list>
  7. #include "../../svm.h"
  8. using namespace std;
  9. #define DEFAULT_PARAM "-t 2 -c 100"
  10. #define XLEN 500
  11. #define YLEN 500
  12. #define DrawLine(dc,x1,y1,x2,y2,c) \
  13. do { \
  14. HPEN hpen = CreatePen(PS_SOLID,0,c); \
  15. HPEN horig = SelectPen(dc,hpen); \
  16. MoveToEx(dc,x1,y1,NULL); \
  17. LineTo(dc,x2,y2); \
  18. SelectPen(dc,horig); \
  19. DeletePen(hpen); \
  20. } while(0)
  21. using namespace std;
  22. COLORREF colors[] =
  23. {
  24. RGB(0,0,0),
  25. RGB(0,120,120),
  26. RGB(120,120,0),
  27. RGB(120,0,120),
  28. RGB(0,200,200),
  29. RGB(200,200,0),
  30. RGB(200,0,200)
  31. };
  32. HWND main_window;
  33. HBITMAP buffer;
  34. HDC window_dc;
  35. HDC buffer_dc;
  36. HBRUSH brush1, brush2, brush3;
  37. HWND edit;
  38. enum {
  39. ID_BUTTON_CHANGE, ID_BUTTON_RUN, ID_BUTTON_CLEAR,
  40. ID_BUTTON_LOAD, ID_BUTTON_SAVE, ID_EDIT
  41. };
  42. struct point {
  43. double x, y;
  44. signed char value;
  45. };
  46. list<point> point_list;
  47. int current_value = 1;
  48. LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);
  49. int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance,
  50. PSTR szCmdLine, int iCmdShow)
  51. {
  52. static char szAppName[] = "SvmToy";
  53. MSG msg;
  54. WNDCLASSEX wndclass;
  55. wndclass.cbSize = sizeof(wndclass);
  56. wndclass.style = CS_HREDRAW | CS_VREDRAW;
  57. wndclass.lpfnWndProc = WndProc;
  58. wndclass.cbClsExtra = 0;
  59. wndclass.cbWndExtra = 0;
  60. wndclass.hInstance = hInstance;
  61. wndclass.hIcon = LoadIcon(NULL, IDI_APPLICATION);
  62. wndclass.hCursor = LoadCursor(NULL, IDC_ARROW);
  63. wndclass.hbrBackground = (HBRUSH) GetStockObject(BLACK_BRUSH);
  64. wndclass.lpszMenuName = NULL;
  65. wndclass.lpszClassName = szAppName;
  66. wndclass.hIconSm = LoadIcon(NULL, IDI_APPLICATION);
  67. RegisterClassEx(&wndclass);
  68. main_window = CreateWindow(szAppName, // window class name
  69. "SVM Toy", // window caption
  70. WS_OVERLAPPEDWINDOW,// window style
  71. CW_USEDEFAULT, // initial x position
  72. CW_USEDEFAULT, // initial y position
  73. XLEN, // initial x size
  74. YLEN+52, // initial y size
  75. NULL, // parent window handle
  76. NULL, // window menu handle
  77. hInstance, // program instance handle
  78. NULL); // creation parameters
  79. ShowWindow(main_window, iCmdShow);
  80. UpdateWindow(main_window);
  81. CreateWindow("button", "Change", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  82. 0, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_CHANGE, hInstance, NULL);
  83. CreateWindow("button", "Run", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  84. 50, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_RUN, hInstance, NULL);
  85. CreateWindow("button", "Clear", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  86. 100, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_CLEAR, hInstance, NULL);
  87. CreateWindow("button", "Save", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  88. 150, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_SAVE, hInstance, NULL);
  89. CreateWindow("button", "Load", WS_CHILD | WS_VISIBLE | BS_PUSHBUTTON,
  90. 200, YLEN, 50, 25, main_window, (HMENU) ID_BUTTON_LOAD, hInstance, NULL);
  91. edit = CreateWindow("edit", NULL, WS_CHILD | WS_VISIBLE,
  92. 250, YLEN, 250, 25, main_window, (HMENU) ID_EDIT, hInstance, NULL);
  93. Edit_SetText(edit,DEFAULT_PARAM);
  94. brush1 = CreateSolidBrush(colors[4]);
  95. brush2 = CreateSolidBrush(colors[5]);
  96. brush3 = CreateSolidBrush(colors[6]);
  97. window_dc = GetDC(main_window);
  98. buffer = CreateCompatibleBitmap(window_dc, XLEN, YLEN);
  99. buffer_dc = CreateCompatibleDC(window_dc);
  100. SelectObject(buffer_dc, buffer);
  101. PatBlt(buffer_dc, 0, 0, XLEN, YLEN, BLACKNESS);
  102. while (GetMessage(&msg, NULL, 0, 0)) {
  103. TranslateMessage(&msg);
  104. DispatchMessage(&msg);
  105. }
  106. return msg.wParam;
  107. }
  108. int getfilename( HWND hWnd , char *filename, int len, int save)
  109. {
  110. OPENFILENAME OpenFileName;
  111. memset(&OpenFileName,0,sizeof(OpenFileName));
  112. filename[0]='\0';
  113. OpenFileName.lStructSize = sizeof(OPENFILENAME);
  114. OpenFileName.hwndOwner = hWnd;
  115. OpenFileName.lpstrFile = filename;
  116. OpenFileName.nMaxFile = len;
  117. OpenFileName.Flags = 0;
  118. return save?GetSaveFileName(&OpenFileName):GetOpenFileName(&OpenFileName);
  119. }
  120. void clear_all()
  121. {
  122. point_list.clear();
  123. PatBlt(buffer_dc, 0, 0, XLEN, YLEN, BLACKNESS);
  124. InvalidateRect(main_window, 0, 0);
  125. }
  126. HBRUSH choose_brush(int v)
  127. {
  128. if(v==1) return brush1;
  129. else if(v==2) return brush2;
  130. else return brush3;
  131. }
  132. void draw_point(const point & p)
  133. {
  134. RECT rect;
  135. rect.left = int(p.x*XLEN);
  136. rect.top = int(p.y*YLEN);
  137. rect.right = int(p.x*XLEN) + 3;
  138. rect.bottom = int(p.y*YLEN) + 3;
  139. FillRect(window_dc, &rect, choose_brush(p.value));
  140. FillRect(buffer_dc, &rect, choose_brush(p.value));
  141. }
  142. void draw_all_points()
  143. {
  144. for(list<point>::iterator p = point_list.begin(); p != point_list.end(); p++)
  145. draw_point(*p);
  146. }
  147. void button_run_clicked()
  148. {
  149. // guard
  150. if(point_list.empty()) return;
  151. svm_parameter param;
  152. int i,j;
  153. // default values
  154. param.svm_type = C_SVC;
  155. param.kernel_type = RBF;
  156. param.degree = 3;
  157. param.gamma = 0;
  158. param.coef0 = 0;
  159. param.nu = 0.5;
  160. param.cache_size = 100;
  161. param.C = 1;
  162. param.eps = 1e-3;
  163. param.p = 0.1;
  164. param.shrinking = 1;
  165. param.probability = 0;
  166. param.nr_weight = 0;
  167. param.weight_label = NULL;
  168. param.weight = NULL;
  169. // parse options
  170. char str[1024];
  171. Edit_GetLine(edit, 0, str, sizeof(str));
  172. const char *p = str;
  173. while (1) {
  174. while (*p && *p != '-')
  175. p++;
  176. if (*p == '\0')
  177. break;
  178. p++;
  179. switch (*p++) {
  180. case 's':
  181. param.svm_type = atoi(p);
  182. break;
  183. case 't':
  184. param.kernel_type = atoi(p);
  185. break;
  186. case 'd':
  187. param.degree = atoi(p);
  188. break;
  189. case 'g':
  190. param.gamma = atof(p);
  191. break;
  192. case 'r':
  193. param.coef0 = atof(p);
  194. break;
  195. case 'n':
  196. param.nu = atof(p);
  197. break;
  198. case 'm':
  199. param.cache_size = atof(p);
  200. break;
  201. case 'c':
  202. param.C = atof(p);
  203. break;
  204. case 'e':
  205. param.eps = atof(p);
  206. break;
  207. case 'p':
  208. param.p = atof(p);
  209. break;
  210. case 'h':
  211. param.shrinking = atoi(p);
  212. break;
  213. case 'b':
  214. param.probability = atoi(p);
  215. break;
  216. case 'w':
  217. ++param.nr_weight;
  218. param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
  219. param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
  220. param.weight_label[param.nr_weight-1] = atoi(p);
  221. while(*p && !isspace(*p)) ++p;
  222. param.weight[param.nr_weight-1] = atof(p);
  223. break;
  224. }
  225. }
  226. // build problem
  227. svm_problem prob;
  228. prob.l = point_list.size();
  229. prob.y = new double[prob.l];
  230. if(param.kernel_type == PRECOMPUTED)
  231. {
  232. }
  233. else if(param.svm_type == EPSILON_SVR ||
  234. param.svm_type == NU_SVR)
  235. {
  236. if(param.gamma == 0) param.gamma = 1;
  237. svm_node *x_space = new svm_node[2 * prob.l];
  238. prob.x = new svm_node *[prob.l];
  239. i = 0;
  240. for (list<point>::iterator q = point_list.begin(); q != point_list.end(); q++, i++)
  241. {
  242. x_space[2 * i].index = 1;
  243. x_space[2 * i].value = q->x;
  244. x_space[2 * i + 1].index = -1;
  245. prob.x[i] = &x_space[2 * i];
  246. prob.y[i] = q->y;
  247. }
  248. // build model & classify
  249. svm_model *model = svm_train(&prob, &param);
  250. svm_node x[2];
  251. x[0].index = 1;
  252. x[1].index = -1;
  253. int *j = new int[XLEN];
  254. for (i = 0; i < XLEN; i++)
  255. {
  256. x[0].value = (double) i / XLEN;
  257. j[i] = (int)(YLEN*svm_predict(model, x));
  258. }
  259. DrawLine(buffer_dc,0,0,0,YLEN,colors[0]);
  260. DrawLine(window_dc,0,0,0,YLEN,colors[0]);
  261. int p = (int)(param.p * YLEN);
  262. for(int i=1; i < XLEN; i++)
  263. {
  264. DrawLine(buffer_dc,i,0,i,YLEN,colors[0]);
  265. DrawLine(window_dc,i,0,i,YLEN,colors[0]);
  266. DrawLine(buffer_dc,i-1,j[i-1],i,j[i],colors[5]);
  267. DrawLine(window_dc,i-1,j[i-1],i,j[i],colors[5]);
  268. if(param.svm_type == EPSILON_SVR)
  269. {
  270. DrawLine(buffer_dc,i-1,j[i-1]+p,i,j[i]+p,colors[2]);
  271. DrawLine(window_dc,i-1,j[i-1]+p,i,j[i]+p,colors[2]);
  272. DrawLine(buffer_dc,i-1,j[i-1]-p,i,j[i]-p,colors[2]);
  273. DrawLine(window_dc,i-1,j[i-1]-p,i,j[i]-p,colors[2]);
  274. }
  275. }
  276. svm_free_and_destroy_model(&model);
  277. delete[] j;
  278. delete[] x_space;
  279. delete[] prob.x;
  280. delete[] prob.y;
  281. }
  282. else
  283. {
  284. if(param.gamma == 0) param.gamma = 0.5;
  285. svm_node *x_space = new svm_node[3 * prob.l];
  286. prob.x = new svm_node *[prob.l];
  287. i = 0;
  288. for (list<point>::iterator q = point_list.begin(); q != point_list.end(); q++, i++)
  289. {
  290. x_space[3 * i].index = 1;
  291. x_space[3 * i].value = q->x;
  292. x_space[3 * i + 1].index = 2;
  293. x_space[3 * i + 1].value = q->y;
  294. x_space[3 * i + 2].index = -1;
  295. prob.x[i] = &x_space[3 * i];
  296. prob.y[i] = q->value;
  297. }
  298. // build model & classify
  299. svm_model *model = svm_train(&prob, &param);
  300. svm_node x[3];
  301. x[0].index = 1;
  302. x[1].index = 2;
  303. x[2].index = -1;
  304. for (i = 0; i < XLEN; i++)
  305. for (j = 0; j < YLEN; j++) {
  306. x[0].value = (double) i / XLEN;
  307. x[1].value = (double) j / YLEN;
  308. double d = svm_predict(model, x);
  309. if (param.svm_type == ONE_CLASS && d<0) d=2;
  310. SetPixel(window_dc, i, j, colors[(int)d]);
  311. SetPixel(buffer_dc, i, j, colors[(int)d]);
  312. }
  313. svm_free_and_destroy_model(&model);
  314. delete[] x_space;
  315. delete[] prob.x;
  316. delete[] prob.y;
  317. }
  318. free(param.weight_label);
  319. free(param.weight);
  320. draw_all_points();
  321. }
  322. LRESULT CALLBACK WndProc(HWND hwnd, UINT iMsg, WPARAM wParam, LPARAM lParam)
  323. {
  324. HDC hdc;
  325. PAINTSTRUCT ps;
  326. switch (iMsg) {
  327. case WM_LBUTTONDOWN:
  328. {
  329. int x = LOWORD(lParam);
  330. int y = HIWORD(lParam);
  331. point p = {(double)x/XLEN, (double)y/YLEN, current_value};
  332. point_list.push_back(p);
  333. draw_point(p);
  334. }
  335. return 0;
  336. case WM_PAINT:
  337. {
  338. hdc = BeginPaint(hwnd, &ps);
  339. BitBlt(hdc, 0, 0, XLEN, YLEN, buffer_dc, 0, 0, SRCCOPY);
  340. EndPaint(hwnd, &ps);
  341. }
  342. return 0;
  343. case WM_COMMAND:
  344. {
  345. int id = LOWORD(wParam);
  346. switch (id) {
  347. case ID_BUTTON_CHANGE:
  348. ++current_value;
  349. if(current_value > 3) current_value = 1;
  350. break;
  351. case ID_BUTTON_RUN:
  352. button_run_clicked();
  353. break;
  354. case ID_BUTTON_CLEAR:
  355. clear_all();
  356. break;
  357. case ID_BUTTON_SAVE:
  358. {
  359. char filename[1024];
  360. if(getfilename(hwnd,filename,1024,1))
  361. {
  362. FILE *fp = fopen(filename,"w");
  363. char str[1024];
  364. Edit_GetLine(edit, 0, str, sizeof(str));
  365. const char *p = str;
  366. const char* svm_type_str = strstr(p, "-s ");
  367. int svm_type = C_SVC;
  368. if(svm_type_str != NULL)
  369. sscanf(svm_type_str, "-s %d", &svm_type);
  370. if(fp)
  371. {
  372. if(svm_type == EPSILON_SVR || svm_type == NU_SVR)
  373. {
  374. for(list<point>::iterator p = point_list.begin(); p != point_list.end();p++)
  375. fprintf(fp,"%f 1:%f\n", p->y, p->x);
  376. }
  377. else
  378. {
  379. for(list<point>::iterator p = point_list.begin(); p != point_list.end();p++)
  380. fprintf(fp,"%d 1:%f 2:%f\n", p->value, p->x, p->y);
  381. }
  382. fclose(fp);
  383. }
  384. }
  385. }
  386. break;
  387. case ID_BUTTON_LOAD:
  388. {
  389. char filename[1024];
  390. if(getfilename(hwnd,filename,1024,0))
  391. {
  392. FILE *fp = fopen(filename,"r");
  393. if(fp)
  394. {
  395. clear_all();
  396. char buf[4096];
  397. while(fgets(buf,sizeof(buf),fp))
  398. {
  399. int v;
  400. double x,y;
  401. if(sscanf(buf,"%d%*d:%lf%*d:%lf",&v,&x,&y)==3)
  402. {
  403. point p = {x,y,v};
  404. point_list.push_back(p);
  405. }
  406. else if(sscanf(buf,"%lf%*d:%lf",&y,&x)==2)
  407. {
  408. point p = {x,y,current_value};
  409. point_list.push_back(p);
  410. }
  411. else
  412. break;
  413. }
  414. fclose(fp);
  415. draw_all_points();
  416. }
  417. }
  418. }
  419. break;
  420. }
  421. }
  422. return 0;
  423. case WM_DESTROY:
  424. PostQuitMessage(0);
  425. return 0;
  426. }
  427. return DefWindowProc(hwnd, iMsg, wParam, lParam);
  428. }

A Python package for graph kernels, graph edit distances and graph pre-image problem.