机器学习实战决策树之眼镜男买眼镜

来源:互联网 发布:中国近十年的gdp数据 编辑:程序博客网 时间:2024/06/08 04:21

决策树是个极其易懂的算法,建好模型后就是一连串嵌套的if..else...或嵌套的switch。

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据;

缺点:可能会产生过度匹配的问题;

适用数据类型:数值型和标称型。


决策树的Python实现:

(一)先实现几个工具函数:计算熵函数,划分数据集工具函数,计算最大概率属性;

(1)计算熵:熵代表集合的无序程度,集合越无序,熵越大;

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def entropy(dataset):  
  2.     from math import log    
  3.     log2 = lambda x:log(x)/log(2)   
  4.       
  5.     results={}    
  6.     for row in dataset:    
  7.         r = row[len(row)-1]  
  8.         results[r] = results.get(r, 0) + 1  
  9.       
  10.     ent = 0.0  
  11.     for r in results.keys():    
  12.         p = float(results[r]) / len(dataset)    
  13.         ent=ent-p*log2(p)    
  14.     return ent    
  15.       

(2)按属性和值获取数据集:

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def fetch_subdataset(dataset, k, v):  
  2.     return [d[:k]+d[k+1:] for d in dataset if d[k] == v]  
这个函数只有短短一行,他的意义是:从dataset序列中取得第k列的值为v的子集,并从获得的子集中去掉第k列。python的简单优美显现无遗。

(3)计算最大概率属性。在构建决策树时,在处理所有决策属性后,还不能唯一区分数据时,我们采用多数表决的方法来选择最终分类:

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def get_max_feature(class_list):  
  2.     class_count = {}  
  3.     for cla in class_list:  
  4.         class_count[cla] = class_count.get(cla, 0) + 1  
  5.     sorted_class_count =  sorted(class_count.items(), key=lambda d: d[1], reverse=True)   
  6.     return sorted_class_count[0][0]  

(二)选取最优数据划分方式函数:

选择集合的最优划分方式:以哪一列的值划分集合,从而获得最大的信息增益呢?

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def choose_decision_feature(dataset):  
  2.     ent, feature = 100000000, -1  
  3.     for i in range(len(dataset[0]) - 1):  
  4.         feat_list = [e[i] for e in dataset]  
  5.         unq_feat_list = set(feat_list)  
  6.         ent_t = 0.0  
  7.         for f in unq_feat_list:  
  8.             sub_data = fetch_subdataset(dataset, i, f)  
  9.             ent_t += entropy(sub_data) * len(sub_data) / len(dataset)  
  10.               
  11.         if ent_t < ent:  
  12.             ent, feature = ent_t, i  
  13.               
  14.     return feature  

