随机森林分类器的实现

来源:互联网 发布:深圳 知乎 编辑:程序博客网 时间:2024/06/10 01:05

参见博客Random Forest随机森林算法

下面是我实现的简易版本

决策树ID3Tree.h

//#pragma once#ifndef ID3#define ID3#include<vector>#include<iostream>using namespace std;#define Epsilon  0.000000001class Record{public:std::vector < int >attri;int label;static int attributeNums;//有多少种属性static int labelNums;//数据最终需要分成几类static std::vector<int>each_attr_class;//每个属性可以取几种值};extern int value_int;extern std::vector<float> value_vector;class ID3Tree{struct ID3Node{int attri;//the attribute that current node selected to splitstd::vector<ID3Node*>child;std::vector<int>remain_attri;//剩余待分裂属性int label;ID3Node(){label = -1;}};private:ID3Node*root;int attributeNums;//有多少种属性int labelNums;//数据最终需要分成几类int sampleNums;//训练样本数量std::vector<int>label_count;//dataset中属于每种label的个数double threshold;//阈值std::vector<int>each_attr_class;//每个属性可以取几种值//std::vector<Record>*dataset;//样本集double cal_entropy(std::vector<Record>&Dataset);std::vector<std::vector<Record>> splitDataset(std::vector<Record>&Dataset, const int k);int majority(std::vector<Record>&Dataset);//判断数据集中的majority label的比例是否达到停止分裂的标准public:ID3Tree();~ID3Tree();void build_tree(ID3Node*node, std::vector<Record>&Dataset);int classify(Record &rd);bool create_root();bool create_root(std::vector<int>&aa);ID3Node*get_root(){ return root; }//void load_dataset(std::vector<Record>&Dataset){ this->dataset = &Dataset; };void set_paras(int attributeNums, int labelNums, int sampleNums){this->attributeNums = attributeNums; this->labelNums = labelNums; this->sampleNums = sampleNums;each_attr_class = Record::each_attr_class;}};int ID3Tree::majority(std::vector<Record>&Dataset){std::vector<int>::iterator it;label_count.clear();label_count.resize(labelNums);label_count.assign(label_count.size(), 0);for (int i = 0; i < Dataset.size(); i++)label_count[Dataset[i].label]++;it = std::max_element(label_count.begin(), label_count.end());//if (double(*it) / double(Dataset.size()) > threshold)return it - label_count.begin();//return -1;}std::vector<std::vector<Record>>ID3Tree::splitDataset(std::vector<Record>&Dataset, const int k){std::vector<std::vector<Record>>aa;aa.resize(each_attr_class[k]);for (int i = 0; i < Dataset.size(); i++)aa[Dataset[i].attri[k]].push_back(Dataset[i]);return aa;}void ID3Tree::build_tree(ID3Node*node, std::vector<Record>&Dataset){int label = majority(Dataset);if (double(label_count[label]) / double(Dataset.size()) > threshold){node->label = label;return;}if (node->remain_attri.size() == 1){node->label = label;return;}if (Dataset.size() == 1){node->label = Dataset[0].label;return;}double base_entropy = cal_entropy(Dataset);double maxgain = -10000;int selectAttri;std::vector<std::vector<Record>>bb;for (int i = 0; i < node->remain_attri.size(); i++){double gain = base_entropy;std::vector<std::vector<Record>>aa = splitDataset(Dataset, node->remain_attri[i]);for (int j = 0; j < aa.size(); j++){double entro = cal_entropy(aa[j]);double dd = double(aa[j].size()) / double(Dataset.size() + Epsilon)*entro;//std::cout << dd << endl;gain -= dd;}if (gain > maxgain){maxgain = gain;selectAttri = i;bb = aa;}}_ASSERTE(selectAttri >= 0);std::vector<int>aa = node->remain_attri;node->attri = node->remain_attri[selectAttri];aa.erase(aa.begin() + selectAttri);for (int i = 0; i < bb.size(); i++){ID3Node*nn = new ID3Node;//nn->attri = node->remain_attri[selectAttri];nn->remain_attri = aa;node->child.push_back(nn);build_tree(nn, bb[i]);}}double ID3Tree::cal_entropy(std::vector<Record>&Dataset){int len = Dataset.size();double entropy = 0;std::vector<int>count;count.resize(labelNums);for (int i = 0; i < Dataset.size(); i++)count[Dataset[i].label]++;for (int i = 0; i < labelNums; i++)entropy += -double(count[i] + Epsilon) / double(len + Epsilon)*log(double(count[i] + Epsilon) / double(len + Epsilon)) / log(double(2));_ASSERTE(entropy >= 0.0);return entropy;}ID3Tree::ID3Tree(){root = NULL;attributeNums = -1;threshold = 0.99;}bool ID3Tree::create_root(){if (attributeNums < 0)return false;root = new ID3Node;std::vector<int>aa;for (int i = 0; i < attributeNums; i++)aa.push_back(i);root->remain_attri = aa;return true;}bool ID3Tree::create_root(std::vector<int>&aa){attributeNums = aa.size();root = new ID3Node;root->remain_attri = aa;return true;}ID3Tree::~ID3Tree(){if (root == NULL)return;std::vector<ID3Node*>aa, bb;aa.push_back(root);while (!aa.empty()){ID3Node*nn = aa.back(); aa.pop_back();bb.push_back(nn);while (!nn->child.empty()){aa.push_back(nn->child.back());nn->child.pop_back();//如果注释掉会不会出错}}for (int i = 0; i < bb.size(); i++)delete bb[i];}int ID3Tree::classify(Record &rd){ID3Node*node = root;while (node->child.size() > 0){node = node->child[rd.attri[node->attri]];}rd.label = node->label;return node->label;}#endif


