随机森林分类器的实现
来源:互联网 发布:深圳 知乎 编辑:程序博客网 时间:2024/06/10 01:05
参见博客Random Forest随机森林算法
下面是我实现的简易版本
决策树ID3Tree.h
//#pragma once#ifndef ID3#define ID3#include<vector>#include<iostream>using namespace std;#define Epsilon 0.000000001class Record{public:std::vector < int >attri;int label;static int attributeNums;//有多少种属性static int labelNums;//数据最终需要分成几类static std::vector<int>each_attr_class;//每个属性可以取几种值};extern int value_int;extern std::vector<float> value_vector;class ID3Tree{struct ID3Node{int attri;//the attribute that current node selected to splitstd::vector<ID3Node*>child;std::vector<int>remain_attri;//剩余待分裂属性int label;ID3Node(){label = -1;}};private:ID3Node*root;int attributeNums;//有多少种属性int labelNums;//数据最终需要分成几类int sampleNums;//训练样本数量std::vector<int>label_count;//dataset中属于每种label的个数double threshold;//阈值std::vector<int>each_attr_class;//每个属性可以取几种值//std::vector<Record>*dataset;//样本集double cal_entropy(std::vector<Record>&Dataset);std::vector<std::vector<Record>> splitDataset(std::vector<Record>&Dataset, const int k);int majority(std::vector<Record>&Dataset);//判断数据集中的majority label的比例是否达到停止分裂的标准public:ID3Tree();~ID3Tree();void build_tree(ID3Node*node, std::vector<Record>&Dataset);int classify(Record &rd);bool create_root();bool create_root(std::vector<int>&aa);ID3Node*get_root(){ return root; }//void load_dataset(std::vector<Record>&Dataset){ this->dataset = &Dataset; };void set_paras(int attributeNums, int labelNums, int sampleNums){this->attributeNums = attributeNums; this->labelNums = labelNums; this->sampleNums = sampleNums;each_attr_class = Record::each_attr_class;}};int ID3Tree::majority(std::vector<Record>&Dataset){std::vector<int>::iterator it;label_count.clear();label_count.resize(labelNums);label_count.assign(label_count.size(), 0);for (int i = 0; i < Dataset.size(); i++)label_count[Dataset[i].label]++;it = std::max_element(label_count.begin(), label_count.end());//if (double(*it) / double(Dataset.size()) > threshold)return it - label_count.begin();//return -1;}std::vector<std::vector<Record>>ID3Tree::splitDataset(std::vector<Record>&Dataset, const int k){std::vector<std::vector<Record>>aa;aa.resize(each_attr_class[k]);for (int i = 0; i < Dataset.size(); i++)aa[Dataset[i].attri[k]].push_back(Dataset[i]);return aa;}void ID3Tree::build_tree(ID3Node*node, std::vector<Record>&Dataset){int label = majority(Dataset);if (double(label_count[label]) / double(Dataset.size()) > threshold){node->label = label;return;}if (node->remain_attri.size() == 1){node->label = label;return;}if (Dataset.size() == 1){node->label = Dataset[0].label;return;}double base_entropy = cal_entropy(Dataset);double maxgain = -10000;int selectAttri;std::vector<std::vector<Record>>bb;for (int i = 0; i < node->remain_attri.size(); i++){double gain = base_entropy;std::vector<std::vector<Record>>aa = splitDataset(Dataset, node->remain_attri[i]);for (int j = 0; j < aa.size(); j++){double entro = cal_entropy(aa[j]);double dd = double(aa[j].size()) / double(Dataset.size() + Epsilon)*entro;//std::cout << dd << endl;gain -= dd;}if (gain > maxgain){maxgain = gain;selectAttri = i;bb = aa;}}_ASSERTE(selectAttri >= 0);std::vector<int>aa = node->remain_attri;node->attri = node->remain_attri[selectAttri];aa.erase(aa.begin() + selectAttri);for (int i = 0; i < bb.size(); i++){ID3Node*nn = new ID3Node;//nn->attri = node->remain_attri[selectAttri];nn->remain_attri = aa;node->child.push_back(nn);build_tree(nn, bb[i]);}}double ID3Tree::cal_entropy(std::vector<Record>&Dataset){int len = Dataset.size();double entropy = 0;std::vector<int>count;count.resize(labelNums);for (int i = 0; i < Dataset.size(); i++)count[Dataset[i].label]++;for (int i = 0; i < labelNums; i++)entropy += -double(count[i] + Epsilon) / double(len + Epsilon)*log(double(count[i] + Epsilon) / double(len + Epsilon)) / log(double(2));_ASSERTE(entropy >= 0.0);return entropy;}ID3Tree::ID3Tree(){root = NULL;attributeNums = -1;threshold = 0.99;}bool ID3Tree::create_root(){if (attributeNums < 0)return false;root = new ID3Node;std::vector<int>aa;for (int i = 0; i < attributeNums; i++)aa.push_back(i);root->remain_attri = aa;return true;}bool ID3Tree::create_root(std::vector<int>&aa){attributeNums = aa.size();root = new ID3Node;root->remain_attri = aa;return true;}ID3Tree::~ID3Tree(){if (root == NULL)return;std::vector<ID3Node*>aa, bb;aa.push_back(root);while (!aa.empty()){ID3Node*nn = aa.back(); aa.pop_back();bb.push_back(nn);while (!nn->child.empty()){aa.push_back(nn->child.back());nn->child.pop_back();//如果注释掉会不会出错}}for (int i = 0; i < bb.size(); i++)delete bb[i];}int ID3Tree::classify(Record &rd){ID3Node*node = root;while (node->child.size() > 0){node = node->child[rd.attri[node->attri]];}rd.label = node->label;return node->label;}#endif
#ifndef RANDOMFOREST#define RANDOMFOREST#include<time.h>#include<cstdlib>class ID3Tree;class Record;class RandomForest{private:std::vector < ID3Tree* > forest;int treeNums;void boostrap();std::vector<Record>wholeDataSet;std::vector<Record>subDataSet;int sizeofwholeDataSet;double ratioofsubDataset;int attributeNums;//有多少种属性int sub_attriNums;//建立一个tree所需要选择的属性数目int labelNums;//数据最终需要分成几类std::vector<int>ranom_select_feature();std::vector<int>vote(Record rd);public:void load_dataset();void create_forest();int classify(std::vector<int>query);void set_paras();RandomForest();~RandomForest();};#endif
randomforest.cpp
#include "stdafx.h"#include"ID3Tree.h"#include "RandomForest.h"#include<string>using namespace std;RandomForest::RandomForest(){time_t t;srand((unsigned)time(&t));}RandomForest::~RandomForest(){for (int i = 0; i < forest.size(); i++)delete forest[i];}void RandomForest::set_paras(){ratioofsubDataset = 0.5;sizeofwholeDataSet = wholeDataSet.size();attributeNums = Record::attributeNums;sub_attriNums = attributeNums - 1;treeNums = 100;labelNums = Record::labelNums;}int split(const std::string& str, std::vector<std::string>& ret_, std::string sep = ","){if (str.empty()){return 0;}std::string tmp;std::string::size_type pos_begin = str.find_first_not_of(sep);std::string::size_type comma_pos = 0;while (pos_begin != std::string::npos){comma_pos = str.find(sep, pos_begin);if (comma_pos != std::string::npos){tmp = str.substr(pos_begin, comma_pos - pos_begin);pos_begin = comma_pos + sep.length();}else{tmp = str.substr(pos_begin);pos_begin = comma_pos;}if (!tmp.empty()){ret_.push_back(tmp);tmp.clear();}}return 0;}int Record::attributeNums = 4;int Record::labelNums = 2;int aa[4] = { 3, 3, 2, 2 };vector<int>nums(aa, aa + 4);vector<int>Record::each_attr_class = nums;void RandomForest::load_dataset(){/*Rid Age Income Student CreditRating BuysComputer1 Youth High No Fair No2 Youth High No Excellent No3 MiddleAged High No Fair Yes4 Senior Medium No Fair Yes5 Senior Low Yes Fair Yes6 Senior Low Yes Excellent No7 MiddleAged Low Yes Excellent Yes8 Youth Medium No Fair No9 Youth Low Yes Fair Yes10 Senior Medium Yes Fair Yes11 Youth Medium Yes Excellent Yes12 MiddleAged Medium No Excellent Yes13 MiddleAged High Yes Fair Yes14 Senior Medium No Excellent No*/FILE*fp = fopen("input.txt", "r");_ASSERTE(fp != NULL);char ch;std::string str;ch = getc(fp);while (ch != EOF){if (ch != EOF&&ch - '0' > 0 && ch - '0' <= 9){str.clear();while (ch - '0' >= 0 && ch - '0' <= 9 && ch != EOF){ch = getc(fp);}if (ch == EOF)break;while (ch != EOF&&ch - '0' < 0 || ch - '0'>9){putchar(ch);str += ch;ch = getc(fp);}std::vector<std::string>re; split(str, re, std::string(" "));_ASSERTE(re.size() == 5);Record rd;if (re[0] == "Youth")rd.attri.push_back(0);else if (re[0] == "Senior")rd.attri.push_back(1);else if (re[0] == "MiddleAged")rd.attri.push_back(2);else _ASSERTE(1 < 0);if (re[1] == "Low")rd.attri.push_back(0);else if (re[1] == "Medium")rd.attri.push_back(1);else if (re[1] == "High")rd.attri.push_back(2);else _ASSERTE(1 < 0);if (re[2] == "No")rd.attri.push_back(0);else if (re[2] == "Yes")rd.attri.push_back(1);else _ASSERTE(1 < 0);if (re[3] == "Fair")rd.attri.push_back(0);else if (re[3] == "Excellent")rd.attri.push_back(1);else _ASSERTE(1 < 0);if (re[4] == "No\n")rd.label = 0;else if (re[4] == "Yes\n")rd.label = 1;else if (re[4] == "No")rd.label = 0;else if (re[4] == "Yes")rd.label = 1;else _ASSERTE(1 < 0);wholeDataSet.push_back(rd);}elsech = getc(fp);}fclose(fp);//关闭文件fp = NULL;//需要指向空,否则会指向原打开文件地址}void RandomForest::boostrap(){subDataSet.clear();for (int i = 0; i < ratioofsubDataset* sizeofwholeDataSet; i++)subDataSet.push_back(wholeDataSet[sizeofwholeDataSet*rand() / (RAND_MAX + 1.0)]);}std::vector<int>RandomForest::ranom_select_feature(){std::vector<int>aa, bb;aa.resize(sub_attriNums);bb.resize(attributeNums);for (int i = 0; i < attributeNums; i++){bb[i] = i;}int kk = attributeNums;for (int i = 0; i < sub_attriNums; i++){int jj = kk*rand() / (RAND_MAX + 1.0);aa[i] = bb[jj];bb.erase(bb.begin() + jj);kk--;}return aa;}void RandomForest::create_forest(){for (int i = 0; i < treeNums; i++){boostrap();ID3Tree* tree = new ID3Tree;tree->set_paras(attributeNums, labelNums, ratioofsubDataset* sizeofwholeDataSet);tree->create_root(ranom_select_feature());tree->build_tree(tree->get_root(), subDataSet);forest.push_back(tree);}_ASSERTE(forest.size() == treeNums);}int RandomForest::classify(vector<int>query){_ASSERTE(query.size() == Record::attributeNums);Record rd;rd.attri = query;std::vector<int>aa = vote(rd);std::vector<int>::iterator it;it = std::max_element(aa.begin(), aa.end());return it - aa.begin();}std::vector<int>RandomForest::vote(Record rd){std::vector<int>aa;aa.resize(labelNums);aa.assign(aa.size(), 0);for (int i = 0; i < treeNums; i++)aa[forest[i]->classify(rd)]++;return aa;}
main.cpp
#include "stdafx.h"#include"RandomForest.h"using namespace std;int _tmain(int argc, _TCHAR* argv[]){/*std::vector<int>aa;aa.resize(5);cout << aa.size() << endl;cout << aa[2];*/RandomForest rf;rf.load_dataset();rf.set_paras();rf.create_forest();/*Age=Youth,Income=Low,Student=No,CreditRating=Fair,预测的分类结果为BuysCompute:No */int aa[4] = { 0, 0, 0, 0 };vector<int>query(aa, aa + 4);int re = rf.classify(query);system("pause");return 0;}
input.txt
Rid Age Income Student CreditRating BuysComputer1 Youth High No Fair No2 Youth High No Excellent No3 MiddleAged High No Fair Yes4 Senior Medium No Fair Yes5 Senior Low Yes Fair Yes6 Senior Low Yes Excellent No7 MiddleAged Low Yes Excellent Yes8 Youth Medium No Fair No9 Youth Low Yes Fair Yes10 Senior Medium Yes Fair Yes11 Youth Medium Yes Excellent Yes12 MiddleAged Medium No Excellent Yes13 MiddleAged High Yes Fair Yes14 Senior Medium No Excellent No
0 0
- 随机森林分类器的实现
- 随机森林分类器
- Scikit-Learn 随机森林分类器的使用
- 转:Scikit-Learn 随机森林分类器的使用
- 随机森林的简单实现
- 随机森林的简单实现
- 非等级式随机森林----随机蕨分类器
- 分类(5):组合分类器-随机森林
- 分类&回归算法-随机森林
- 随机森林二分类建模
- 随机森林算法的python实现
- 随机森林的原理与实现
- 分类模型的再考以及随机森林的应用
- 单一决策树与集成模型(随机森林分类器、梯度提升决策树)的比较
- 基于决策树的分类回归(随机森林,xgboost, gbdt)
- 使用scikit-learn的随机森林对西瓜进行分类
- 随机森林算法实现
- 随机森林算法实现
- MR基本步骤
- Regex入门
- Android数据库更新并保留原来数据的实现
- noip2012 国王游戏 (高精除,高精乘+大数比较与替换+数论)
- 1002. A+B for Polynomials (25)
- 随机森林分类器的实现
- 15个nosql数据库
- SCU 2016 GCD & LCM Inverse(素性测试+DFS)
- c/c++ 2048 120行左右~
- [13]EC_ECShop修改安装中密码长度
- angularjs路由例子
- CentOS7 增加tomcat 启动,停止,使用systemctl进行配置
- web.config connectionStrings 数据库连接字符串的解释(转载)
- GAL GAME 汉化攻略 辅助篇1 破解工具篇