(三)递归构建决策树:

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def build_decision_tree(dataset, datalabel):  
  2.     cla = [c[-1for c in dataset]  
  3.     if len(cla) == cla.count(cla[0]):  
  4.         return cla[0]  
  5.     if len(dataset[0]) == 1:  
  6.         return get_max_feature(dataset)  
  7.           
  8.     feature = choose_decision_feature(dataset)  
  9.     feature_label = datalabel[feature]  
  10.     decision_tree = {feature_label:{}}  
  11.     del(datalabel[feature])  
  12.       
  13.     feat_value = [d[feature] for d in dataset]  
  14.     unique_feat_value = set(feat_value)  
  15.     for value in unique_feat_value:  
  16.         sub_label = datalabel[:]  
  17.         decision_tree[feature_label][value] = build_decision_tree(\  
  18.             fetch_subdataset(dataset, feature, value), sub_label)  
  19.           
  20.     return decision_tree  


(四)使用决策树

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def classify(decision_tree, feat_labels, testVec):  
  2.     label = decision_tree.keys()[0]  
  3.     next_dict = decision_tree[label]  
  4.     feat_index = feat_labels.index(label)  
  5.     for key in next_dict.keys():  
  6.         if testVec[feat_index] == key:  
  7.             if type(next_dict[key]).__name__ == 'dict':  
  8.                 c_label = classify(next_dict[key], feat_labels, testVec)  
  9.             else:  
  10.                 c_label = next_dict[key]  
  11.     return c_label  

(五)决策树持久化

(1)存储

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def store_decision_tree(tree, filename):  
  2.     import pickle  
  3.     f = open(filename, 'w')  
  4.     pickle.dump(tree, f)  
  5.     f.close()  

(2)读取

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def load_decision_tree(filename):  
  2.     import pickle  
  3.     f = open(filename)  
  4.     return pickle.load(f)  

(六)到了最后了,该回到主题了,给眼镜男配眼镜了。

下面的隐形眼镜数据集来自UCI数据库,它包含很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型,隐形眼镜类型包括硬材料、软材料和不适合佩戴隐形眼镜。

数据如下:

[plain] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. young   myope   no  reduced no lenses  
  2. young   myope   no  normal  soft  
  3. young   myope   yes reduced no lenses  
  4. young   myope   yes normal  hard  
  5. young   hyper   no  reduced no lenses  
  6. young   hyper   no  normal  soft  
  7. young   hyper   yes reduced no lenses  
  8. young   hyper   yes normal  hard  
  9. pre myope   no  reduced no lenses  
  10. pre myope   no  normal  soft  
  11. pre myope   yes reduced no lenses  
  12. pre myope   yes normal  hard  
  13. pre hyper   no  reduced no lenses  
  14. pre hyper   no  normal  soft  
  15. pre hyper   yes reduced no lenses  
  16. pre hyper   yes normal  no lenses  
  17. presbyopic  myope   no  reduced no lenses  
  18. presbyopic  myope   no  normal  no lenses  
  19. presbyopic  myope   yes reduced no lenses  
  20. presbyopic  myope   yes normal  hard  
  21. presbyopic  hyper   no  reduced no lenses  
  22. presbyopic  hyper   no  normal  soft  
  23. presbyopic  hyper   yes reduced no lenses  
  24. presbyopic  hyper   yes normal  no lenses  


测试程序如下:

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def test():  
  2.     f = open('lenses.txt')  
  3.     lense_data = [inst.strip().split('\t'for inst in f.readlines()]  
  4.     lense_label = ['age''prescript''astigmatic''tearRate']  
  5.     lense_tree = build_decision_tree(lense_data, lense_label)  
我这里测试结果如下:

 


眼镜男终于可以买到合适的眼镜啦。。。


所有代码黏在下面:

[python] view plaincopyprint?在CODE上查看代码片派生到我的代码片
  1. def entropy(dataset):  
  2.     from math import log    
  3.     log2 = lambda x:log(x)/log(2)   
  4.       
  5.     results={}    
  6.     for row in dataset:    
  7.         r = row[len(row)-1]  
  8.         results[r] = results.get(r, 0) + 1  
  9.       
  10.     ent = 0.0  
  11.     for r in results.keys():    
  12.         p = float(results[r]) / len(dataset)    
  13.         ent=ent-p*log2(p)    
  14.     return ent    
  15.       
  16. def fetch_subdataset(dataset, k, v):  
  17.     return [d[:k]+d[k+1:] for d in dataset if d[k] == v]  
  18.   
  19. def get_max_feature(class_list):  
  20.     class_count = {}  
  21.     for cla in class_list:  
  22.         class_count[cla] = class_count.get(cla, 0) + 1  
  23.     sorted_class_count =  sorted(class_count.items(), key=lambda d: d[1], reverse=True)   
  24.     return sorted_class_count[0][0]  
  25.   
  26. def choose_decision_feature(dataset):  
  27.     ent, feature = 100000000, -1  
  28.     for i in range(len(dataset[0]) - 1):  
  29.         feat_list = [e[i] for e in dataset]  
  30.         unq_feat_list = set(feat_list)  
  31.         ent_t = 0.0  
  32.         for f in unq_feat_list:  
  33.             sub_data = fetch_subdataset(dataset, i, f)  
  34.             ent_t += entropy(sub_data) * len(sub_data) / len(dataset)  
  35.               
  36.         if ent_t < ent:  
  37.             ent, feature = ent_t, i  
  38.               
  39.     return feature  
  40.               
  41. def build_decision_tree(dataset, datalabel):  
  42.     cla = [c[-1for c in dataset]  
  43.     if len(cla) == cla.count(cla[0]):  
  44.         return cla[0]  
  45.     if len(dataset[0]) == 1:  
  46.         return get_max_feature(dataset)  
  47.           
  48.     feature = choose_decision_feature(dataset)  
  49.     feature_label = datalabel[feature]  
  50.     decision_tree = {feature_label:{}}  
  51.     del(datalabel[feature])  
  52.       
  53.     feat_value = [d[feature] for d in dataset]  
  54.     unique_feat_value = set(feat_value)  
  55.     for value in unique_feat_value:  
  56.         sub_label = datalabel[:]  
  57.         decision_tree[feature_label][value] = build_decision_tree(\  
  58.             fetch_subdataset(dataset, feature, value), sub_label)  
  59.           
  60.     return decision_tree  
  61.           
  62. def store_decision_tree(tree, filename):  
  63.     import pickle  
  64.     f = open(filename, 'w')  
  65.     pickle.dump(tree, f)  
  66.     f.close()  
  67.   
  68. def load_decision_tree(filename):  
  69.     import pickle  
  70.     f = open(filename)  
  71.     return pickle.load(f)  
  72.       
  73. def classify(decision_tree, feat_labels, testVec):  
  74.     label = decision_tree.keys()[0]  
  75.     next_dict = decision_tree[label]  
  76.     feat_index = feat_labels.index(label)  
  77.     for key in next_dict.keys():  
  78.         if testVec[feat_index] == key:  
  79.             if type(next_dict[key]).__name__ == 'dict':  
  80.                 c_label = classify(next_dict[key], feat_labels, testVec)  
  81.             else:  
  82.                 c_label = next_dict[key]  
  83.     return c_label  
  84.       
  85. def test():  
  86.     f = open('lenses.txt')  
  87.     lense_data = [inst.strip().split('\t'for inst in f.readlines()]  
  88.     lense_label = ['age''prescript''astigmatic''tearRate']  
  89.     lense_tree = build_decision_tree(lense_data, lense_label)  
  90.     return lense_tree  
  91.       
  92. if __name__ == "__main__":  
  93.     tree = test()  
  94.     print tree  
0 0
原创粉丝点击