randomforest.h

#ifndef RANDOMFOREST#define RANDOMFOREST#include<time.h>#include<cstdlib>class ID3Tree;class Record;class RandomForest{private:std::vector < ID3Tree* > forest;int treeNums;void boostrap();std::vector<Record>wholeDataSet;std::vector<Record>subDataSet;int sizeofwholeDataSet;double ratioofsubDataset;int attributeNums;//有多少种属性int sub_attriNums;//建立一个tree所需要选择的属性数目int labelNums;//数据最终需要分成几类std::vector<int>ranom_select_feature();std::vector<int>vote(Record rd);public:void load_dataset();void create_forest();int classify(std::vector<int>query);void set_paras();RandomForest();~RandomForest();};#endif

randomforest.cpp

#include "stdafx.h"#include"ID3Tree.h"#include "RandomForest.h"#include<string>using namespace std;RandomForest::RandomForest(){time_t t;srand((unsigned)time(&t));}RandomForest::~RandomForest(){for (int i = 0; i < forest.size(); i++)delete forest[i];}void RandomForest::set_paras(){ratioofsubDataset = 0.5;sizeofwholeDataSet = wholeDataSet.size();attributeNums = Record::attributeNums;sub_attriNums = attributeNums - 1;treeNums = 100;labelNums = Record::labelNums;}int split(const std::string& str, std::vector<std::string>& ret_, std::string sep = ","){if (str.empty()){return 0;}std::string tmp;std::string::size_type pos_begin = str.find_first_not_of(sep);std::string::size_type comma_pos = 0;while (pos_begin != std::string::npos){comma_pos = str.find(sep, pos_begin);if (comma_pos != std::string::npos){tmp = str.substr(pos_begin, comma_pos - pos_begin);pos_begin = comma_pos + sep.length();}else{tmp = str.substr(pos_begin);pos_begin = comma_pos;}if (!tmp.empty()){ret_.push_back(tmp);tmp.clear();}}return 0;}int Record::attributeNums = 4;int Record::labelNums = 2;int aa[4] = { 3, 3, 2, 2 };vector<int>nums(aa, aa + 4);vector<int>Record::each_attr_class = nums;void RandomForest::load_dataset(){/*Rid Age Income Student CreditRating BuysComputer1 Youth High No Fair No2 Youth High No Excellent No3 MiddleAged High No Fair Yes4 Senior Medium No Fair Yes5 Senior Low Yes Fair Yes6 Senior Low Yes Excellent No7 MiddleAged Low Yes Excellent Yes8 Youth Medium No Fair No9 Youth Low Yes Fair Yes10 Senior Medium Yes Fair Yes11 Youth Medium Yes Excellent Yes12 MiddleAged Medium No Excellent Yes13 MiddleAged High Yes Fair Yes14 Senior Medium No Excellent No*/FILE*fp = fopen("input.txt", "r");_ASSERTE(fp != NULL);char ch;std::string str;ch = getc(fp);while (ch != EOF){if (ch != EOF&&ch - '0' > 0 && ch - '0' <= 9){str.clear();while (ch - '0' >= 0 && ch - '0' <= 9 && ch != EOF){ch = getc(fp);}if (ch == EOF)break;while (ch != EOF&&ch - '0' < 0 || ch - '0'>9){putchar(ch);str += ch;ch = getc(fp);}std::vector<std::string>re; split(str, re, std::string(" "));_ASSERTE(re.size() == 5);Record rd;if (re[0] == "Youth")rd.attri.push_back(0);else if (re[0] == "Senior")rd.attri.push_back(1);else if (re[0] == "MiddleAged")rd.attri.push_back(2);else _ASSERTE(1 < 0);if (re[1] == "Low")rd.attri.push_back(0);else if (re[1] == "Medium")rd.attri.push_back(1);else if (re[1] == "High")rd.attri.push_back(2);else _ASSERTE(1 < 0);if (re[2] == "No")rd.attri.push_back(0);else if (re[2] == "Yes")rd.attri.push_back(1);else _ASSERTE(1 < 0);if (re[3] == "Fair")rd.attri.push_back(0);else if (re[3] == "Excellent")rd.attri.push_back(1);else _ASSERTE(1 < 0);if (re[4] == "No\n")rd.label = 0;else if (re[4] == "Yes\n")rd.label = 1;else if (re[4] == "No")rd.label = 0;else if (re[4] == "Yes")rd.label = 1;else _ASSERTE(1 < 0);wholeDataSet.push_back(rd);}elsech = getc(fp);}fclose(fp);//关闭文件fp = NULL;//需要指向空,否则会指向原打开文件地址}void RandomForest::boostrap(){subDataSet.clear();for (int i = 0; i < ratioofsubDataset* sizeofwholeDataSet; i++)subDataSet.push_back(wholeDataSet[sizeofwholeDataSet*rand() / (RAND_MAX + 1.0)]);}std::vector<int>RandomForest::ranom_select_feature(){std::vector<int>aa, bb;aa.resize(sub_attriNums);bb.resize(attributeNums);for (int i = 0; i < attributeNums; i++){bb[i] = i;}int kk = attributeNums;for (int i = 0; i < sub_attriNums; i++){int jj = kk*rand() / (RAND_MAX + 1.0);aa[i] = bb[jj];bb.erase(bb.begin() + jj);kk--;}return aa;}void RandomForest::create_forest(){for (int i = 0; i < treeNums; i++){boostrap();ID3Tree* tree = new ID3Tree;tree->set_paras(attributeNums, labelNums, ratioofsubDataset* sizeofwholeDataSet);tree->create_root(ranom_select_feature());tree->build_tree(tree->get_root(), subDataSet);forest.push_back(tree);}_ASSERTE(forest.size() == treeNums);}int RandomForest::classify(vector<int>query){_ASSERTE(query.size() == Record::attributeNums);Record rd;rd.attri = query;std::vector<int>aa = vote(rd);std::vector<int>::iterator it;it = std::max_element(aa.begin(), aa.end());return it - aa.begin();}std::vector<int>RandomForest::vote(Record rd){std::vector<int>aa;aa.resize(labelNums);aa.assign(aa.size(), 0);for (int i = 0; i < treeNums; i++)aa[forest[i]->classify(rd)]++;return aa;}

main.cpp

#include "stdafx.h"#include"RandomForest.h"using namespace std;int _tmain(int argc, _TCHAR* argv[]){/*std::vector<int>aa;aa.resize(5);cout << aa.size() << endl;cout << aa[2];*/RandomForest rf;rf.load_dataset();rf.set_paras();rf.create_forest();/*Age=Youth,Income=Low,Student=No,CreditRating=Fair,预测的分类结果为BuysCompute:No */int aa[4] = { 0, 0, 0, 0 };vector<int>query(aa, aa + 4);int re = rf.classify(query);system("pause");return 0;}


测试数据内容如下
input.txt
Rid Age Income Student CreditRating BuysComputer1 Youth High No Fair No2 Youth High No Excellent No3 MiddleAged High No Fair Yes4 Senior Medium No Fair Yes5 Senior Low Yes Fair Yes6 Senior Low Yes Excellent No7 MiddleAged Low Yes Excellent Yes8 Youth Medium No Fair No9 Youth Low Yes Fair Yes10 Senior Medium Yes Fair Yes11 Youth Medium Yes Excellent Yes12 MiddleAged Medium No Excellent Yes13 MiddleAged High Yes Fair Yes14 Senior Medium No Excellent No


0 0
原创粉丝点击