
机器学习笔记24——单层决策树(decision stump)原理以及python实现
发布日期:2021-05-07 08:53:20
浏览次数:21
分类:技术文章
本文共 10495 字,大约阅读时间需要 34 分钟。
单层决策树
1、概述
\quad \quad 单层决策树(decision stump),也称决策树桩,它是一种简单的决策树,通过给定的阈值,进行分类。如下图所示(仅对 petal length 进行了判断):

- 从树(数据结构)的观点来看,它由根节点(root)与叶子节点(leaves)直接相连。用作分类器(classifier)的 decision stump 的叶子节点也就意味着最终的分类结果。
- 从实际意义来看,decision stump 根据一个属性的一个判断就决定了最终的分类结果,比如根据水果是否是圆形判断水果是否为苹果,这体现的是单一简单的规则(或叫特征)在起作用。
- 显然 decision stump 仅可作为一个 weak base learning algorithm(它会比瞎猜 12 稍好一点点,但好的程度十分有限),常用作集成学习中的 base algorithm,而不会单独作为分类器。
2、构建
优化目标:最低错误率
分类函数 stumpClassify():
\quad \quad 通过阈值比较对数据进行分类 。所有在阈值一边的数据会分到类别-1, 而在 另外一边的数据分到类别-1。决策树buildStump():数据集上最佳决策树
\quad \quad 通过遍历,改变不同的阈值,计算最终的分类误差,找到分类误差最小的分类方式,即为我们要找的最佳单层决策树。伪代码:
- 将最小错误率minError设为Inf(正无穷)
- 对于数据集每一个特征:(第一层循环)
- 对于数据集每一个步长:(第二层循环)
- 对每个不等号:(第三层循环) 建立一颗决策树并用加权数据集对它进行测试 如果错误率低于minError,将当前决策树设为最佳单层决策树
- 返回最佳单层决策树
# 单层决策树分类函数,与阈值进行比较def stumpClassify(dataMatrix,dimen,threshVal,threshIneq): retArray = np.ones((np.shape(dataMatrix)[0],1)) #初始化retArray为1 if threshIneq == 'lt':#如果分类标志为'lt',则分类规则:小于等于阈值,归类到-1;将不等号在大 、小 之间切换。 retArray[dataMatrix[:,dimen] <= threshVal] = -1.0 #如果小于阈值,则赋值为-1 else: retArray[dataMatrix[:,dimen] > threshVal] = -1.0 #如果大于阈值,则赋值为-1 return retArray"""Parameters: dataArr - 数据矩阵 classLabels - 数据标签 D - 样本权重Returns: bestStump - 最佳单层决策树信息 minError - 最小误差 bestClasEst - 最佳的分类结果"""# 找到数据集上最佳的单层决策树def buildStump(dataArr,classLabels,D): dataMatrix = np.mat(dataArr); labelMat = np.mat(classLabels).T m,n = np.shape(dataMatrix) numSteps = 10.0; bestStump = { }; bestClasEst = np.mat(np.zeros((m,1)))#初始化步长,最佳单层决策树,最佳分类结果 minError = float('inf') #最小误差初始化为正无穷大 for i in range(n): #遍历所有特征 rangeMin = dataMatrix[:,i].min(); rangeMax = dataMatrix[:,i].max() #找到特征中最小的值和最大值 stepSize = (rangeMax - rangeMin) / numSteps #计算步长 for j in range(-1, int(numSteps) + 1): for inequal in ['lt', 'gt']: #大于和小于的情况,均遍历。lt:less than,gt:greater than threshVal = (rangeMin + float(j) * stepSize) #计算阈值 predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)#计算分类结果 errArr = np.mat(np.ones((m,1))) #初始化误差矩阵 errArr[predictedVals == labelMat] = 0 #分类正确的,赋值为0 weightedError = D.T * errArr #计算误差 print("split: dim %d, threshVal %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError)) if weightedError < minError: #找到误差最小的分类方式 minError = weightedError bestClasEst = predictedVals.copy() bestStump['dim'] = i bestStump['thresh'] = threshVal bestStump['ineq'] = inequal return bestStump,minError,bestClasEst
举一个简单的例子:
1、创建数据集# 创建单层决策树的数据集def loadSimpData(): datMat = np.matrix([[ 1. , 2.1], [ 1.5, 1.6], [ 1.3, 1. ], [ 1. , 1. ], [ 2. , 1. ]]) classLabels = [1.0, 1.0, -1.0, -1.0, 1.0] return datMat,classLabels
2、基于上述数据集创建单层决策树
dataArr,classLabels = loadSimpData() D = np.mat(np.ones((5, 1)) / 5) bestStump,minError,bestClasEst = buildStump(dataArr,classLabels,D)print('bestStump:\n', bestStump)print('minError:\n', minError)print('bestClasEst:\n', bestClasEst)
split: dim 0, threshVal 0.90, thresh ineqal: lt, the weighted error is 0.400
split: dim 0, threshVal 0.90, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 1.00, thresh ineqal: lt, the weighted error is 0.400 split: dim 0, threshVal 1.00, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 1.10, thresh ineqal: lt, the weighted error is 0.400 split: dim 0, threshVal 1.10, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 1.20, thresh ineqal: lt, the weighted error is 0.400 split: dim 0, threshVal 1.20, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 1.30, thresh ineqal: lt, the weighted error is 0.200 split: dim 0, threshVal 1.30, thresh ineqal: gt, the weighted error is 0.800 split: dim 0, threshVal 1.40, thresh ineqal: lt, the weighted error is 0.200 split: dim 0, threshVal 1.40, thresh ineqal: gt, the weighted error is 0.800 split: dim 0, threshVal 1.50, thresh ineqal: lt, the weighted error is 0.400 split: dim 0, threshVal 1.50, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 1.60, thresh ineqal: lt, the weighted error is 0.400 split: dim 0, threshVal 1.60, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 1.70, thresh ineqal: lt, the weighted error is 0.400 split: dim 0, threshVal 1.70, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 1.80, thresh ineqal: lt, the weighted error is 0.400 split: dim 0, threshVal 1.80, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 1.90, thresh ineqal: lt, the weighted error is 0.400 split: dim 0, threshVal 1.90, thresh ineqal: gt, the weighted error is 0.600 split: dim 0, threshVal 2.00, thresh ineqal: lt, the weighted error is 0.600 split: dim 0, threshVal 2.00, thresh ineqal: gt, the weighted error is 0.400 split: dim 1, threshVal 0.89, thresh ineqal: lt, the weighted error is 0.400 split: dim 1, threshVal 0.89, thresh ineqal: gt, the weighted error is 0.600 split: dim 1, threshVal 1.00, thresh ineqal: lt, the weighted error is 0.200 split: dim 1, threshVal 1.00, thresh ineqal: gt, the weighted error is 0.800 split: dim 1, threshVal 1.11, thresh ineqal: lt, the weighted error is 0.200 split: dim 1, threshVal 1.11, thresh ineqal: gt, the weighted error is 0.800 split: dim 1, threshVal 1.22, thresh ineqal: lt, the weighted error is 0.200 split: dim 1, threshVal 1.22, thresh ineqal: gt, the weighted error is 0.800 split: dim 1, threshVal 1.33, thresh ineqal: lt, the weighted error is 0.200 split: dim 1, threshVal 1.33, thresh ineqal: gt, the weighted error is 0.800 split: dim 1, threshVal 1.44, thresh ineqal: lt, the weighted error is 0.200 split: dim 1, threshVal 1.44, thresh ineqal: gt, the weighted error is 0.800 split: dim 1, threshVal 1.55, thresh ineqal: lt, the weighted error is 0.200 split: dim 1, threshVal 1.55, thresh ineqal: gt, the weighted error is 0.800 split: dim 1, threshVal 1.66, thresh ineqal: lt, the weighted error is 0.400 split: dim 1, threshVal 1.66, thresh ineqal: gt, the weighted error is 0.600 split: dim 1, threshVal 1.77, thresh ineqal: lt, the weighted error is 0.400 split: dim 1, threshVal 1.77, thresh ineqal: gt, the weighted error is 0.600 split: dim 1, threshVal 1.88, thresh ineqal: lt, the weighted error is 0.400 split: dim 1, threshVal 1.88, thresh ineqal: gt, the weighted error is 0.600 split: dim 1, threshVal 1.99, thresh ineqal: lt, the weighted error is 0.400 split: dim 1, threshVal 1.99, thresh ineqal: gt, the weighted error is 0.600 split: dim 1, threshVal 2.10, thresh ineqal: lt, the weighted error is 0.600 split: dim 1, threshVal 2.10, thresh ineqal: gt, the weighted error is 0.400 bestStump: {‘dim’: 0, ‘thresh’: 1.3, ‘ineq’: ‘lt’} minError: [[0.2]] bestClasEst: [[-1.] [ 1.] [-1.] [-1.] [ 1.]]
\quad \quad 经过遍历,我们找到,训练好的最佳单层决策树的最小分类误差为0.2,就是对于该数据集,无论用什么样的单层决策树,分类误差最小就是0.2。
完整代码:
import numpy as np"""Parameters: 无Returns: dataMat - 数据矩阵 classLabels - 数据标签"""# 创建单层决策树的数据集def loadSimpData(): datMat = np.matrix([[ 1. , 2.1], [ 1.5, 1.6], [ 1.3, 1. ], [ 1. , 1. ], [ 2. , 1. ]]) classLabels = [1.0, 1.0, -1.0, -1.0, 1.0] return datMat,classLabels"""Parameters: dataMatrix - 数据矩阵 dimen - 第dimen列,也就是第几个特征 threshVal - 阈值 threshIneq - 标志:有'lt':表示小于等于;'gt':表示大于等于Returns: retArray - 分类结果"""# 单层决策树分类函数,与阈值进行比较def stumpClassify(dataMatrix,dimen,threshVal,threshIneq): retArray = np.ones((np.shape(dataMatrix)[0],1)) #初始化retArray为1 if threshIneq == 'lt':#如果分类标志为'lt',则分类规则:小于等于阈值,归类到-1;将不等号在大 、小 之间切换。 retArray[dataMatrix[:,dimen] <= threshVal] = -1.0 #如果小于阈值,则赋值为-1 else: retArray[dataMatrix[:,dimen] > threshVal] = -1.0 #如果大于阈值,则赋值为-1 return retArray"""Parameters: dataArr - 数据矩阵 classLabels - 数据标签 D - 样本权重Returns: bestStump - 最佳单层决策树信息 minError - 最小误差 bestClasEst - 最佳的分类结果"""# 找到数据集上最佳的单层决策树def buildStump(dataArr,classLabels,D): dataMatrix = np.mat(dataArr); labelMat = np.mat(classLabels).T m,n = np.shape(dataMatrix) numSteps = 10.0; bestStump = { }; bestClasEst = np.mat(np.zeros((m,1)))#初始化步长,最佳单层决策树,最佳分类结果 minError = float('inf') #最小误差初始化为正无穷大 for i in range(n): #遍历所有特征 rangeMin = dataMatrix[:,i].min(); rangeMax = dataMatrix[:,i].max() #找到特征中最小的值和最大值 stepSize = (rangeMax - rangeMin) / numSteps #计算步长 for j in range(-1, int(numSteps) + 1): for inequal in ['lt', 'gt']: #大于和小于的情况,均遍历。lt:less than,gt:greater than threshVal = (rangeMin + float(j) * stepSize) #计算阈值 predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)#计算分类结果 errArr = np.mat(np.ones((m,1))) #初始化误差矩阵 errArr[predictedVals == labelMat] = 0 #分类正确的,赋值为0 weightedError = D.T * errArr #计算误差 print("split: dim %d, threshVal %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError)) if weightedError < minError: #找到误差最小的分类方式 minError = weightedError bestClasEst = predictedVals.copy() bestStump['dim'] = i bestStump['thresh'] = threshVal bestStump['ineq'] = inequal return bestStump,minError,bestClasEstif __name__ == '__main__': dataArr,classLabels = loadSimpData() D = np.mat(np.ones((5, 1)) / 5) bestStump,minError,bestClasEst = buildStump(dataArr,classLabels,D) print('bestStump:\n', bestStump) print('minError:\n', minError) print('bestClasEst:\n', bestClasEst)
参考资料:
机器学习实战发表评论
最新留言
初次前来,多多关照!
[***.217.46.12]2025年03月20日 13时22分00秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
JS中如何创建对象
2019-03-04
二叉树的基础练习题代码
2019-03-04
(JAVA小白必看)小白IDEA的安装与配置
2019-03-04
SpringMVC模板代码
2019-03-04
C++ Primer学习笔记——第2章 变量和基本数据类型
2019-03-04
5.11 TEST1
2019-03-04
6.14 阶段考试
2019-03-04
牛客IOI周赛19-普及组
2019-03-04
2020牛客NOIP赛前集训营-普及组(第四、五场)
2019-03-04
uni-app请求头中携带token
2019-03-04
常用的 Git 命令和小技巧(1)
2019-03-04
vue中接收后台的图片验证码并显示
2019-03-04
springboot入门(1)---整合MyBatis
2019-03-04
Vue入门学习笔记(1)
2019-03-04
前端入门经验——页面布局
2019-03-04
ECharts——双向柱状图
2019-03-04
Vue——引进bootstrap
2019-03-04
Vue——引进ivew
2019-03-04
趣谈win10常用快捷键
2019-03-04
趣谈文件扩展名和隐藏文件
2019-03-04