spark.mllib源码阅读-优化算法1-Gradient
来源:互联网 发布:全球淘宝下载 编辑:程序博客网 时间:2024/06/09 20:00
Spark中定义的损失函数及梯度,在看源码之前,先回顾一下机器学习中定义了哪些损失函数,毕竟梯度求解是为优化求解损失函数服务的。
监督学习问题是在假设空间F中选取模型f作为决策函数,对于给定的输入X,由f(X)给出相应的输出Y,这个输出的预测值f(X)与真实值Y可能一致也可能不一致,用一个损失函数(lossfunction)或代价函数(cost function)来度量预测错误的程度。损失函数是f(X)和Y的非负实值函数,记作L(Y, f(X)).
统计学习中常用的损失函数有以下几种:
(1) 0-1损失函数(0-1 loss function):
(2) 平方损失函数(quadraticloss function)
(3) 绝对损失函数(absolute lossfunction)
(4) 对数损失函数(logarithmicloss function) 或对数似然损失函数(log-likelihood loss function)
(5)间隔损失函数(hinge loss)
在不考虑过拟合的情况下,损失函数越小,模型就越好。
Spark中定义梯度和损失函数求解的类包括一个Gradient基类及其三个实现类:
Gradient
梯度计算的抽象类,定义了计算梯度值和损失函数值的compute函数:
def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { val gradient = Vectors.zeros(weights.size) val loss = compute(data, label, weights, gradient) (gradient, loss)}
后面的梯度计算类都继承子Gradient类并实现compute函数。
LeastSquaresGradient
实现了最小二乘法进行线性回归的梯度计算方法。
其对compute函数进行的覆写override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { val diff = dot(data, weights) - label val loss = diff * diff / 2.0 val gradient = data.copy scal(diff, gradient)//常数乘以向量 更新后的gradient即为梯度 gradient=(y - lable)* x (gradient, loss)}
使用场景:
1、 参数估计的方法是最小化误差的平方和,其它估计方法不适合用此梯度算子。
2、 Spark实现的是线性回归的梯度计算,非线性回归的梯度计算不适合使用此算子。
HingeGradient
实现了最大化分类间距的hinge loss进行参数估计的梯度下降方法,对compute函数进行的覆写:
class HingeGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { val gradient = data.copy scal(-labelScaled, gradient) (gradient, 1.0 - labelScaled * dotProduct) } else { (Vectors.sparse(weights.size, Array.empty, Array.empty), 0.0) } }
使用场景:
适用于利用最大化分类间隔思想来构建分类器,典型的使用如SVM。
LogisticGradient
使用对数似然损失函数对Logistic分类/回归进行参数估计的梯度下降方法。实现的代码比较长,在此就不贴了,在内部分了2分类和多分类两种情况进行计算。
- spark.mllib源码阅读-优化算法1-Gradient
- spark.mllib源码阅读-优化算法2-Updater
- spark.mllib源码阅读-优化算法3-Optimizer
- spark.mllib源码阅读-回归算法1-LinearRegression
- spark.mllib源码阅读-分类算法1-LogisticRegression
- spark.mllib源码阅读-聚类算法1-KMeans
- spark.mllib源码阅读-回归算法2-IsotonicRegression
- spark.mllib源码阅读-分类算法3-SVM
- spark.mllib源码阅读-分类算法4-DecisionTree
- spark.mllib源码阅读-分类算法5-GradientBoostedTrees
- spark 1.1 mllib中 NaiveBayes 源码阅读
- Spark MLlib Gaussian Mixture Clustering源码阅读
- spark.mllib源码阅读-bagging方法
- Apache Spark MLlib学习笔记(五)MLlib决策树类算法源码解析 1
- Spark MLlib SVM算法
- Spark MLlib FPGrowth算法
- Spark MLlib 算法
- Spark MLlib SVM算法
- hdu 5828
- Android进阶知识点(ViewPager,Fragment argument )
- LVS源码剖析-init
- javascript技术难点(三)之this、new、apply和call详解
- hdu 5818
- spark.mllib源码阅读-优化算法1-Gradient
- 回答阿里社招面试如何准备,顺便谈谈对于Java程序猿学习当中各个阶段的建议
- qos实现之tc流量控制
- matcaffe错误:Invalid MEX-file: caffe.mexa64,x86_64-linux-gnu/libstdc++.so.6
- 常见Android面试题及答案(详细整理)
- 快速排序-Java实现
- hdu 5800
- Python脚本在crontab定时任务中无法运行,但是对应脚本手动执行成功
- HTTP长连接和短连接原理浅析