python实现朴素贝叶斯分类器
本文用的是sciki-learn库的iris数据集进行测试。用的模型也是最简单的,就是用贝叶斯定理P(A|B)=P(B|A)*P(A)/P(B),计算每个类别在样本中概率(代码中是pLabel变量)
以及每个类下每个特征的概率(代码中是pNum变量)。
写得比较粗糙,对于某个类下没有此特征的情况采用p=1/样本数量。
有什么错误有人发现麻烦提出,谢谢。
[python]viewplaincopy #-*-coding:utf-8-*- fromnumpyimport* fromsklearnimportdatasets importnumpyasnp classNaiveBayesClassifier(object): def__init__(self): self.dataMat=list() self.labelMat=list() self.pLabel={} self.pNum={} defloadDataSet(self): iris=datasets.load_iris() self.dataMat=iris.data self.labelMat=iris.target labelSet=set(iris.target) labelList=[iforiinlabelSet] labelNum=len(labelList) foriinrange(labelNum): self.pLabel.setdefault(labelList[i]) self.pLabel[labelList[i]]=np.sum(self.labelMat==labelList[i])/float(len(self.labelMat)) defseperateByClass(self): seperated={} foriinrange(len(self.dataMat)): vector=self.dataMat[i] ifself.labelMat[i]notinseperated: seperated[self.labelMat[i]]=[] seperated[self.labelMat[i]].append(vector) returnseperated #通过numpyarray二维数组来获取每一维每种数的概率 defgetProbByArray(self,data): prob={} foriinrange(len(data[0])): ifinotinprob: prob[i]={} dataSetList=list(set(data[:,i])) forjindataSetList: ifjnotinprob[i]: prob[i][j]=0 prob[i][j]=np.sum(data[:,i]==j)/float(len(data[:,i])) prob[0]=[1/float(len(data[:,0]))]#防止feature不存在的情况 returnprob deftrain(self): featureNum=len(self.dataMat[0]) seperated=self.seperateByClass() t_pNum={}#存储每个类别下每个特征每种情况出现的概率 forlabel,datainseperated.iteritems(): iflabelnotint_pNum: t_pNum[label]={} t_pNum[label]=self.getProbByArray(np.array(data)) self.pNum=t_pNum defclassify(self,data): label=0 pTest=np.ones(3) foriinself.pLabel: forjinself.pNum[i]: ifdata[j]notinself.pNum[i][j]: pTest[i]*=self.pNum[i][0][0] else: pTest[i]*=self.pNum[i][j][data[j]] pMax=np.max(pTest) ind=np.where(pTest==pMax) returnind[0][0] deftest(self): self.loadDataSet() self.train() pred=[] right=0 fordinself.dataMat: pred.append(self.classify(d)) foriinrange(len(self.labelMat)): ifpred[i]==self.labelMat[i]: right+=1 printright/float(len(self.labelMat)) if__name__=='__main__': NB=NaiveBayesClassifier() NB.test()
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。