train_digits.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #!/usr/bin/env python
  2. import cv2
  3. import numpy as np
  4. SZ = 20
  5. CLASS_N = 10
  6. # local modules
  7. from common import clock, mosaic
  8. def split2d(img, cell_size, flatten=True):
  9. h, w = img.shape[:2]
  10. sx, sy = cell_size
  11. cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
  12. cells = np.array(cells)
  13. if flatten:
  14. cells = cells.reshape(-1, sy, sx)
  15. return cells
  16. def load_digits(fn):
  17. digits_img = cv2.imread(fn, 0)
  18. digits = split2d(digits_img, (SZ, SZ))
  19. labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
  20. return digits, labels
  21. def deskew(img):
  22. m = cv2.moments(img)
  23. if abs(m['mu02']) < 1e-2:
  24. return img.copy()
  25. skew = m['mu11']/m['mu02']
  26. M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
  27. img = cv2.warpAffine(img, M, (SZ, SZ), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)
  28. return img
  29. def svmInit(C=12.5, gamma=0.50625):
  30. model = cv2.ml.SVM_create()
  31. model.setGamma(gamma)
  32. model.setC(C)
  33. model.setKernel(cv2.ml.SVM_RBF)
  34. model.setType(cv2.ml.SVM_C_SVC)
  35. return model
  36. def svmTrain(model, samples, responses):
  37. model.train(samples, cv2.ml.ROW_SAMPLE, responses)
  38. return model
  39. def svmPredict(model, samples):
  40. return model.predict(samples)[1].ravel()
  41. def svmEvaluate(model, digits, samples, labels):
  42. predictions = svmPredict(model, samples)
  43. accuracy = (labels == predictions).mean()
  44. print('Percentage Accuracy: %.2f %%' % (accuracy*100))
  45. confusion = np.zeros((10, 10), np.int32)
  46. for i, j in zip(labels, predictions):
  47. confusion[int(i), int(j)] += 1
  48. print('confusion matrix:')
  49. print(confusion)
  50. vis = []
  51. for img, flag in zip(digits, predictions == labels):
  52. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  53. if not flag:
  54. img[...,:2] = 0
  55. vis.append(img)
  56. return mosaic(25, vis)
  57. def preprocess_simple(digits):
  58. return np.float32(digits).reshape(-1, SZ*SZ) / 255.0
  59. def get_hog() :
  60. winSize = (20,20)
  61. blockSize = (8,8)
  62. blockStride = (4,4)
  63. cellSize = (8,8)
  64. nbins = 9
  65. derivAperture = 1
  66. winSigma = -1.
  67. histogramNormType = 0
  68. L2HysThreshold = 0.2
  69. gammaCorrection = 1
  70. nlevels = 64
  71. signedGradient = True
  72. hog = cv2.HOGDescriptor(winSize,blockSize,blockStride,cellSize,nbins,derivAperture,winSigma,histogramNormType,L2HysThreshold,gammaCorrection,nlevels, signedGradient)
  73. return hog
  74. affine_flags = cv2.WARP_INVERSE_MAP|cv2.INTER_LINEAR
  75. if __name__ == '__main__':
  76. print('Loading digits from digits.png ... ')
  77. # Load data.
  78. digits, labels = load_digits('digits.png')
  79. print('Shuffle data ... ')
  80. # Shuffle data
  81. rand = np.random.RandomState(10)
  82. shuffle = rand.permutation(len(digits))
  83. digits, labels = digits[shuffle], labels[shuffle]
  84. print('Deskew images ... ')
  85. digits_deskewed = list(map(deskew, digits))
  86. print('Defining HoG parameters ...')
  87. # HoG feature descriptor
  88. hog = get_hog();
  89. print('Calculating HoG descriptor for every image ... ')
  90. hog_descriptors = []
  91. for img in digits_deskewed:
  92. hog_descriptors.append(hog.compute(img))
  93. hog_descriptors = np.squeeze(hog_descriptors)
  94. print('Spliting data into training (90%) and test set (10%)... ')
  95. train_n=int(0.9*len(hog_descriptors))
  96. digits_train, digits_test = np.split(digits_deskewed, [train_n])
  97. hog_descriptors_train, hog_descriptors_test = np.split(hog_descriptors, [train_n])
  98. labels_train, labels_test = np.split(labels, [train_n])
  99. print('Training SVM model ...')
  100. model = svmInit()
  101. svmTrain(model, hog_descriptors_train, labels_train)
  102. print('Evaluating model ... ')
  103. vis = svmEvaluate(model, digits_test, hog_descriptors_test, labels_test)
  104. cv2.imwrite("digits-classification.jpg",vis)
  105. cv2.imshow("Vis", vis)
  106. cv2.waitKey(0)
  107. cv2.destroyAllWindows()