李航《统计学习方法》第六章——用Python实现逻辑斯谛回归(MNIST数据集)
来源:互联网 发布:matlab高维数据可视化 编辑:程序博客网 时间:2024/06/09 19:55
相关文章:
- 李航《统计学习方法》第二章——用Python实现感知器模型(MNIST数据集)
- 李航《统计学习方法》第三章——用Python实现KNN算法(MNIST数据集)
- 李航《统计学习方法》第四章——用Python实现朴素贝叶斯分类器(MNIST数据集)
- 李航《统计学习方法》第五章——用Python实现决策树(MNIST数据集)
- 李航《统计学习方法》第六章——用Python实现最大熵模型(MNIST数据集)
- 李航《统计学习方法》第七章——用Python实现支持向量机模型(伪造数据集)
- 李航《统计学习方法》第八章——用Python+Cpp实现AdaBoost算法(MNIST数据集)
第六章有两个算法,分别是逻辑斯谛回归与最大熵模型
逻辑斯谛回归可以看成最大熵模型的一种特例,最大熵模型的代码已经写好了,小数据集下测试正常,但放到MNIST数据集下会产生指数爆炸的问题,我这几天再看看如何解决吧!
逻辑斯谛回归
首先贴一下书上的算法
算法
可以看到这个算法与感知器算法贼像
感知器算法
当 Y = 1 时,
wT⋅x 尽量等于 +1
当 Y = 0 时,wT⋅x 尽量等于 -1
而罗吉斯蒂算法
当 Y = 1 时,
wT⋅x 尽量等于+∞
当 Y = 0 时,wT⋅x 尽量等于−∞
根据我仅存一丢丢的集合论知识,好像(0,1)与(
P.S. 如果理解有错,希望能在评论区告诉我!
参数估计
书上没有写出完整的参数估计算法,但给出了其对数似然函数,经过简单的证明可以得出该函数是单调上升,且其极限为0
因此,我们可以将-L(w)作为损失函数,用随机梯度下降的方法求解
每次随机选取一个误分类点,用上述梯度对w进行更新即可,注意由于梯度中包含指数操作,所以需要一个很小的学习率。
数据集
数据集和感知器那个博文用的是同样的数据集。
数据地址:https://github.com/WenDesi/lihang_book_algorithm/blob/master/data/train_binary.csv
特征
将整个图作为特征
代码
代码已放到Github上,这边也贴出来,因为算法还挺简单的,所以没有加什么注释,
# encoding=utf-8# @Author: WenDesi# @Date: 08-11-16# @Email: wendesi@foxmail.com# @Last modified by: WenDesi# @Last modified time: 08-11-16import timeimport mathimport randomimport pandas as pdfrom sklearn.cross_validation import train_test_splitfrom sklearn.metrics import accuracy_scoreclass LogisticRegression(object): def __init__(self): self.learning_step = 0.00001 self.max_iteration = 5000 def predict_(self,x): wx = sum([self.w[j] * x[j] for j in xrange(len(self.w))]) exp_wx = math.exp(wx) predict1 = exp_wx / (1 + exp_wx) predict0 = 1 / (1 + exp_wx) if predict1 > predict0: return 1 else: return 0 def train(self,features, labels): self.w = [0.0] * (len(features[0]) + 1) correct_count = 0 time = 0 while time < self.max_iteration: index = random.randint(0, len(labels) - 1) x = list(features[index]) x.append(1.0) y = labels[index] if y == self.predict_(x): correct_count += 1 if correct_count > self.max_iteration: break continue # print 'iterater times %d' % time time += 1 correct_count = 0 wx = sum([self.w[i] * x[i] for i in xrange(len(self.w))]) exp_wx = math.exp(wx) for i in xrange(len(self.w)): self.w[i] -= self.learning_step * \ (-y * x[i] + float(x[i] * exp_wx) / float(1 + exp_wx)) def predict(self,features): labels = [] for feature in features: x = list(feature) x.append(1) labels.append(self.predict_(x)) return labelsif __name__ == "__main__": print 'Start read data' time_1 = time.time() raw_data = pd.read_csv('../data/train_binary.csv',header=0) data = raw_data.values imgs = data[0::,1::] labels = data[::,0] # 选取 2/3 数据作为训练集, 1/3 数据作为测试集 train_features, test_features, train_labels, test_labels = train_test_split(imgs, labels, test_size=0.33, random_state=23323) time_2 = time.time() print 'read data cost ',time_2 - time_1,' second','\n' print 'Start training' lr = LogisticRegression() lr.train(train_features, train_labels) time_3 = time.time() print 'training cost ',time_3 - time_2,' second','\n' print 'Start predicting' test_predict = lr.predict(test_features) time_4 = time.time() print 'predicting cost ',time_4 - time_3,' second','\n' score = accuracy_score(test_labels,test_predict) print "The accruacy socre is ", score
运行结果
训练速度挺慢的,但正确率还行
对比实验
感知器与逻辑斯谛实在是太像了,必须要比一比
我们依然用MNIST数据集,进行十次实验,下图是实验结果,蓝色是感知器,橙色是逻辑斯谛。
可以看出逻辑斯谛回归模型正确率上还是优于感知器模型的,原因可能就像我之前说的,是浮点数精度等问题
0 0
- 李航《统计学习方法》第六章——用Python实现逻辑斯谛回归(MNIST数据集)
- 李航《统计学习方法》第六章——用Python实现最大熵模型(MNIST数据集)
- 李航《统计学习方法》第三章——用Python实现KNN算法(MNIST数据集)
- 李航《统计学习方法》第五章——用Python实现决策树(MNIST数据集)
- 李航《统计学习方法》第二章——用Python实现感知器模型(MNIST数据集)
- 李航《统计学习方法》第四章——用Python实现朴素贝叶斯分类器(MNIST数据集)
- 《统计学习方法》 逻辑斯谛回归(logistic regression) Python实现
- 逻辑斯谛回归(Logistic regression)—《统计学习方法》
- 统计学习方法——逻辑斯蒂回归模型
- 《统计学习方法》1——逻辑斯蒂回归
- 《统计学习方法》第六章逻辑斯蒂回归与最大熵模型学习笔记
- 李航《统计学习方法》第七章——用Python实现支持向量机模型(伪造数据集)
- 统计学习方法 第6章 逻辑斯谛回归与最大熵模型(1)
- 统计学习方法 第6章 逻辑斯谛回归与最大熵模型(2)
- 最小二乘回归树Python实现——统计学习方法第五章课后题
- 统计学习方法-Logistic(逻辑斯蒂)回归
- TensorFlow学习笔记(3)--实现Softmax逻辑回归识别手写数字(MNIST数据集)
- 《统计学习方法》笔记(6):逻辑斯谛回归&最大熵模型
- java循环语句详解
- 【jzoj4878】【时空传送】【最短路】
- VirtualBox + CentOS7 安装PHP运行环境(一)
- 乐视视频转屏问题
- 淘宝一键搬家到微店
- 李航《统计学习方法》第六章——用Python实现逻辑斯谛回归(MNIST数据集)
- c++ stl---------set
- 带负数高精加
- 获取系统时间
- 指针学习笔记(上)
- 三星手机照相或选择图片,图片旋转截屏图片不旋转问题
- LightOJ 1220 Mysterious Bacteria(唯一分解定理+暴力)
- 利用jclasslib修改java编译后的.class文件
- Python中的__name__和__main__含义详解