
本文共 4807 字,大约阅读时间需要 16 分钟。
白话机器学习-逻辑斯蒂回归-理论+实践篇
@(2018年例会)
概述
前面讲述了线性回归,线性回归的模型 y=wT+b 。模型的预测值逼近真实标记y。那么可否令模型的预测值逼近真实标记y的衍生物呢。比如说模型的预测值逼近真实标记的对数函数。下面引入逻辑回归的知识。
转换函数
我们需要一个单调可微函数将分类任务的真实标记y与线性回归模型的预测值联系起来,所以需要一个转换函数将线性模型的值与实际的预测值关联起来。
考虑二分类问题,其输出标记是y属于{0,1},而线性模型产生的预测值是 z=wT+b 是实值,那么我们需要将这个实值转化成0/1值,最理想的函数是单位阶跃函数。
单位阶跃函数
单位阶跃函数(unit-step function),如下图,如果预测值大于零则判断为正例;如果预测值小于零则判断为反例;为零的则任意判断。如下图所示。
sigmoid function
从图中可以看出,单位阶跃函数不连续因此不适合用来计算。这里我们引入sigmoid函数,进行计算。
y=11+e−z将z值转化为一个接近0或1的y值,并且其输出值在z=0的附近变化很陡。那么我们现在的模型变化成
几率与对数几率
几率:如果将y作为正例的可能性,1-y作为负例的可能性,那么两者的比值 y1−y 称为几率,反应了x作为正例的相对可能性。则根据sigmoid函数可得。
lny1−y 称为对数几率;
由此可以看出, y=11+e−(wT+b) 实际上是用线性模型的预测结果去逼近真实标记的对数几率,因此,其对应的模型称为“对数几率回归”
下面介绍损失函数以及计算方法。
损失函数
因为: lny1−y=wT+b 。所以
我们采用极大似然估计法进行求解,由于是二分类问题,所以符合概率里面的0-1分布,所以似然函数为
令 p(y=1|x)=e(wT+b)1+e(wT+b)=f(x) , p(y=0|x)=11+e(wT+b)=1−f(x)
对数似然函数为:
求这个函数的最大值,加个负号,求最小值。运用前面章节介绍的梯度下降和牛顿法都可以求解,这里不再赘述。
代码实战
这里讲述通过梯度上升法进行求解,首先,我们需要界定、分析下这个问题,那么我们需要什么样的信息呢?
- 输入信息的变量
- 样本:包括特征与分类标记、正例与负例
- 回归系数的初始化;
- 步长的计算;
- 损失函数的已经确定;
- 损失函数的梯度的计算;
- 通过损失函数的梯度和步长确实每次迭代;
- 迭代的停止条件;
输入数据(确定上面的变量)
- 样本信息:包括特征、分类标记(从机器学习实战中提取,后续将贴出;
- 回归函数的系数都初始化为1.0;
- 为简单起见,设置步长alpha = 0.001;
- 损失函数,上面已经介绍:
- 损失函数的梯度:
[∂f2∂xi∂yj]nxn
l(w)=lnL(w)=∑i=1n[yi(wxi)−ln(1+ewxi)]
迭代的停止条件:maxCycles = 500,迭代500次。
- 代码
def loadDataSet():
dataMat = []; labelMat = []
fr = open('testSet.txt')
for line in fr.readlines():
lineArr = line.strip().split()
dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])
labelMat.append(int(lineArr[2]))
return dataMat,labelMatdef sigmoid(inX):
return 1.0/(1+exp(-inX))def gradAscent(dataMatIn, classLabels):
dataMatrix = mat(dataMatIn)
#convert to NumPy matrix
labelMat = mat(classLabels).transpose() #convert to NumPy matrix
m,n = shape(dataMatrix)
alpha = 0.001
maxCycles = 500
weights = ones((n,1))
for k in range(maxCycles):
#heavy on matrix operations
h = sigmoid(dataMatrix*weights)
#matrix mult
error = (labelMat - h)
#vector subtraction
weights = weights + alpha * dataMatrix.transpose()* error # 这里就是用梯度迭代修改参数的值
return weights
- 样本数据
-0.017612 14.053064 0-1.395634 4.662541
1-0.752157 6.538620
0-1.322371 7.152853
00.423363
11.054677 00.406704
7.067335
10.667394
12.741452 0-2.460150 6.866805
10.569411
9.548755
0-0.026632 10.427743 00.850433
6.920334
11.347183
13.175500 01.176813
3.167020
1-1.781871 9.097953
0-0.566606 5.749003
10.931635
1.589505
1-0.024205 6.151823
1-0.036453 2.690988
1-0.196949 0.444165
11.014459
5.754399
11.985298
3.230619
1-1.693453 -0.557540 1-0.576525 11.778922 0-0.346811 -1.678730 1-2.124484 2.672471
11.217916
9.597015
0-0.733928 9.098687
0-3.642001 -1.618087 10.315985
3.523953
11.416614
9.619232
0-0.386323 3.989286
10.556921
8.294984
11.224863
11.587360 0-1.347803 -2.406051 11.196604
4.951851
10.275221
9.543647
00.470575
9.332488
0-1.889567 9.542662
0-1.527893 12.150579 0-1.185247 11.309318 0-0.445678 3.297303
11.042222
6.105155
1-0.618787 10.320986 01.152083
0.548467
10.828534
2.676045
1-1.237728 10.549033 0-0.683565 -2.166125 10.229456
5.921938
1-0.959885 11.555336 00.492911
10.993324 00.184992
8.721488
0-0.355715 10.325976 0-0.397822 8.058397
00.824839
13.730343 01.507278
5.027866
10.099671
6.835839
1-0.344008 10.717485 01.785928
7.718645
1-0.918801 11.560217 0-0.364009 4.747300
1-0.841722 4.119083
10.490426
1.960539
1-0.007194 9.075792
00.356107
12.447863 00.342578
12.281162 0-0.810823 -1.466018 12.530777
6.476801
11.296683
11.607559 00.475487
12.040035 0-0.783277 11.009725 00.074798
11.023650 0-1.337472 0.468339
1-0.102781 13.763651 0-0.147324 2.874846
10.518389
9.887035
01.015399
7.571882
0-1.658086 -0.027255 11.319944
2.171228
12.056216
5.019981
1-0.851633 4.375691
1-1.510047 6.061992
0-1.076637 -3.181888 11.821096
10.283990 03.010150
8.401766
1-1.099458 1.688274
1-0.834872 -1.733869 1-0.846637 3.849075
11.400102
12.628781 01.752842
5.468166
10.078557
0.059736
10.089392
-0.715300 11.825662
12.693808 00.197445
9.744638
00.126117
0.922311
1-0.679797 1.220530
10.677983
2.556666
10.761349
10.693862 0-2.168791 0.143632
11.388610
9.341997
00.317029
14.739025 0
- 获得结果
>>> import logRegres>>> param_mat, label_mat = logRegres.loadDataSet()>>> >>> logRegres.gradAscent(param_mat, label_mat)matrix([[ 4.12414349],
[ 0.48007329],
[-0.6168482 ]])>>>
这里每次迭代采用的是全部的样本,计算量较大,可以修改为每次迭代选取随机的样本。
参考
- 机器学习实战
转载地址:https://blog.csdn.net/qq_22054285/article/details/79116567 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
关于作者
