You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

segmentation.py 9.5 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #coding=utf-8
  2. import cv2
  3. import numpy as np
  4. # from matplotlib import pyplot as plt
  5. import scipy.ndimage.filters as f
  6. import scipy
  7. import time
  8. import scipy.signal as l
  9. from keras.models import Sequential
  10. from keras.layers import Dense, Dropout, Activation, Flatten
  11. from keras.layers import Conv2D, MaxPool2D
  12. from keras.optimizers import SGD
  13. from keras import backend as K
  14. K.image_data_format()
  15. def Getmodel_tensorflow(nb_classes):
  16. # nb_classes = len(charset)
  17. img_rows, img_cols = 23, 23
  18. # number of convolutional filters to use
  19. nb_filters = 16
  20. # size of pooling area for max pooling
  21. nb_pool = 2
  22. # convolution kernel size
  23. nb_conv = 3
  24. # x = np.load('x.npy')
  25. # y = np_utils.to_categorical(range(3062)*45*5*2, nb_classes)
  26. # weight = ((type_class - np.arange(type_class)) / type_class + 1) ** 3
  27. # weight = dict(zip(range(3063), weight / weight.mean())) # 调整权重,高频字优先
  28. model = Sequential()
  29. model.add(Conv2D(nb_filters, (nb_conv, nb_conv),input_shape=(img_rows, img_cols,1)))
  30. model.add(Activation('relu'))
  31. model.add(MaxPool2D(pool_size=(nb_pool, nb_pool)))
  32. model.add(Conv2D(nb_filters, (nb_conv, nb_conv)))
  33. model.add(Activation('relu'))
  34. model.add(MaxPool2D(pool_size=(nb_pool, nb_pool)))
  35. model.add(Flatten())
  36. model.add(Dense(256))
  37. model.add(Dropout(0.5))
  38. model.add(Activation('relu'))
  39. model.add(Dense(nb_classes))
  40. model.add(Activation('softmax'))
  41. model.compile(loss='categorical_crossentropy',
  42. optimizer='sgd',
  43. metrics=['accuracy'])
  44. return model
  45. def Getmodel_tensorflow_light(nb_classes):
  46. # nb_classes = len(charset)
  47. img_rows, img_cols = 23, 23
  48. # number of convolutional filters to use
  49. nb_filters = 8
  50. # size of pooling area for max pooling
  51. nb_pool = 2
  52. # convolution kernel size
  53. nb_conv = 3
  54. # x = np.load('x.npy')
  55. # y = np_utils.to_categorical(range(3062)*45*5*2, nb_classes)
  56. # weight = ((type_class - np.arange(type_class)) / type_class + 1) ** 3
  57. # weight = dict(zip(range(3063), weight / weight.mean())) # 调整权重,高频字优先
  58. model = Sequential()
  59. model.add(Conv2D(nb_filters, (nb_conv, nb_conv),input_shape=(img_rows, img_cols, 1)))
  60. model.add(Activation('relu'))
  61. model.add(MaxPool2D(pool_size=(nb_pool, nb_pool)))
  62. model.add(Conv2D(nb_filters, (nb_conv * 2, nb_conv * 2)))
  63. model.add(Activation('relu'))
  64. model.add(MaxPool2D(pool_size=(nb_pool, nb_pool)))
  65. model.add(Flatten())
  66. model.add(Dense(32))
  67. # model.add(Dropout(0.25))
  68. model.add(Activation('relu'))
  69. model.add(Dense(nb_classes))
  70. model.add(Activation('softmax'))
  71. model.compile(loss='categorical_crossentropy',
  72. optimizer='adam',
  73. metrics=['accuracy'])
  74. return model
  75. model = Getmodel_tensorflow_light(3)
  76. model2 = Getmodel_tensorflow(3)
  77. import os
  78. model.load_weights("./model/char_judgement1.h5")
  79. # model.save("./model/char_judgement1.h5")
  80. model2.load_weights("./model/char_judgement.h5")
  81. # model2.save("./model/char_judgement.h5")
  82. model = model2
  83. def get_median(data):
  84. data = sorted(data)
  85. size = len(data)
  86. # print size
  87. if size % 2 == 0: # 判断列表长度为偶数
  88. median = (data[size//2]+data[size//2-1])/2
  89. data[0] = median
  90. if size % 2 == 1: # 判断列表长度为奇数
  91. median = data[(size-1)//2]
  92. data[0] = median
  93. return data[0]
  94. import time
  95. def searchOptimalCuttingPoint(rgb,res_map,start,width_boundingbox,interval_range):
  96. t0 = time.time()
  97. #
  98. # for x in xrange(10):
  99. # res_map = np.vstack((res_map,res_map[-1]))
  100. length = res_map.shape[0]
  101. refine_s = -2;
  102. if width_boundingbox>20:
  103. refine_s = -9
  104. score_list = []
  105. interval_big = int(width_boundingbox * 0.3) #
  106. p = 0
  107. for zero_add in xrange(start,start+50,3):
  108. # for interval_small in xrange(-0,width_boundingbox/2):
  109. for i in xrange(-8,int(width_boundingbox/1)-8):
  110. for refine in xrange(refine_s,width_boundingbox/2+3):
  111. p1 = zero_add# this point is province
  112. p2 = p1 + width_boundingbox +refine #
  113. p3 = p2 + width_boundingbox + interval_big+i+1
  114. p4 = p3 + width_boundingbox +refine
  115. p5 = p4 + width_boundingbox +refine
  116. p6 = p5 + width_boundingbox +refine
  117. p7 = p6 + width_boundingbox +refine
  118. if p7>=length:
  119. continue
  120. score = res_map[p1][2]*3 -(res_map[p3][1]+res_map[p4][1]+res_map[p5][1]+res_map[p6][1]+res_map[p7][1])+7
  121. # print score
  122. score_list.append([score,[p1,p2,p3,p4,p5,p6,p7]])
  123. p+=1
  124. print p
  125. score_list = sorted(score_list , key=lambda x:x[0])
  126. # for one in score_list[-1][1]:
  127. # cv2.line(debug,(one,0),(one,36),(255,0,0),1)
  128. # #
  129. # cv2.imshow("one",debug)
  130. # cv2.waitKey(0)
  131. #
  132. print "寻找最佳点",time.time()-t0
  133. return score_list[-1]
  134. import sys
  135. sys.path.append('../')
  136. import recognizer as cRP
  137. import niblack_thresholding as nt
  138. def refineCrop(sections,width=16):
  139. new_sections = []
  140. for section in sections:
  141. # cv2.imshow("section¡",section)
  142. # cv2.blur(section,(3,3),3)
  143. sec_center = np.array([section.shape[1]/2,section.shape[0]/2])
  144. binary_niblack = nt.niBlackThreshold(section,17,-0.255)
  145. imagex, contours, hierarchy = cv2.findContours(binary_niblack,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
  146. boxs = []
  147. for contour in contours:
  148. x,y,w,h = cv2.boundingRect(contour)
  149. ratio = w/float(h)
  150. if ratio<1 and h>36*0.4 and y<16\
  151. :
  152. box = [x,y,w,h]
  153. boxs.append([box,np.array([x+w/2,y+h/2])])
  154. # cv2.rectangle(section,(x,y),(x+w,y+h),255,1)
  155. # print boxs
  156. dis_ = np.array([ ((one[1]-sec_center)**2).sum() for one in boxs])
  157. if len(dis_)==0:
  158. kernal = [0, 0, section.shape[1], section.shape[0]]
  159. else:
  160. kernal = boxs[dis_.argmin()][0]
  161. center_c = (kernal[0]+kernal[2]/2,kernal[1]+kernal[3]/2)
  162. w_2 = int(width/2)
  163. h_2 = kernal[3]/2
  164. if center_c[0] - w_2< 0:
  165. w_2 = center_c[0]
  166. new_box = [center_c[0] - w_2,kernal[1],width,kernal[3]]
  167. # print new_box[2]/float(new_box[3])
  168. if new_box[2]/float(new_box[3])>0.5:
  169. # print "异常"
  170. h = int((new_box[2]/0.35 )/2)
  171. if h>35:
  172. h = 35
  173. new_box[1] = center_c[1]- h
  174. if new_box[1]<0:
  175. new_box[1] = 1
  176. new_box[3] = h*2
  177. section = section[new_box[1]:new_box[1]+new_box[3],new_box[0]:new_box[0]+new_box[2]]
  178. # cv2.imshow("section",section)
  179. # cv2.waitKey(0)
  180. new_sections.append(section)
  181. # print new_box
  182. return new_sections
  183. def slidingWindowsEval(image):
  184. windows_size = 16;
  185. stride = 1
  186. height= image.shape[0]
  187. t0 = time.time()
  188. data_sets = []
  189. for i in range(0,image.shape[1]-windows_size+1,stride):
  190. data = image[0:height,i:i+windows_size]
  191. data = cv2.resize(data,(23,23))
  192. # cv2.imshow("image",data)
  193. data = cv2.equalizeHist(data)
  194. data = data.astype(np.float)/255
  195. data= np.expand_dims(data,3)
  196. data_sets.append(data)
  197. res = model2.predict(np.array(data_sets))
  198. print "分割",time.time() - t0
  199. pin = res
  200. p = 1 - (res.T)[1]
  201. p = f.gaussian_filter1d(np.array(p,dtype=np.float),3)
  202. lmin = l.argrelmax(np.array(p),order = 3)[0]
  203. interval = []
  204. for i in xrange(len(lmin)-1):
  205. interval.append(lmin[i+1]-lmin[i])
  206. if(len(interval)>3):
  207. mid = get_median(interval)
  208. else:
  209. return []
  210. pin = np.array(pin)
  211. res = searchOptimalCuttingPoint(image,pin,0,mid,3)
  212. cutting_pts = res[1]
  213. last = cutting_pts[-1] + mid
  214. if last < image.shape[1]:
  215. cutting_pts.append(last)
  216. else:
  217. cutting_pts.append(image.shape[1]-1)
  218. name = ""
  219. confidence =0.00
  220. seg_block = []
  221. for x in xrange(1,len(cutting_pts)):
  222. if x != len(cutting_pts)-1 and x!=1:
  223. section = image[0:36,cutting_pts[x-1]-2:cutting_pts[x]+2]
  224. elif x==1:
  225. c_head = cutting_pts[x - 1]- 2
  226. if c_head<0:
  227. c_head=0
  228. c_tail = cutting_pts[x] + 2
  229. section = image[0:36, c_head:c_tail]
  230. elif x==len(cutting_pts)-1:
  231. end = cutting_pts[x]
  232. diff = image.shape[1]-end
  233. c_head = cutting_pts[x - 1]
  234. c_tail = cutting_pts[x]
  235. if diff<7 :
  236. section = image[0:36, c_head-5:c_tail+5]
  237. else:
  238. diff-=1
  239. section = image[0:36, c_head - diff:c_tail + diff]
  240. elif x==2:
  241. section = image[0:36, cutting_pts[x - 1] - 3:cutting_pts[x-1]+ mid]
  242. else:
  243. section = image[0:36,cutting_pts[x-1]:cutting_pts[x]]
  244. seg_block.append(section)
  245. refined = refineCrop(seg_block,mid-1)
  246. t0 = time.time()
  247. for i,one in enumerate(refined):
  248. res_pre = cRP.SimplePredict(one, i )
  249. # cv2.imshow(str(i),one)
  250. # cv2.waitKey(0)
  251. confidence+=res_pre[0]
  252. name+= res_pre[1]
  253. print "字符识别",time.time() - t0
  254. return refined,name,confidence