kNN_手写识别系统代码实现

来源:互联网 发布:网络变压器品牌 编辑:程序博客网 时间:2024/06/08 14:19

【准备数据】将图像转换为向量

import numpy as npimport os#将单个图片的数据转化为行向量def img2Vect(filename):    returnVect = np.zeros((1,1024))    fr = open(filename)    for i in range(32):        lineStr = fr.readline()        for j in range(32):            returnVect[0,j+i*32] = int(lineStr[j])    return returnVectimg2Vect('trainingDigits\\0_0.txt')[0,0:31]#将文件夹中所有的数据转化为数据矩阵以及对应的结果集def imgs2Mat(pathname):    dataFileList = os.listdir(pathname)    m = len(dataFileList)    dataSet = np.zeros((m,1024))    labels = []    for i in range(m):        filename = dataFileList[i]        classNum = int((filename.split('.')[0]).split('_')[0])        labels.append(classNum)        vector = img2Vect(pathname+'//'+filename)        dataSet[i,:] = vector[:]    return dataSet,labelsdataSet,labels = imgs2Mat('trainingDigits')

【实施算法】k-近邻算法实现

import operatordef classify0(dataVect,dataSet,labels,k):    dataSetSize = dataSet.shape[0]    diffMat = np.tile(dataVect,(dataSetSize,1)) - dataSet    sqDiffMat = diffMat**2    sqDistances = sqDiffMat.sum(axis=1)    distances = sqDistances**0.5    sortedDistIndicies = distances.argsort()    classCount = {}    for i in range(k):        voteIlabel = labels[sortedDistIndicies[i]]        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)    return sortedClassCount[0][0]#dataVect = img2Vect('testDigits\\7_18.txt')#classify0(dataVect,dataSet,labels,5)

【测试算法】测试错误率函数

def handwritingClassTest():    #hoRatio = 0.1    dataSet,labels = imgs2Mat('trainingDigits')    testDataSet,testLabels = imgs2Mat('testDigits')    errorCount = 0.0    m = testDataSet.shape[0]    for i in range(m):        dataVect = testDataSet[i,:]        classifierResult = classify0(dataVect,dataSet,labels,5)        #print("the classifier came back with: %d, the real answer is: %d" % (classifierResult,testLabels[i]))        if classifierResult != testLabels[i]:            errorCount += 1.0    print("the total number of errors is: %d" % errorCount)    print("the total error rate is: %f" % (errorCount/float(m)))handwritingClassTest()
the total number of errors is: 17the total error rate is: 0.017970

【使用算法】输入图像并输出预测结果

def classifyHandwriting(filename):    dataVect = img2Vect(filename)    dataSet,labels = imgs2Mat('trainingDigits')    classifierResult = classify0(dataVect,dataSet,labels,5)    return classifierResultclassifyHandwriting('dataImg.txt')
8
原创粉丝点击