spark.mllib源码阅读-优化算法1-Gradient

来源:互联网 发布:全球淘宝下载 编辑:程序博客网 时间:2024/06/09 20:00

Spark中定义的损失函数及梯度,在看源码之前,先回顾一下机器学习中定义了哪些损失函数,毕竟梯度求解是为优化求解损失函数服务的。

监督学习问题是在假设空间F中选取模型f作为决策函数,对于给定的输入X,由f(X)给出相应的输出Y,这个输出的预测值f(X)与真实值Y可能一致也可能不一致,用一个损失函数(lossfunction)或代价函数(cost function)来度量预测错误的程度。损失函数是f(X)Y的非负实值函数,记作L(Y, f(X)).

统计学习中常用的损失函数有以下几种:

(1) 0-1损失函数(0-1 loss function):


(2) 平方损失函数(quadraticloss function)


(3) 绝对损失函数(absolute lossfunction)


(4) 对数损失函数(logarithmicloss function) 或对数似然损失函数(log-likelihood loss function)

 

(5)间隔损失函数(hinge loss)


在不考虑过拟合的情况下,损失函数越小,模型就越好。

 

Spark中定义梯度和损失函数求解的类包括一个Gradient基类及其三个实现类:


Gradient

梯度计算的抽象类,定义了计算梯度值和损失函数值的compute函数:

def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {  val gradient = Vectors.zeros(weights.size)  val loss = compute(data, label, weights, gradient)  (gradient, loss)}

后面的梯度计算类都继承子Gradient类并实现compute函数。

LeastSquaresGradient

实现了最小二乘法进行线性回归的梯度计算方法。

其对compute函数进行的覆写

override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {  val diff = dot(data, weights) - label  val loss = diff * diff / 2.0  val gradient = data.copy  scal(diff, gradient)//常数乘以向量 更新后的gradient即为梯度 gradient=(y - lable)* x  (gradient, loss)}

使用场景:

1、 参数估计的方法是最小化误差的平方和,其它估计方法不适合用此梯度算子。

2、 Spark实现的是线性回归的梯度计算,非线性回归的梯度计算不适合使用此算子。


HingeGradient

实现了最大化分类间距的hinge loss进行参数估计的梯度下降方法,compute函数进行的覆写:

class HingeGradient extends Gradient {  override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {    val dotProduct = dot(data, weights)    // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))    // Therefore the gradient is -(2y - 1)*x    val labelScaled = 2 * label - 1.0    if (1.0 > labelScaled * dotProduct) {      val gradient = data.copy      scal(-labelScaled, gradient)      (gradient, 1.0 - labelScaled * dotProduct)    } else {      (Vectors.sparse(weights.size, Array.empty, Array.empty), 0.0)    }  }

使用场景:

适用于利用最大化分类间隔思想来构建分类器,典型的使用如SVM。

 LogisticGradient

使用对数似然损失函数对Logistic分类/回归进行参数估计的梯度下降方法。实现的代码比较长,在此就不贴了,在内部分了2分类和多分类两种情况进行计算。






1 0
原创粉丝点击