python实现决策树ID3算法的示例代码
在周志华的西瓜书和李航的统计机器学习中对决策树ID3算法都有很详细的解释,如何实现呢?核心点有如下几个步骤
step1:计算香农熵
frommathimportlog importoperator #计算香农熵 defcalculate_entropy(data): label_counts={} forfeature_dataindata: laber=feature_data[-1]#最后一行是laber iflabernotinlabel_counts.keys(): label_counts[laber]=0 label_counts[laber]+=1 count=len(data) entropy=0.0 forkeyinlabel_counts: prob=float(label_counts[key])/count entropy-=prob*log(prob,2) returnentropy
step2.计算某个feature的信息增益的方法
#计算某个feature的信息增益 #index:要计算信息增益的feature对应的在data的第几列 #data的香农熵 defcalculate_relative_entropy(data,index,entropy): feat_list=[number[index]fornumberindata]#得到某个特征下所有值(某列) uniqual_vals=set(feat_list) new_entropy=0 forvalueinuniqual_vals: sub_data=split_data(data,index,value) prob=len(sub_data)/float(len(data)) new_entropy+=prob*calculate_entropy(sub_data)#对各子集香农熵求和 relative_entropy=entropy-new_entropy#计算信息增益 returnrelative_entropy
step3.选择最大信息增益的feature
#选择最大信息增益的feature defchoose_max_relative_entropy(data): num_feature=len(data[0])-1 base_entropy=calculate_entropy(data)#香农熵 best_infor_gain=0 best_feature=-1 foriinrange(num_feature): info_gain=calculate_relative_entropy(data,i,base_entropy) #最大信息增益 if(info_gain>best_infor_gain): best_infor_gain=info_gain best_feature=i returnbest_feature
step4.构建决策树
defcreate_decision_tree(data,labels): class_list=[example[-1]forexampleindata] #类别相同,停止划分 ifclass_list.count(class_list[-1])==len(class_list): returnclass_list[-1] #判断是否遍历完所有的特征时返回个数最多的类别 iflen(data[0])==1: returnmost_class(class_list) #按照信息增益最高选取分类特征属性 best_feat=choose_max_relative_entropy(data) best_feat_lable=labels[best_feat]#该特征的label decision_tree={best_feat_lable:{}}#构建树的字典 del(labels[best_feat])#从labels的list中删除该label feat_values=[example[best_feat]forexampleindata] unique_values=set(feat_values) forvalueinunique_values: sub_lables=labels[:] #构建数据的子集合,并进行递归 decision_tree[best_feat_lable][value]=create_decision_tree(split_data(data,best_feat,value),sub_lables) returndecision_tree
在构建决策树的过程中会用到两个工具方法:
#当遍历完所有的特征时返回个数最多的类别 defmost_class(classList): class_count={} forvoteinclassList: ifvotenotinclass_count.keys():class_count[vote]=0 class_count[vote]+=1 sorted_class_count=sorted(class_count.items,key=operator.itemgetter(1),reversed=True) returnsorted_class_count[0][0] #工具函数输入三个变量(待划分的数据集,特征,分类值)返回不含划分特征的子集 defsplit_data(data,axis,value): ret_data=[] forfeat_vecindata: iffeat_vec[axis]==value: reduce_feat_vec=feat_vec[:axis] reduce_feat_vec.extend(feat_vec[axis+1:]) ret_data.append(reduce_feat_vec) returnret_data
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。